Author: Nele Albers
Date: January 2025
Let's compare the optimal policy, the policy of always using human feedback, the policy of using human feedback 50% of the time, and the policy of never using human feedback over time, thus reproducing Figure 2. We also visualize the transition functions for giving and not giving human feedback (Supplementary Figure 1).
Required files:
Created files:
Authored by Nele Albers, Francisco S. Melo, Mark A. Neerincx, Olya Kudina, and Willem-Paul Brinkman.
Let's import the packages we need.
import graphviz # for network plot
from IPython.display import set_matplotlib_formats
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pickle
import seaborn as sns
And we define some variables we need throughout.
FEAT_SEL = [0, 1, 2] # Selected base state features
NUM_VALS_PER_FEATURE = [3, 2, 2] # Number of values per state feature
DISCOUNT_FACTOR = 0.85
NUM_ACTIONS = 2
PATH = "Intermediate_Results/" # Pre-fix for path for storing results
path_to_save = str(str(FEAT_SEL[0]) + str(FEAT_SEL[1]) + str(FEAT_SEL[2]) + "_" + str(DISCOUNT_FACTOR) + "_" + str(NUM_VALS_PER_FEATURE))
states = [[i, j, k] for i in range(NUM_VALS_PER_FEATURE[0]) for j in range(NUM_VALS_PER_FEATURE[1]) for k in range(NUM_VALS_PER_FEATURE[2])]
num_states = len(states)
And we set some things for the styling of our plot later.
sns.set()
sns.set_style("white")
set_matplotlib_formats('retina')
med_fontsize = 22
small_fontsize = 18
extrasmall_fontsize = 15
sns.set_context("paper", rc={"font.size":small_fontsize,"axes.titlesize":med_fontsize,"axes.labelsize":med_fontsize,
'xtick.labelsize':small_fontsize, 'ytick.labelsize':small_fontsize,
'legend.fontsize':extrasmall_fontsize,'legend.title_fontsize': extrasmall_fontsize})
/tmp/ipykernel_4403/3994115924.py:4: DeprecationWarning: `set_matplotlib_formats` is deprecated since IPython 7.23, directly use `matplotlib_inline.backend_inline.set_matplotlib_formats()` set_matplotlib_formats('retina')
Let's load the data with the transitions samples.
data = pd.read_csv("Data/data_rl_samples_abstracted" + str(FEAT_SEL) + str(NUM_VALS_PER_FEATURE) + ".csv",
converters={'s0': eval, 's1': eval})
data_train = data.copy(deep=True)
We also load the data on all states. This includes the first states of people with no data from session 2.
df_all_states = pd.read_csv("Data/all_abstract_states_with_session.csv",
converters={'state': eval})
Let's load the previously computed dynamcis and Q-values.
with open(PATH + "_reward_func_" + path_to_save, "rb") as f:
reward_func = pickle.load(f)
with open(PATH + "_trans_func_" + path_to_save, "rb") as f:
trans_func = pickle.load(f)
with open(PATH + "_qvals_" + path_to_save, "rb") as f:
q_vals = pickle.load(f)
Here we compute the policies we want to compare, as well as the rewards and transitions when following those policies.
opt_policy = [np.argmax(q_vals[state]) for state in range(len(q_vals))]
trans_func_opt_policy = np.array([trans_func[state][opt_policy[state]] for state in range(num_states)])
reward_func_opt_policy = np.array([reward_func[state][opt_policy[state]] for state in range(num_states)])
print("Optimal policy:", [int(i) for i in opt_policy])
# Here we never provide feedback
nosupport_policy = [0 for state in range(len(q_vals))]
trans_func_nosupport_policy = np.array([trans_func[state][nosupport_policy[state]] for state in range(num_states)])
reward_func_nosupport_policy = np.array([reward_func[state][nosupport_policy[state]] for state in range(num_states)])
# Here we always provide feedback
always_policy = [1 for state in range(len(q_vals))]
trans_func_always_policy = np.array([trans_func[state][always_policy[state]] for state in range(num_states)])
reward_func_always_policy = np.array([reward_func[state][always_policy[state]] for state in range(num_states)])
# Let's compute the average of the transition functions for the 2 actions in a state.
# This is if we choose each action 50% of the time. Of course this is a theoretical construct.
trans_func_avg_policy = np.array([sum(trans_func[state])/NUM_ACTIONS for state in range(num_states)])
reward_func_avg_policy = np.array([sum(reward_func[state])/NUM_ACTIONS for state in range(num_states)])
Optimal policy: [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1]
As starting population of our simulation we want to use the distribution of people across states we observed in the first session of our study.
all_states_s1 = df_all_states[df_all_states['session'] == 1]
all_states_s1 = all_states_s1.reset_index(drop=True)
all_states_count = np.zeros(num_states)
for p in range(len(all_states_s1)):
state = list(np.take(all_states_s1.iloc[p]["state"], FEAT_SEL))
state_idx = states.index(state)
all_states_count[state_idx] += 1
all_states_frac = all_states_count/sum(all_states_count)
print("Fraction of people in each state in session 1:", np.round(all_states_frac, 2))
Fraction of people in each state in session 1: [0.34 0.17 0.06 0.05 0.07 0.05 0.02 0.04 0.03 0.07 0.02 0.07]
Now we actually simulate the four different policies over time.
num_steps_list = [1, 2, 3, 5, 10, 20, 30, 50, 100]
initial_pop = all_states_frac
initial_pop_size = sum(initial_pop)
rew_opt_policy_list = []
rew_nosupport_policy_list = []
rew_avg_policy_list = []
rew_always_policy_list = []
for num_steps in num_steps_list:
print("\nNumber of time steps:", num_steps)
pop_opt_policy = initial_pop
pop_nosupport_policy = initial_pop
pop_avg_policy = initial_pop
pop_always_policy = initial_pop
rew_opt_policy = 0
rew_nosupport_policy = 0
rew_avg_policy = 0
rew_always_policy = 0
for t in range(num_steps):
trans_time_rew = trans_func_opt_policy * reward_func_opt_policy # element-wise multiplication
rew_opt_policy += sum((trans_time_rew.T).dot(pop_opt_policy)) # total reward for transitions
pop_opt_policy = (trans_func_opt_policy.T).dot(pop_opt_policy) # new population
trans_time_rew = trans_func_nosupport_policy * reward_func_nosupport_policy # element-wise multiplication
rew_nosupport_policy += sum((trans_time_rew.T).dot(pop_nosupport_policy)) # total reward for transitions
pop_nosupport_policy = (trans_func_nosupport_policy.T).dot(pop_nosupport_policy) # new population
trans_time_rew = trans_func_avg_policy * reward_func_avg_policy # element-wise multiplication
rew_avg_policy += sum((trans_time_rew.T).dot(pop_avg_policy)) # total reward for transitions
pop_avg_policy = (trans_func_avg_policy.T).dot(pop_avg_policy) # new population
trans_time_rew = trans_func_always_policy * reward_func_always_policy # element-wise multiplication
rew_always_policy += sum((trans_time_rew.T).dot(pop_always_policy)) # total reward for transitions
pop_always_policy = (trans_func_always_policy.T).dot(pop_always_policy) # new population
print("Opt policy:", round(rew_opt_policy/(num_steps * initial_pop_size), 2))
rew_opt_policy_list.append(rew_opt_policy/(num_steps * initial_pop_size))
print("No support policy:", round(rew_nosupport_policy/(num_steps * initial_pop_size), 2))
rew_nosupport_policy_list.append(rew_nosupport_policy/(num_steps * initial_pop_size))
print("Avg policy:", round(rew_avg_policy/(num_steps * initial_pop_size), 2))
rew_avg_policy_list.append(rew_avg_policy/(num_steps * initial_pop_size))
print("Always support policy:", round(rew_always_policy/(num_steps * initial_pop_size), 2))
rew_always_policy_list.append(rew_always_policy/(num_steps * initial_pop_size))
Number of time steps: 1 Opt policy: 0.55 No support policy: 0.53 Avg policy: 0.54 Always support policy: 0.55 Number of time steps: 2 Opt policy: 0.56 No support policy: 0.53 Avg policy: 0.55 Always support policy: 0.57 Number of time steps: 3 Opt policy: 0.57 No support policy: 0.53 Avg policy: 0.55 Always support policy: 0.57 Number of time steps: 5 Opt policy: 0.59 No support policy: 0.54 Avg policy: 0.56 Always support policy: 0.59 Number of time steps: 10 Opt policy: 0.61 No support policy: 0.55 Avg policy: 0.58 Always support policy: 0.61 Number of time steps: 20 Opt policy: 0.63 No support policy: 0.56 Avg policy: 0.59 Always support policy: 0.63 Number of time steps: 30 Opt policy: 0.64 No support policy: 0.56 Avg policy: 0.6 Always support policy: 0.64 Number of time steps: 50 Opt policy: 0.65 No support policy: 0.57 Avg policy: 0.61 Always support policy: 0.64 Number of time steps: 100 Opt policy: 0.66 No support policy: 0.57 Avg policy: 0.61 Always support policy: 0.64
Let's compute the effort corresponding to the average rewards after 100 time steps. A reward of 0.5 corresponds to the mean effort. And a reward of 1 to an effort of 10.
effort_mean = data_train["effort"].mean()
slope = (10 - effort_mean)/(1 - 0.5)
intercept = effort_mean - 0.5 * slope
print("Intercept:", round(intercept, 2), "Slope:", round(slope, 2))
def get_effort(reward):
return intercept + reward * slope
effort_opt_policy_100_steps = get_effort(rew_opt_policy_list[-1])
effort_always_policy_100_steps = get_effort(rew_always_policy_list[-1])
effort_avg_policy_100_steps = get_effort(rew_avg_policy_list[-1])
effort_nosupport_policy_100_steps = get_effort(rew_nosupport_policy_list[-1])
print("Effort opt policy after 100 steps:", round(effort_opt_policy_100_steps, 2))
print("Effort always policy after 100 steps:", round(effort_always_policy_100_steps, 2))
print("Effort avg policy after 100 steps:", round(effort_avg_policy_100_steps, 2))
print("Effort no support policy after 100 steps:", round(effort_nosupport_policy_100_steps, 2))
print("Always - no support:", round(effort_always_policy_100_steps - effort_nosupport_policy_100_steps, 2))
Intercept: 1.47 Slope: 8.53 Effort opt policy after 100 steps: 7.08 Effort always policy after 100 steps: 6.97 Effort avg policy after 100 steps: 6.68 Effort no support policy after 100 steps: 6.32 Always - no support: 0.65
And we compute how much the corresponding increase is compared to the mean effort in percent.
percent_increase_opt_policy = effort_opt_policy_100_steps * 100 / effort_mean - 100
percent_increase_always_policy = effort_always_policy_100_steps * 100 / effort_mean - 100
percent_increase_avg_policy = effort_avg_policy_100_steps * 100 / effort_mean - 100
percent_increase_nosupport_policy = effort_nosupport_policy_100_steps * 100 / effort_mean - 100
print("Percentage increase opt. policy:", round(percent_increase_opt_policy, 2), "%")
print("Percentage increase always policy:", round(percent_increase_always_policy, 2), "%")
print("Percentage increase avg policy:", round(percent_increase_avg_policy, 2), "%")
print("Percentage increase no support policy:", round(percent_increase_nosupport_policy, 2), "%")
Percentage increase opt. policy: 23.36 % Percentage increase always policy: 21.52 % Percentage increase avg policy: 16.41 % Percentage increase no support policy: 10.18 %
And let's plot this.
plt.figure(figsize=(10,10))
fig_lower_bound = 0 # y-axis lower limit for figure
fig_higher_bound = 1 # y-axis upper limit for figure
x_vals = np.arange(len(num_steps_list))
# Plot average reward for the four policies
plt.plot(x_vals, rew_opt_policy_list, color = 'deepskyblue', label = "Optimal policy")
plt.plot(x_vals, rew_always_policy_list, color = 'darkblue', label = "Always human feedback",
linestyle = "dashed")
plt.plot(x_vals, rew_avg_policy_list, color = 'slategray', label = "50% human feedback",
linestyle = "dashdot")
plt.plot(x_vals, rew_nosupport_policy_list, color = 'black', label = "Never human feedback",
linestyle = "dotted")
plt.legend()
plt.ylabel("Mean reward per previous activity assignment")
plt.xlabel("Number of time steps")
plt.ylim([fig_lower_bound, fig_higher_bound])
plt.xticks(x_vals, num_steps_list)
plt.savefig("Figures/policy_comp.pdf", dpi=1500,
bbox_inches='tight', pad_inches=0)
Let's see what the transition functions for giving and not giving feedback look like. We color edges between states based on whether the transition corresponds toa transition to a state with a higher or the highest reward for not giving feedback (blue), lower or lowest reward for not giving feedback (red), or same reward for not giving feedback (black).
Let's define some variables we need for both plots.
fontsize = "18" # font size for labels in plot
states_names = ["000", "001", "010", "011", "100", "101", "110", "111", "200", "201", "210", "211"]
scale_factor = 0.2
min_weight = 1/num_states # min. weight for edges to be plotted
We start with the transition function for always giving human feedback.
# Values of states
rew_no_support = [reward_func[s][0] for s in range(len(reward_func))]
# format specifies in what file type the graph will be saved. Can also use 'pdf'.
GA = graphviz.Digraph(filename = "Figures/Network_plot_transition_function_human_support",
engine="neato", format='png')
GA.node('000',pos='-1,-3.73!', fillcolor = 'lightgray', style = 'filled', xlabel="R(s,0)=" + str(round(rew_no_support[0], 2)), fontsize=fontsize)
GA.node('001',pos='-2.73,-2.73!', fillcolor = 'lightgray', style = 'filled', xlabel="R(s,0)=" + str(round(rew_no_support[1], 2)), fontsize=fontsize)
GA.node('010',pos='-3.73,-1!', fillcolor = 'lightgray', style = 'filled', xlabel="R(s,0)=" + str(round(rew_no_support[2], 2)), fontsize=fontsize)
GA.node('011',pos='-3.73,1!', fillcolor = 'lightgray', style = 'filled', xlabel="R(s,0)=" + str(round(rew_no_support[3], 2)), fontsize=fontsize)
GA.node('100',pos='-2.73,2.73!', fillcolor = 'lightgray', style = 'filled', xlabel="R(s,0)=" + str(round(rew_no_support[4], 2)), fontsize=fontsize)
GA.node('101',pos='-1,3.73!', fillcolor = 'lightgray', style = 'filled', xlabel="R(s,0)=" + str(round(rew_no_support[5], 2)), fontsize=fontsize)
GA.node('110',pos='1,3.73!', fillcolor = 'lightgray', style = 'filled', xlabel="R(s,0)=" + str(round(rew_no_support[6], 2)), fontsize=fontsize)
GA.node('111',pos='2.73,2.73!', fillcolor = 'lightgray', style = 'filled', xlabel="R(s,0)=" + str(round(rew_no_support[7], 2)), fontsize=fontsize)
GA.node('200',pos='3.73,1!', fillcolor = 'lightgray', style = 'filled', xlabel="R(s,0)=" + str(round(rew_no_support[8], 2)), fontsize=fontsize)
GA.node('201',pos='3.73,-1!', fillcolor = 'lightgray', style = 'filled', xlabel="R(s,0)=" + str(round(rew_no_support[9], 2)), fontsize=fontsize)
GA.node('210',pos='2.73,-2.73!', fillcolor = 'lightgray', style = 'filled', xlabel="R(s,0)=" + str(round(rew_no_support[10], 2)), fontsize=fontsize)
GA.node('211',pos='1,-3.73!', fillcolor = 'lightgray', style = 'filled', xlabel="R(s,0)=" + str(round(rew_no_support[11], 2)), fontsize=fontsize)
for s0 in range(num_states):
for s1 in range(num_states):
edge_width = trans_func[s0][1][s1]
if edge_width >= min_weight:
print("State " + states_names[s0] + " -> State " + states_names[s1] + ":", round(edge_width, 2))
# Get to a state with higher value or stay in state with highest value
if rew_no_support[s1] > rew_no_support[s0] or rew_no_support[s1] == max(rew_no_support):
color='blue'
# Move to a state with lower value or stay in state with lowest value
elif rew_no_support[s1] < rew_no_support[s0] or rew_no_support[s1] == min(rew_no_support):
color ='crimson'
# Stay in non-highest-value state
else:
color = 'black'
GA.edge(states_names[s0],
states_names[s1],
penwidth = str(edge_width/scale_factor),
arrowhead = 'normal',
arrowsize = str(1),
color=color)
# save plot
GA.render()
GA
State 000 -> State 000: 0.71 State 000 -> State 001: 0.14 State 001 -> State 000: 0.19 State 001 -> State 001: 0.49 State 001 -> State 101: 0.09 State 010 -> State 000: 0.15 State 010 -> State 010: 0.45 State 010 -> State 011: 0.1 State 010 -> State 210: 0.15 State 011 -> State 001: 0.17 State 011 -> State 011: 0.21 State 011 -> State 111: 0.25 State 011 -> State 200: 0.08 State 011 -> State 211: 0.17 State 100 -> State 000: 0.17 State 100 -> State 001: 0.14 State 100 -> State 100: 0.24 State 100 -> State 101: 0.17 State 100 -> State 200: 0.1 State 101 -> State 001: 0.21 State 101 -> State 101: 0.25 State 101 -> State 111: 0.08 State 101 -> State 201: 0.25 State 101 -> State 211: 0.08 State 110 -> State 000: 0.1 State 110 -> State 010: 0.1 State 110 -> State 100: 0.1 State 110 -> State 101: 0.1 State 110 -> State 111: 0.1 State 110 -> State 210: 0.2 State 110 -> State 211: 0.3 State 111 -> State 111: 0.29 State 111 -> State 211: 0.59 State 200 -> State 100: 0.13 State 200 -> State 101: 0.13 State 200 -> State 200: 0.33 State 200 -> State 210: 0.13 State 201 -> State 101: 0.18 State 201 -> State 201: 0.36 State 201 -> State 211: 0.21 State 210 -> State 200: 0.11 State 210 -> State 210: 0.89 State 211 -> State 211: 0.85
And now the transition function for never giving human feedback.
# Values of states
rew_no_support = [reward_func[s][0] for s in range(len(q_vals))]
# format specifies in what file type the graph will be saved. Can also use 'pdf'.
GA = graphviz.Digraph(filename = "Figures/Network_plot_transition_function_nosupport",
engine="neato", format='png')
GA.node('000',pos='-1,-3.73!', fillcolor = 'lightgray', style = 'filled', xlabel="R(s,0)=" + str(round(rew_no_support[0], 2)), fontsize=fontsize)
GA.node('001',pos='-2.73,-2.73!', fillcolor = 'lightgray', style = 'filled', xlabel="R(s,0)=" + str(round(rew_no_support[1], 2)), fontsize=fontsize)
GA.node('010',pos='-3.73,-1!', fillcolor = 'lightgray', style = 'filled', xlabel="R(s,0)=" + str(round(rew_no_support[2], 2)), fontsize=fontsize)
GA.node('011',pos='-3.73,1!', fillcolor = 'lightgray', style = 'filled', xlabel="R(s,0)=" + str(round(rew_no_support[3], 2)), fontsize=fontsize)
GA.node('100',pos='-2.73,2.73!', fillcolor = 'lightgray', style = 'filled', xlabel="R(s,0)=" + str(round(rew_no_support[4], 2)), fontsize=fontsize)
GA.node('101',pos='-1,3.73!', fillcolor = 'lightgray', style = 'filled', xlabel="R(s,0)=" + str(round(rew_no_support[5], 2)), fontsize=fontsize)
GA.node('110',pos='1,3.73!', fillcolor = 'lightgray', style = 'filled', xlabel="R(s,0)=" + str(round(rew_no_support[6], 2)), fontsize=fontsize)
GA.node('111',pos='2.73,2.73!', fillcolor = 'lightgray', style = 'filled', xlabel="R(s,0)=" + str(round(rew_no_support[7], 2)), fontsize=fontsize)
GA.node('200',pos='3.73,1!', fillcolor = 'lightgray', style = 'filled', xlabel="R(s,0)=" + str(round(rew_no_support[8], 2)), fontsize=fontsize)
GA.node('201',pos='3.73,-1!', fillcolor = 'lightgray', style = 'filled', xlabel="R(s,0)=" + str(round(rew_no_support[9], 2)), fontsize=fontsize)
GA.node('210',pos='2.73,-2.73!', fillcolor = 'lightgray', style = 'filled', xlabel="R(s,0)=" + str(round(rew_no_support[10], 2)), fontsize=fontsize)
GA.node('211',pos='1,-3.73!', fillcolor = 'lightgray', style = 'filled', xlabel="R(s,0)=" + str(round(rew_no_support[11], 2)), fontsize=fontsize)
for s0 in range(num_states):
for s1 in range(num_states):
edge_width = trans_func[s0][0][s1]
if edge_width >= min_weight:
print("State " + states_names[s0] + " -> State " + states_names[s1] + ":", round(edge_width, 2))
# Get to a state with higher value or stay in state with highest value
if rew_no_support[s1] > rew_no_support[s0] or rew_no_support[s1] == max(rew_no_support):
color='blue'
# Move to a state with lower value or stay in state with lowest value
elif rew_no_support[s1] < rew_no_support[s0] or rew_no_support[s1] == min(rew_no_support):
color ='crimson'
# Stay in non-highest-value state
else:
color = 'black'
GA.edge(states_names[s0],
states_names[s1],
penwidth = str(edge_width/scale_factor),
arrowhead = 'normal',
arrowsize = str(1),
color=color)
# save plot
GA.render()
GA
State 000 -> State 000: 0.77 State 001 -> State 000: 0.17 State 001 -> State 001: 0.46 State 001 -> State 101: 0.1 State 010 -> State 000: 0.38 State 010 -> State 010: 0.27 State 011 -> State 001: 0.22 State 011 -> State 011: 0.19 State 011 -> State 111: 0.32 State 011 -> State 211: 0.09 State 100 -> State 000: 0.22 State 100 -> State 100: 0.38 State 100 -> State 200: 0.11 State 101 -> State 001: 0.21 State 101 -> State 101: 0.32 State 101 -> State 201: 0.16 State 101 -> State 211: 0.1 State 110 -> State 000: 0.14 State 110 -> State 100: 0.11 State 110 -> State 110: 0.11 State 110 -> State 210: 0.27 State 110 -> State 211: 0.14 State 111 -> State 001: 0.11 State 111 -> State 111: 0.37 State 111 -> State 211: 0.35 State 200 -> State 100: 0.1 State 200 -> State 200: 0.5 State 200 -> State 201: 0.1 State 201 -> State 101: 0.11 State 201 -> State 201: 0.49 State 201 -> State 211: 0.24 State 210 -> State 110: 0.1 State 210 -> State 210: 0.56 State 210 -> State 211: 0.14 State 211 -> State 111: 0.09 State 211 -> State 211: 0.75