Note
You can download this example as a Jupyter notebook or try it out directly in Google Colab.
9. Explainable Reinforcement Learning Tutorial#
Welcome to this tutorial on Explainable Reinforcement Learning (XRL)! In this guide, we will explore how to interpret and explain the decisions made by reinforcement learning agents using the SHAP (SHapley Additive exPlanations) library. Through a practical example involving a simulation in a reinforcement learning setting, we’ll demonstrate how to compute and visualize feature attributions for the agent’s actions.
Table of Contents
Introduction
1.1. Multi-Agent Deep Reinforcement Learning with Market Splitting
Explainable AI and SHAP Values
2.1 Understanding Explainable AI
2.2 Introduction to SHAP Values
Calculating SHAP Values
3.1. Loading and Preparing Data
3.2. Creating a SHAP Explainer
Visualizing SHAP Values
Conclusion
Additional Resources
1. Introduction#
Reinforcement Learning (RL) has achieved remarkable success in various domains, such as game playing, robotics, and autonomous systems. However, RL models, particularly those using deep neural networks, are often seen as black boxes due to their complex architectures and non-linear computations. This opacity makes it challenging to understand and trust the decisions made by RL agents, especially in critical applications where transparency is essential.
Explainable Reinforcement Learning (XRL) aims to bridge this gap by providing insights into an agent’s decision-making process. By leveraging explainability techniques, we can: - Interpret the actions of an RL agent. - Understand the influence of input features on decisions. - Potentially improve the model’s performance, fairness, and transparency.
In this tutorial, we will demonstrate how to apply SHAP values to a trained actor neural network in an RL framework to explain the agent’s actions.
1.1 Running a MADRL Simulation#
In this tutorial, we will simulate RL agents using a Multi-Agent Deep Reinforcement Learning (MADRL) approach. The agents operate in a market-splitting environment where they interact and learn optimal strategies over time. Here’s a breakdown of the key components:
Observations: Each agent receives observations, including market forecasts, unit-specific information, and past actions.
Actions: The agents decide on bidding strategies, such as bid prices for both inflexible and flexible capacities.
Rewards: The agents are rewarded based on profits and opportunity costs, helping them learn optimal bidding strategies.
Algorithm: We utilize a multi-agent version of the TD3 (Twin Delayed Deep Deterministic Policy Gradient) algorithm, which ensures stable learning even in non-stationary environments.
For a more detailed explanation of the RL configurations, refer to the Deep Reinforcement Learning Tutorial.
Key Aspects of the Simulation#
Agents require observations to make informed decisions, which include:
Residual Load Forecast: Forecasted net demand (electricity demand minus renewable generation) over the next 24 hours.
Price Forecast: Forecasted market prices over the next 24 hours.
Marginal Cost: The current marginal cost of operating the agent’s power-generating unit.
Previous Output: The agent’s dispatched capacity (energy production) from the previous time step.
Agent Actions#
The action space for the agents is two-dimensional and consists of:
Bid Price for Inflexible Capacity (p_inflex): The price at which the agent offers its minimum power output (must-run capacity) to the market.
Bid Price for Flexible Capacity (p_flex): The price for the additional capacity above the minimum output that the agent can flexibly adjust.
1.1.1 Install Assume and Required Packages#
In this section, we will install the necessary packages to run the Assume framework along with other dependencies. The process is similar to the other tutorial on Assume.
The following commands will install Assume and its dependencies for reinforcement learning, along with additional libraries such as Plotly for visualization. Make sure to install these before running the main code.
[ ]:
!pip install 'assume-framework[learning]'
!pip install plotly
!pip install nbconvert
!git clone --depth=1 https://github.com/assume-framework/assume.git assume-repo
Define paths to differentiate between Colab or local usage. If you’re running this on Google Colab, the paths might differ slightly from your local environment. You can configure the paths accordingly based on where you’re executing the code.
[ ]:
import importlib.util
import pandas as pd
# import plotly for visualization
import plotly.graph_objects as go
# import yaml for reading and writing YAML files
import yaml
# Check if 'google.colab' is available
IN_COLAB = importlib.util.find_spec("google.colab") is not None
colab_inputs_path = "inputs"
local_inputs_path = "../inputs"
inputs_path = colab_inputs_path if IN_COLAB else local_inputs_path
print(inputs_path)
1.1.2 Create and Load Example Files from Market Splitting Tutorial#
To define the RL Agent, we need to obtain the results from the Market Zone Splitting tutorial. This tutorial provides essential data that the RL agent will use for decision-making.
If you are working in Google Colab, execute the following cells to download and run the necessary notebook automatically. If you are working on your local machine, simply open the respective tutorial notebook and execute it manually.
[ ]:
# For local execution:
%cd assume/examples/notebooks/
# For execution in Google Colab:
%cd assume-repo/examples/notebooks/
# Execute the Market Zone Splitting tutorial:
!jupyter nbconvert --to notebook --execute --ExecutePreprocessor.timeout=60 --output output.ipynb 08_market_zone_coupling.ipynb
# Return to content folder (for Colab):
%cd /content
# Copy inputs directory to the working folder (for Colab):
!cp -r assume-repo/examples/notebooks/inputs .
[ ]:
import os
# Define the input directory
input_dir = os.path.join(inputs_path, "tutorial_08")
# Read the DataFrames from CSV files
powerplant_units = pd.read_csv(os.path.join(input_dir, "powerplant_units.csv"))
demand_df = pd.read_csv(os.path.join(input_dir, "demand_df.csv"))
print("Input CSV files have been read from 'inputs/tutorial_08'.")
1.1.3 Transform the Scenario into a Learning Example#
The following cells show how we can convert any pre-configured scenario in Assume into a learning example.
Define a Learning Power Plant
In this example, we place a learning nuclear power plant in the southern zone. This plant has five times the maximum power of a typical plant, which allows us to create a scenario where its actions have a noticeable impact on market prices.
[ ]:
# Create scarcity in southern Germany by limiting the number of power plants
powerplant_units = powerplant_units[:20]
# Assign the RL-controlled power plant and give it market power
powerplant_units.loc[19, "bidding_zonal"] = "pp_learning"
powerplant_units.loc[19, "max_power"] = 5000 # Set maximum power to 5000 MW
# Assign a specific RL unit operator to the plant
powerplant_units.loc[19, "unit_operator"] = "Operator-RL"
# Set the 'name' column as the index
powerplant_units.set_index("name", inplace=True, drop=True)
# Save the updated power plant units to a CSV file
powerplant_units.to_csv(input_dir + "/powerplant_units.csv")
# Show the last 10 entries
powerplant_units.tail(10)
Configure Learning Hyperparameters in YAML
The following YAML configuration contains the learning-specific hyperparameters that will guide the RL agent’s training process. Below is a brief description of these hyperparameters:
continue_learning (
False
):Whether to continue training from a previously saved state or start fresh.
max_bid_price (
100
):The maximum allowable bid price for the agent, used to scale the actor’s output.
algorithm (
"matd3"
):The learning algorithm to be used, in this case
MATD3
(Multi-Agent Twin Delayed Deep Deterministic Policy Gradient).
learning_rate (
0.001
):The rate at which the model’s parameters are updated during training.
training_episodes (
50
):The total number of episodes for training the agent.
episodes_collecting_initial_experience (
3
):Number of episodes dedicated to collecting initial experience before actual training begins, during which the agent follows a random policy.
train_freq (
"4h"
):Frequency of model training, in this case, every 4 hours.
gradient_steps (
-1
):The number of gradient updates to perform at each training step. A value of
-1
typically means that all collected experience will be used for training.
batch_size (
256
):The size of the mini-batch used for training.
gamma (
0.99
):The discount factor for future rewards, balancing short-term vs. long-term reward importance.
device (
"cpu"
):The computational device for training. In this case, the CPU is used.
noise_sigma (
0.1
):The standard deviation of the exploration noise added to actions.
noise_scale (
1
) and noise_dt (1
):Parameters controlling the scale and time step of the exploration noise. Since both are set to 1, no decay is applied.
validation_episodes_interval (
3
):The interval (in episodes) at which validation is performed during training.
[ ]:
# YAML configuration for the RL training
config = {
"zonal_case": {
"start_date": "2019-01-01 00:00",
"end_date": "2019-01-01 23:00",
"time_step": "1h",
"save_frequency_hours": 4,
"learning_mode": "True",
"markets_config": {
"zonal": {
"operator": "EOM_operator",
"product_type": "energy",
"products": [{"duration": "1h", "count": 1, "first_delivery": "1h"}],
"opening_frequency": "1h",
"opening_duration": "1h",
"volume_unit": "MWh",
"maximum_bid_volume": 100000,
"maximum_bid_price": 3000,
"minimum_bid_price": -500,
"price_unit": "EUR/MWh",
"market_mechanism": "pay_as_clear_complex",
"additional_fields": ["bid_type", "node"],
"param_dict": {"network_path": ".", "zones_identifier": "zone_id"},
}
},
"learning_config": {
"continue_learning": False,
"max_bid_price": 100,
"algorithm": "matd3",
"learning_rate": 0.001,
"training_episodes": 15,
"episodes_collecting_initial_experience": 3,
"train_freq": "4h",
"gradient_steps": -1,
"batch_size": 256,
"gamma": 0.99,
"device": "cpu",
"noise_sigma": 0.1,
"noise_scale": 1,
"noise_dt": 1,
"validation_episodes_interval": 3,
},
}
}
# Define the path for the configuration file
config_path = os.path.join(input_dir, "config.yaml")
# Save the configuration to a YAML file
with open(config_path, "w") as file:
yaml.dump(config, file, sort_keys=False)
print(f"Configuration YAML file has been saved to '{config_path}'.")
In order to make this setup compatible with XRL, we need to enhance the logging of the learning process. ASSUME does not have this feature natively, so we will override some functions to enable this logging for the purpose of this tutorial.
[ ]:
# @title Overwrite run_learning function with enhanced logging
import json
import logging
import os
from collections import defaultdict
from pathlib import Path
import numpy as np
import yaml
from tqdm import tqdm
from assume.common.exceptions import AssumeException
from assume.scenario.loader_csv import (
load_config_and_create_forecaster,
setup_world,
)
from assume.world import World
logger = logging.getLogger(__name__)
def run_learning(
world: World,
inputs_path: str,
scenario: str,
study_case: str,
verbose: bool = False,
) -> None:
"""
Train Deep Reinforcement Learning (DRL) agents to act in a simulated market environment.
This function runs multiple episodes of simulation to train DRL agents, performs evaluation, and saves the best runs. It maintains the buffer and learned agents in memory to avoid resetting them with each new run.
Args:
world (World): An instance of the World class representing the simulation environment.
inputs_path (str): The path to the folder containing input files necessary for the simulation.
scenario (str): The name of the scenario for the simulation.
study_case (str): The specific study case for the simulation.
Note:
- The function uses a ReplayBuffer to store experiences for training the DRL agents.
- It iterates through training episodes, updating the agents and evaluating their performance at regular intervals.
- Initial exploration is active at the beginning and is disabled after a certain number of episodes to improve the performance of DRL algorithms.
- Upon completion of training, the function performs an evaluation run using the best policy learned during training.
- The best policies are chosen based on the average reward obtained during the evaluation runs, and they are saved for future use.
"""
from assume.reinforcement_learning.buffer import ReplayBuffer
if not verbose:
logger.setLevel(logging.WARNING)
# remove csv path so that nothing is written while learning
temp_csv_path = world.export_csv_path
world.export_csv_path = ""
# initialize policies already here to set the obs_dim and act_dim in the learning role
actors_and_critics = None
world.learning_role.initialize_policy(actors_and_critics=actors_and_critics)
world.output_role.del_similar_runs()
# check if we already stored policies for this simulation
save_path = world.learning_config["trained_policies_save_path"]
if Path(save_path).is_dir():
# we are in learning mode and about to train new policies, which might overwrite existing ones
accept = input(
f"{save_path=} exists - should we overwrite current learnings? (y/N) "
)
if not accept.lower().startswith("y"):
# stop here - do not start learning or save anything
raise AssumeException("don't overwrite existing strategies")
# -----------------------------------------
# Load scenario data to reuse across episodes
scenario_data = load_config_and_create_forecaster(inputs_path, scenario, study_case)
# -----------------------------------------
# Information that needs to be stored across episodes, aka one simulation run
inter_episodic_data = {
"buffer": ReplayBuffer(
buffer_size=int(world.learning_config.get("replay_buffer_size", 5e5)),
obs_dim=world.learning_role.rl_algorithm.obs_dim,
act_dim=world.learning_role.rl_algorithm.act_dim,
n_rl_units=len(world.learning_role.rl_strats),
device=world.learning_role.device,
float_type=world.learning_role.float_type,
),
"actors_and_critics": None,
"max_eval": defaultdict(lambda: -1e9),
"all_eval": defaultdict(list),
"avg_all_eval": [],
"episodes_done": 0,
"eval_episodes_done": 0,
"noise_scale": world.learning_config.get("noise_scale", 1.0),
}
# -----------------------------------------
validation_interval = min(
world.learning_role.training_episodes,
world.learning_config.get("validation_episodes_interval", 5),
)
eval_episode = 1
for episode in tqdm(
range(1, world.learning_role.training_episodes + 1),
desc="Training Episodes",
):
# TODO normally, loading twice should not create issues, somehow a scheduling issue is raised currently
if episode != 1:
setup_world(
world=world,
scenario_data=scenario_data,
study_case=study_case,
episode=episode,
)
# -----------------------------------------
# Give the newly initialized learning role the needed information across episodes
world.learning_role.load_inter_episodic_data(inter_episodic_data)
world.run()
# -----------------------------------------
# Store updated information across episodes
inter_episodic_data = world.learning_role.get_inter_episodic_data()
inter_episodic_data["episodes_done"] = episode
# evaluation run:
if (
episode % validation_interval == 0
and episode
>= world.learning_role.episodes_collecting_initial_experience
+ validation_interval
):
world.reset()
# load evaluation run
setup_world(
world=world,
scenario_data=scenario_data,
study_case=study_case,
perform_evaluation=True,
eval_episode=eval_episode,
)
world.learning_role.load_inter_episodic_data(inter_episodic_data)
world.run()
total_rewards = world.output_role.get_sum_reward()
avg_reward = np.mean(total_rewards)
# check reward improvement in evaluation run
# and store best run in eval folder
terminate = world.learning_role.compare_and_save_policies(
{"avg_reward": avg_reward}
)
inter_episodic_data["eval_episodes_done"] = eval_episode
# if we have not improved in the last x evaluations, we stop loop
if terminate:
break
eval_episode += 1
world.reset()
# if at end of simulation save last policies
if episode == (world.learning_role.training_episodes):
world.learning_role.rl_algorithm.save_params(
directory=f"{world.learning_role.trained_policies_save_path}/last_policies"
)
# export buffer_obs.json in the last training episode to get observations later
export = inter_episodic_data["buffer"].observations.tolist()
path = f"{world.learning_role.trained_policies_save_path}/buffer_obs"
os.makedirs(path, exist_ok=True)
with open(os.path.join(path, "buffer_obs.json"), "w") as f:
json.dump(export, f)
# container shutdown implicitly with new initialisation
logger.info("################")
logger.info("Training finished, Start evaluation run")
world.export_csv_path = temp_csv_path
world.reset()
# load scenario for evaluation
setup_world(
world=world,
scenario_data=scenario_data,
study_case=study_case,
terminate_learning=True,
)
world.learning_role.load_inter_episodic_data(inter_episodic_data)
Run the Example Case
Now we run the example case as done previously in the market zone tutorial. The main difference here is that we call the run_learning()
function, which iterates multiple times over the simulation horizon for reinforcement learning.
[ ]:
# Import necessary classes and functions from the Assume framework
from assume import World
from assume.scenario.loader_csv import load_scenario_folder
# Define paths for input and output data
csv_path = "outputs"
# Define the data format and database URI for storing results
# Use "local_db" for SQLite or "timescale" for TimescaleDB
os.makedirs(csv_path, exist_ok=True)
os.makedirs("local_db", exist_ok=True)
data_format = "local_db" # Options: "local_db" (SQLite) or "timescale" (TimescaleDB)
# Set the database URI based on the selected data format
if data_format == "local_db":
db_uri = "sqlite:///local_db/assume_db.db" # SQLite database
elif data_format == "timescale":
db_uri = "postgresql://assume:assume@localhost:5432/assume" # TimescaleDB
# Create the World instance with the specified database
world = World(database_uri=db_uri, export_csv_path=csv_path)
# Load the scenario configuration
# - world: World instance
# - inputs_path: Folder containing input data
# - scenario: Scenario subfolder in inputs
# - study_case: Which configuration (case) to use for the simulation
load_scenario_folder(
world,
inputs_path=inputs_path,
scenario="tutorial_08",
study_case="zonal_case",
)
# If learning mode is enabled, run the reinforcement learning loop
if world.learning_config.get("learning_mode", False):
run_learning(
world,
inputs_path=inputs_path,
scenario="tutorial_08",
study_case="zonal_case",
)
# Run the simulation
world.run()
Compare the Results
Next, we use the same code from the market zone tutorial to generate a Plotly graph displaying market clearing prices over time for each zone.
[ ]:
# Import Plotly for creating interactive visualizations
import plotly.graph_objects as go
# Define the path to the simulation output directory
output_dir = "outputs/tutorial_08_zonal_case"
market_meta_path = os.path.join(output_dir, "market_meta.csv")
# Load the market metadata from the CSV file
market_meta = pd.read_csv(market_meta_path, index_col="time", parse_dates=True)
market_meta = market_meta.drop(
columns=market_meta.columns[0]
) # Drop the first unnamed column
# Extract unique zones from the "node" column
zones = market_meta["node"].unique()
# Initialize an empty DataFrame to store clearing prices for each zone
clearing_prices_df = pd.DataFrame()
# Populate the DataFrame with clearing prices for each zone
for zone in zones:
zone_data = market_meta[market_meta["node"] == zone][["price"]]
zone_data = zone_data.rename(columns={"price": f"{zone}_price"})
clearing_prices_df = (
pd.merge(
clearing_prices_df,
zone_data,
left_index=True,
right_index=True,
how="outer",
)
if not clearing_prices_df.empty
else zone_data
)
# Sort the DataFrame by time
clearing_prices_df = clearing_prices_df.sort_index()
# Initialize the Plotly figure
fig = go.Figure()
# Plot clearing prices for each zone
for zone in zones:
fig.add_trace(
go.Scatter(
x=clearing_prices_df.index,
y=clearing_prices_df[f"{zone}_price"],
mode="lines",
name=f"{zone} - Simulation",
line=dict(width=2),
)
)
# Customize the layout for better aesthetics and interaction
fig.update_layout(
title="Clearing Prices per Zone Over Time: Simulation Results",
xaxis_title="Time",
yaxis_title="Clearing Price (EUR/MWh)",
legend_title="Market Zones",
xaxis=dict(
tickangle=45, # Rotate x-axis labels for readability
type="date", # Ensure x-axis is treated as dates
),
hovermode="x unified", # Unified hover to compare values across zones at the same time
template="plotly_white", # Use a clean white background
width=1000,
height=600,
)
# Display the interactive plot
fig.show()
2. Explainable AI and SHAP Values#
Prerequisites#
To follow along with this tutorial, we need some additional libraries.
matplotlib
shap
scikit-learn
[ ]:
!pip install matplotlib
!pip install shap==0.42.1
!pip install scikit-learn==1.3.0
2.1 Understanding Explainable AI#
Explainable AI (XAI) refers to techniques and methods that make the behavior and decisions of AI systems understandable to humans. In the context of complex models like deep neural networks, XAI helps to: - Increase Transparency: Providing insights into how models make decisions. - Build Trust: Users and stakeholders can trust AI systems if they understand them. - Ensure Compliance: Regulatory requirements often demand explainability. - Improve Models: Identifying weaknesses or biases in models.
2.2 Introduction to SHAP Values#
Shapley values are a method from cooperative game theory used to explain the contribution of each feature to the prediction of a machine learning model, such as a neural network. They provide an interpretability technique by distributing the “payout” (the prediction) among the input features, attributing the importance of each feature to the prediction.
For a given prediction, the Shapley value of a feature represents the average contribution of that feature to the prediction, considering all possible combinations of other features.
Marginal Contribution: The marginal contribution of a feature is the difference between the prediction with and without that feature.
Average over all subsets: The Shapley value is calculated by averaging the marginal contributions over all possible subsets of features.
The formula for the Shapley value of feature \(i\) is:
Where: - \(N\) is the set of all features. - \(S\) is a subset of features. - \(f(S)\) is the model’s prediction when using only the features in subset \(S\).
The shap
library is a popular tool for computing Shapley values for machine learning models, including neural networks.
Why Use SHAP in RL? - Model-Agnostic: Applicable to any machine learning model, including neural networks. - Local Explanations: Provides explanations for individual predictions (actions). - Consistency: Ensures that features contributing more to the prediction have higher Shapley values.
Properties of SHAP: 1. Local Accuracy: The sum of Shapley values equals the difference between the model output and the expected output. 2. Missingness: Features not present in the model have zero Shapley value. 3. Consistency: If a model changes so that a feature contributes more to the prediction, the Shapley value of that feature should not decrease.
3. Calculating SHAP values#
We will work with:
Observations (``input_data``): These are the inputs to our actor neural network, representing the state of the environment.
Trained Actor Model: A neural network representing the decision making of one RL power plant that outputs actions based on the observations.
Our goal is to:
Load the observations and the trained actor model.
Use the model to predict actions.
Apply SHAP to explain the model’s predictions.
3.1. Loading and Preparing Data#
First, let’s load the necessary libraries and the data.
[ ]:
import matplotlib.pyplot as plt
import pandas as pd
import shap
import torch as th
from sklearn.model_selection import train_test_split
We define a utility function to load observations and input data from a specified path. Analyzing the shap values for all observations and all parameters would make this notebook quite lengthy, so we’re filtering the observation data frame to include only 700 observations.
[ ]:
# @title Load observations function
def load_observations(path, feature_names):
# Load observations
obs_path = f"{path}/buffer_obs.json"
print(obs_path)
with open(obs_path) as file:
json_data = json.load(file)
# Convert the list of lists into a 2D numpy array
input_data = np.array(json_data)
input_data = np.squeeze(input_data)
print(len(input_data))
# filter out arrays where all value are 0
input_data = input_data[~np.all(input_data == 0, axis=1)]
print(len(input_data))
# filter only first 700 observations
input_data = input_data[:300]
return pd.DataFrame(input_data, columns=feature_names), input_data
Load Observations and Input Data
Load the observations and input data using the utility function.
[ ]:
# path to extra loggedobservation values
path = input_dir + "/learned_strategies/zonal_case/buffer_obs"
# Define feature names (replace with actual feature names)
# make columns names
names_1 = ["price forecast t+" + str(x) for x in range(1, 25)]
names_2 = ["residual load forecast t+" + str(x) for x in range(1, 25)]
feature_names = names_1 + names_2 + ["total capacity t-1"] + ["marginal costs t-1"]
df_obs, input_data = load_observations(path, feature_names)
df_obs
Load the Trained Actor Model
We initialize and load the trained actor neural network. Therefore, we define the actor neural network class that will be used to predict actions based on observations.
[ ]:
from assume.reinforcement_learning.neural_network_architecture import MLPActor
# Initialize the model
obs_dim = len(feature_names)
act_dim = 2 # Adjust if your model outputs a different number of actions
model = MLPActor(obs_dim=obs_dim, act_dim=act_dim, float_type=th.float)
[ ]:
# which actor is the RL actor
ACTOR_NUM = len(powerplant_units) # 20
# Path to actor we want to analyse
actor_path = os.path.join(
input_dir,
f"learned_strategies/zonal_case/avg_reward_eval_policies/actors/actor_Unit {ACTOR_NUM}.pt",
)
# Load the trained model parameters
model_state = th.load(actor_path, map_location=th.device("cpu"))
model.load_state_dict(model_state["actor"])
Get the actions base on observation tensor we just loaded.
[ ]:
actions = []
for obs in input_data:
obs_tensor = th.tensor(obs, dtype=th.float)
action = model(obs_tensor)
actions.append(action)
3.2. Creating a SHAP Explainer#
In the next step we create the Shap explainer. In this example we facilitat the Kernel Shap method. You can easily switch it out for Deep Shap. The SHAP Kernel Explainer is a model-agnostic method for computing SHAP values, which can be applied to any machine learning model, including black-box models like neural networks, decision trees, or ensemble models. It uses a simplified linear approximation based on the Kernel SHAP method to estimate the SHAP values, allowing you to interpret how each feature contributes to a particular model’s prediction. Basically the SHAP Kernel Explainer builds a weighted linear regression model around each prediction, using different combinations (coalitions) of input features to simulate their presence or absence. This results in SHAP values that represent the marginal contribution of each feature.
As we fit a linear regression, we split the observatoin and action data into test and train data sets.
[ ]:
# @ Title Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(
input_data, actions, test_size=0.15, random_state=42
)
# Convert data to tensors
y_train = th.stack(y_train)
y_test = th.stack(y_test)
X_train_tensor = th.tensor(X_train, dtype=th.float32)
y_train_tensor = th.tensor(y_train, dtype=th.float32)
X_test_tensor = th.tensor(X_test, dtype=th.float32)
y_test_tensor = th.tensor(y_test, dtype=th.float32)
We define a prediction function compatible with SHAP and create a Kernel SHAP explainer.
[ ]:
# @ Title Define a prediction function for generating actions for SHAP Explainer
def model_predict(X):
X_tensor = th.tensor(X, dtype=th.float32)
model.eval()
with th.no_grad():
return model(X_tensor).numpy()
[ ]:
# Create the SHAP Kernel Explainer
explainer = shap.KernelExplainer(model_predict, X_train)
[ ]:
# Calculate SHAP values for the test set
shap_values = explainer.shap_values(X_test)
4. Visualizing SHAP Values#
We generate summary plots to visualize feature importance for each output dimension.
[ ]:
# Summary plot for the first output dimension
shap.summary_plot(shap_values[0], X_test, feature_names=feature_names, show=False)
plt.title("Summary Plot for Output Dimension 0, p_inflex")
plt.show()
# Summary plot for the second output dimension
shap.summary_plot(shap_values[1], X_test, feature_names=feature_names, show=False)
plt.title("Summary Plot for Output Dimension 1, p_flex")
plt.show()
shap.summary_plot(
shap_values[0],
X_test,
feature_names=feature_names,
plot_type="bar",
title="Summary Bar Plot for Output Dimension 0",
)
shap.summary_plot(
shap_values[1],
X_test,
feature_names=feature_names,
plot_type="bar",
title="Summary Bar Plot for Output Dimension 1",
)
The SHAP summary plots show the impact of each feature on the model’s predictions for each output dimension (action). Features with larger absolute SHAP values have a more significant influence on the decision-making process of the RL agent.
Positive SHAP Value: Indicates that the feature contributes positively to the predicted action value.
Negative SHAP Value: Indicates that the feature contributes negatively to the predicted action value.
By analyzing these plots, we can identify which features are most influential and understand how changes in feature values affect the agent’s actions.
5. Conclusion#
In this tutorial, we’ve demonstrated how to apply SHAP to a reinforcement learning agent to explain its decision-making process. By interpreting the SHAP values, we gain valuable insights into which features influence the agent’s actions, enhancing transparency and trust in the model.
Explainability is crucial, especially when deploying RL agents in real-world applications where understanding the rationale behind decisions is essential for safety, fairness, and compliance.
6. Additional Resources#
SHAP Documentation: https://shap.readthedocs.io/en/latest/
PyTorch Documentation: https://pytorch.org/docs/stable/index.html
Reinforcement Learning Introduction: Richard S. Sutton and Andrew G. Barto, “Reinforcement Learning: An Introduction”
Interpretable Machine Learning Book: https://christophm.github.io/interpretable-ml-book/
Feel free to experiment with the code and explore different explainability techniques. Happy learning!