Policy Gradient
In [1]:
Copied!
import os
from behavior_generation_lecture_python.mdp.policy import CategoricalPolicy
from behavior_generation_lecture_python.utils.grid_plotting import (
make_plot_policy_step_function,
)
from behavior_generation_lecture_python.mdp.mdp import (
GridMDP,
policy_gradient,
derive_deterministic_policy,
GRID_MDP_DICT,
HIGHWAY_MDP_DICT,
LC_RIGHT_ACTION,
STAY_IN_LANE_ACTION,
)
HIGHWAY_MDP_DICT["restrict_actions_to_available_states"] = False
import os
from behavior_generation_lecture_python.mdp.policy import CategoricalPolicy
from behavior_generation_lecture_python.utils.grid_plotting import (
make_plot_policy_step_function,
)
from behavior_generation_lecture_python.mdp.mdp import (
GridMDP,
policy_gradient,
derive_deterministic_policy,
GRID_MDP_DICT,
HIGHWAY_MDP_DICT,
LC_RIGHT_ACTION,
STAY_IN_LANE_ACTION,
)
HIGHWAY_MDP_DICT["restrict_actions_to_available_states"] = False
TOY EXAMPLE¶
In [2]:
Copied!
grid_mdp = GridMDP(**GRID_MDP_DICT)
grid_mdp = GridMDP(**GRID_MDP_DICT)
In [3]:
Copied!
policy = CategoricalPolicy(
sizes=[len(grid_mdp.initial_state), 32, len(grid_mdp.actions)],
actions=list(grid_mdp.actions),
)
policy = CategoricalPolicy(
sizes=[len(grid_mdp.initial_state), 32, len(grid_mdp.actions)],
actions=list(grid_mdp.actions),
)
In [4]:
Copied!
model_checkpoints = policy_gradient(
mdp=grid_mdp,
policy=policy,
iterations=100,
return_history=True,
)
model_checkpoints = policy_gradient(
mdp=grid_mdp,
policy=policy,
iterations=100,
return_history=True,
)
iteration: 1; return: -1.958; episode_length: 34.020
iteration: 2; return: -1.798; episode_length: 33.046
iteration: 3; return: -1.260; episode_length: 22.665
iteration: 4; return: -1.027; episode_length: 18.454
iteration: 5; return: -0.972; episode_length: 14.432
iteration: 6; return: -0.815; episode_length: 13.452
iteration: 7; return: -0.553; episode_length: 12.414
iteration: 8; return: -0.525; episode_length: 11.220
iteration: 9; return: -0.380; episode_length: 9.903
iteration: 10; return: -0.213; episode_length: 11.190
iteration: 11; return: -0.130; episode_length: 11.589
iteration: 12; return: -0.090; episode_length: 12.678
iteration: 13; return: -0.096; episode_length: 13.276
iteration: 14; return: -0.060; episode_length: 15.457
iteration: 15; return: -0.029; episode_length: 16.516
iteration: 16; return: -0.052; episode_length: 16.756
iteration: 17; return: -0.008; episode_length: 18.207
iteration: 18; return: -0.007; episode_length: 18.185
iteration: 19; return: -0.110; episode_length: 18.445
iteration: 20; return: 0.029; episode_length: 16.645
iteration: 21; return: 0.052; episode_length: 18.589
iteration: 22; return: 0.054; episode_length: 15.216
iteration: 23; return: 0.074; episode_length: 15.021
iteration: 24; return: 0.087; episode_length: 14.991
iteration: 25; return: 0.134; episode_length: 12.733
iteration: 26; return: 0.124; episode_length: 13.161
iteration: 27; return: 0.178; episode_length: 12.828
iteration: 28; return: 0.216; episode_length: 11.983
iteration: 29; return: 0.192; episode_length: 11.712
iteration: 30; return: 0.208; episode_length: 11.501
iteration: 31; return: 0.166; episode_length: 11.497
iteration: 32; return: 0.139; episode_length: 11.690
iteration: 33; return: 0.224; episode_length: 11.662
iteration: 34; return: 0.173; episode_length: 12.116
iteration: 35; return: 0.225; episode_length: 11.921
iteration: 36; return: 0.275; episode_length: 12.519
iteration: 37; return: 0.252; episode_length: 12.636
iteration: 38; return: 0.168; episode_length: 12.607
iteration: 39; return: 0.279; episode_length: 12.198
iteration: 40; return: 0.266; episode_length: 12.568
iteration: 41; return: 0.324; episode_length: 12.612
iteration: 42; return: 0.258; episode_length: 12.866
iteration: 43; return: 0.286; episode_length: 12.184
iteration: 44; return: 0.288; episode_length: 12.290
iteration: 45; return: 0.269; episode_length: 12.368
iteration: 46; return: 0.280; episode_length: 11.800
iteration: 47; return: 0.295; episode_length: 12.175
iteration: 48; return: 0.298; episode_length: 11.664
iteration: 49; return: 0.353; episode_length: 11.525
iteration: 50; return: 0.287; episode_length: 11.477
iteration: 51; return: 0.342; episode_length: 11.190
iteration: 52; return: 0.326; episode_length: 11.604
iteration: 53; return: 0.336; episode_length: 11.990
iteration: 54; return: 0.384; episode_length: 11.394
iteration: 55; return: 0.356; episode_length: 11.270
iteration: 56; return: 0.377; episode_length: 11.529
iteration: 57; return: 0.364; episode_length: 11.049
iteration: 58; return: 0.359; episode_length: 11.662
iteration: 59; return: 0.363; episode_length: 11.002
iteration: 60; return: 0.294; episode_length: 11.669
iteration: 61; return: 0.300; episode_length: 11.784
iteration: 62; return: 0.401; episode_length: 11.100
iteration: 63; return: 0.376; episode_length: 11.535
iteration: 64; return: 0.357; episode_length: 10.950
iteration: 65; return: 0.426; episode_length: 11.375
iteration: 66; return: 0.460; episode_length: 10.736
iteration: 67; return: 0.413; episode_length: 10.978
iteration: 68; return: 0.421; episode_length: 11.381
iteration: 69; return: 0.425; episode_length: 11.138
iteration: 70; return: 0.472; episode_length: 11.004
iteration: 71; return: 0.488; episode_length: 10.872
iteration: 72; return: 0.487; episode_length: 10.803
iteration: 73; return: 0.532; episode_length: 10.241
iteration: 74; return: 0.496; episode_length: 10.872
iteration: 75; return: 0.490; episode_length: 10.493
iteration: 76; return: 0.488; episode_length: 10.533
iteration: 77; return: 0.502; episode_length: 10.427
iteration: 78; return: 0.523; episode_length: 10.339
iteration: 79; return: 0.566; episode_length: 10.590
iteration: 80; return: 0.537; episode_length: 10.651
iteration: 81; return: 0.581; episode_length: 10.247
iteration: 82; return: 0.558; episode_length: 10.581
iteration: 83; return: 0.573; episode_length: 10.074
iteration: 84; return: 0.599; episode_length: 10.305
iteration: 85; return: 0.596; episode_length: 10.000
iteration: 86; return: 0.597; episode_length: 10.074
iteration: 87; return: 0.599; episode_length: 10.315
iteration: 88; return: 0.589; episode_length: 10.068
iteration: 89; return: 0.614; episode_length: 10.052
iteration: 90; return: 0.607; episode_length: 9.929
iteration: 91; return: 0.584; episode_length: 10.455
iteration: 92; return: 0.612; episode_length: 10.283
iteration: 93; return: 0.618; episode_length: 9.758
iteration: 94; return: 0.633; episode_length: 9.779
iteration: 95; return: 0.630; episode_length: 9.752
iteration: 96; return: 0.613; episode_length: 9.892
iteration: 97; return: 0.603; episode_length: 9.596
iteration: 98; return: 0.620; episode_length: 9.804
iteration: 99; return: 0.632; episode_length: 9.351
iteration: 100; return: 0.637; episode_length: 9.418
In [5]:
Copied!
policy_array = [
derive_deterministic_policy(mdp=grid_mdp, policy=model)
for model in model_checkpoints
]
policy_array = [
derive_deterministic_policy(mdp=grid_mdp, policy=model)
for model in model_checkpoints
]
In [6]:
Copied!
plot_policy_step_grid_map = make_plot_policy_step_function(
columns=4, rows=3, policy_over_time=policy_array
)
plot_policy_step_grid_map = make_plot_policy_step_function(
columns=4, rows=3, policy_over_time=policy_array
)
In [7]:
Copied!
interactive_widgets = not os.getenv("CI") == "true" # non-interative in CI
if interactive_widgets:
import ipywidgets
from IPython.display import display
iteration_slider = ipywidgets.IntSlider(
min=0, max=len(model_checkpoints) - 1, step=1, value=0
)
w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)
display(w)
else:
plot_policy_step_grid_map(100)
interactive_widgets = not os.getenv("CI") == "true" # non-interative in CI
if interactive_widgets:
import ipywidgets
from IPython.display import display
iteration_slider = ipywidgets.IntSlider(
min=0, max=len(model_checkpoints) - 1, step=1, value=0
)
w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)
display(w)
else:
plot_policy_step_grid_map(100)
HIGHWAY EXAMPLE¶
In [8]:
Copied!
if False:
# we will change this to true later on, to see the effect
HIGHWAY_MDP_DICT["transition_probabilities_per_action"][LC_RIGHT_ACTION] = [
(0.4, LC_RIGHT_ACTION),
(0.6, STAY_IN_LANE_ACTION),
]
if False:
# we will change this to true later on, to see the effect
HIGHWAY_MDP_DICT["transition_probabilities_per_action"][LC_RIGHT_ACTION] = [
(0.4, LC_RIGHT_ACTION),
(0.6, STAY_IN_LANE_ACTION),
]
In [9]:
Copied!
highway_mdp = GridMDP(**HIGHWAY_MDP_DICT)
highway_mdp = GridMDP(**HIGHWAY_MDP_DICT)
In [10]:
Copied!
policy = CategoricalPolicy(
sizes=[len(highway_mdp.initial_state), 32, len(highway_mdp.actions)],
actions=list(highway_mdp.actions),
)
policy = CategoricalPolicy(
sizes=[len(highway_mdp.initial_state), 32, len(highway_mdp.actions)],
actions=list(highway_mdp.actions),
)
In [11]:
Copied!
model_checkpoints = policy_gradient(
mdp=highway_mdp,
policy=policy,
iterations=200,
return_history=True,
)
model_checkpoints = policy_gradient(
mdp=highway_mdp,
policy=policy,
iterations=200,
return_history=True,
)
iteration: 1; return: -62.609; episode_length: 9.889
iteration: 2; return: -57.321; episode_length: 9.872
iteration: 3; return: -53.203; episode_length: 9.958
iteration: 4; return: -49.116; episode_length: 10.144
iteration: 5; return: -44.359; episode_length: 10.551
iteration: 6; return: -35.690; episode_length: 11.501
iteration: 7; return: -35.388; episode_length: 12.533
iteration: 8; return: -33.855; episode_length: 13.166
iteration: 9; return: -34.396; episode_length: 13.488
iteration: 10; return: -36.870; episode_length: 14.752
iteration: 11; return: -35.730; episode_length: 14.513
iteration: 12; return: -33.986; episode_length: 14.110
iteration: 13; return: -33.828; episode_length: 13.856
iteration: 14; return: -33.851; episode_length: 13.848
iteration: 15; return: -31.328; episode_length: 13.246
iteration: 16; return: -30.681; episode_length: 12.010
iteration: 17; return: -30.977; episode_length: 11.603
iteration: 18; return: -29.952; episode_length: 11.007
iteration: 19; return: -30.412; episode_length: 11.066
iteration: 20; return: -30.926; episode_length: 10.947
iteration: 21; return: -30.712; episode_length: 11.086
iteration: 22; return: -28.176; episode_length: 11.429
iteration: 23; return: -28.809; episode_length: 11.795
iteration: 24; return: -29.580; episode_length: 11.739
iteration: 25; return: -28.328; episode_length: 12.383
iteration: 26; return: -28.071; episode_length: 12.170
iteration: 27; return: -30.262; episode_length: 12.348
iteration: 28; return: -29.177; episode_length: 12.150
iteration: 29; return: -27.311; episode_length: 11.689
iteration: 30; return: -28.071; episode_length: 11.116
iteration: 31; return: -26.955; episode_length: 10.838
iteration: 32; return: -27.556; episode_length: 10.277
iteration: 33; return: -29.328; episode_length: 10.126
iteration: 34; return: -29.534; episode_length: 10.016
iteration: 35; return: -28.053; episode_length: 9.872
iteration: 36; return: -26.133; episode_length: 10.285
iteration: 37; return: -26.432; episode_length: 10.339
iteration: 38; return: -25.581; episode_length: 10.866
iteration: 39; return: -26.093; episode_length: 11.363
iteration: 40; return: -25.835; episode_length: 11.781
iteration: 41; return: -26.587; episode_length: 11.895
iteration: 42; return: -26.950; episode_length: 11.993
iteration: 43; return: -26.136; episode_length: 11.767
iteration: 44; return: -26.855; episode_length: 11.891
iteration: 45; return: -24.887; episode_length: 11.266
iteration: 46; return: -25.130; episode_length: 10.703
iteration: 47; return: -25.745; episode_length: 10.539
iteration: 48; return: -25.257; episode_length: 9.996
iteration: 49; return: -25.020; episode_length: 10.012
iteration: 50; return: -25.536; episode_length: 9.835
iteration: 51; return: -24.784; episode_length: 9.762
iteration: 52; return: -24.091; episode_length: 9.724
iteration: 53; return: -24.696; episode_length: 10.083
iteration: 54; return: -23.777; episode_length: 10.064
iteration: 55; return: -23.392; episode_length: 10.320
iteration: 56; return: -23.623; episode_length: 10.305
iteration: 57; return: -23.344; episode_length: 10.559
iteration: 58; return: -23.406; episode_length: 10.362
iteration: 59; return: -23.121; episode_length: 10.107
iteration: 60; return: -23.777; episode_length: 10.132
iteration: 61; return: -22.056; episode_length: 10.066
iteration: 62; return: -23.356; episode_length: 9.913
iteration: 63; return: -23.739; episode_length: 9.814
iteration: 64; return: -22.574; episode_length: 9.831
iteration: 65; return: -22.470; episode_length: 10.014
iteration: 66; return: -22.357; episode_length: 10.093
iteration: 67; return: -22.324; episode_length: 9.962
iteration: 68; return: -21.758; episode_length: 9.673
iteration: 69; return: -22.106; episode_length: 9.856
iteration: 70; return: -20.828; episode_length: 9.768
iteration: 71; return: -21.949; episode_length: 9.850
iteration: 72; return: -21.886; episode_length: 9.860
iteration: 73; return: -22.053; episode_length: 10.167
iteration: 74; return: -21.695; episode_length: 9.848
iteration: 75; return: -21.552; episode_length: 9.633
iteration: 76; return: -21.686; episode_length: 9.703
iteration: 77; return: -21.596; episode_length: 9.514
iteration: 78; return: -21.618; episode_length: 9.433
iteration: 79; return: -21.917; episode_length: 9.443
iteration: 80; return: -21.125; episode_length: 9.449
iteration: 81; return: -20.944; episode_length: 9.424
iteration: 82; return: -21.463; episode_length: 9.607
iteration: 83; return: -21.507; episode_length: 9.585
iteration: 84; return: -21.669; episode_length: 9.564
iteration: 85; return: -21.209; episode_length: 9.614
iteration: 86; return: -20.799; episode_length: 9.600
iteration: 87; return: -20.615; episode_length: 9.677
iteration: 88; return: -21.458; episode_length: 9.732
iteration: 89; return: -21.202; episode_length: 9.739
iteration: 90; return: -20.893; episode_length: 9.717
iteration: 91; return: -20.502; episode_length: 9.692
iteration: 92; return: -20.648; episode_length: 9.568
iteration: 93; return: -20.533; episode_length: 9.594
iteration: 94; return: -20.170; episode_length: 9.481
iteration: 95; return: -20.697; episode_length: 9.533
iteration: 96; return: -20.057; episode_length: 9.477
iteration: 97; return: -19.870; episode_length: 9.422
iteration: 98; return: -20.781; episode_length: 9.309
iteration: 99; return: -20.681; episode_length: 9.274
iteration: 100; return: -20.702; episode_length: 9.214
iteration: 101; return: -20.006; episode_length: 9.232
iteration: 102; return: -19.779; episode_length: 9.240
iteration: 103; return: -19.212; episode_length: 9.215
iteration: 104; return: -19.848; episode_length: 9.259
iteration: 105; return: -19.203; episode_length: 9.315
iteration: 106; return: -20.315; episode_length: 9.272
iteration: 107; return: -19.933; episode_length: 9.345
iteration: 108; return: -19.756; episode_length: 9.320
iteration: 109; return: -19.787; episode_length: 9.261
iteration: 110; return: -19.755; episode_length: 9.371
iteration: 111; return: -19.816; episode_length: 9.398
iteration: 112; return: -20.057; episode_length: 9.456
iteration: 113; return: -18.943; episode_length: 9.494
iteration: 114; return: -19.365; episode_length: 9.456
iteration: 115; return: -20.128; episode_length: 9.303
iteration: 116; return: -19.772; episode_length: 9.348
iteration: 117; return: -20.024; episode_length: 9.363
iteration: 118; return: -19.367; episode_length: 9.371
iteration: 119; return: -19.236; episode_length: 9.280
iteration: 120; return: -18.797; episode_length: 9.313
iteration: 121; return: -19.350; episode_length: 9.225
iteration: 122; return: -19.062; episode_length: 9.334
iteration: 123; return: -19.399; episode_length: 9.240
iteration: 124; return: -19.187; episode_length: 9.282
iteration: 125; return: -19.221; episode_length: 9.225
iteration: 126; return: -19.793; episode_length: 9.167
iteration: 127; return: -18.872; episode_length: 9.155
iteration: 128; return: -19.665; episode_length: 9.212
iteration: 129; return: -19.005; episode_length: 9.167
iteration: 130; return: -19.941; episode_length: 9.238
iteration: 131; return: -18.555; episode_length: 9.159
iteration: 132; return: -18.759; episode_length: 9.152
iteration: 133; return: -19.142; episode_length: 9.236
iteration: 134; return: -19.359; episode_length: 9.361
iteration: 135; return: -18.791; episode_length: 9.440
iteration: 136; return: -19.037; episode_length: 9.267
iteration: 137; return: -19.808; episode_length: 9.330
iteration: 138; return: -18.940; episode_length: 9.341
iteration: 139; return: -18.950; episode_length: 9.270
iteration: 140; return: -19.074; episode_length: 9.270
iteration: 141; return: -19.277; episode_length: 9.238
iteration: 142; return: -19.279; episode_length: 9.189
iteration: 143; return: -18.489; episode_length: 9.159
iteration: 144; return: -18.533; episode_length: 9.139
iteration: 145; return: -19.000; episode_length: 9.105
iteration: 146; return: -19.283; episode_length: 9.074
iteration: 147; return: -18.832; episode_length: 9.074
iteration: 148; return: -19.951; episode_length: 9.078
iteration: 149; return: -19.459; episode_length: 9.043
iteration: 150; return: -19.732; episode_length: 9.058
iteration: 151; return: -19.576; episode_length: 9.063
iteration: 152; return: -19.414; episode_length: 9.056
iteration: 153; return: -19.826; episode_length: 9.056
iteration: 154; return: -19.755; episode_length: 9.105
iteration: 155; return: -18.399; episode_length: 9.109
iteration: 156; return: -19.126; episode_length: 9.126
iteration: 157; return: -19.439; episode_length: 9.176
iteration: 158; return: -18.529; episode_length: 9.163
iteration: 159; return: -18.616; episode_length: 9.152
iteration: 160; return: -19.322; episode_length: 9.161
iteration: 161; return: -19.033; episode_length: 9.240
iteration: 162; return: -18.707; episode_length: 9.232
iteration: 163; return: -19.665; episode_length: 9.426
iteration: 164; return: -19.214; episode_length: 9.309
iteration: 165; return: -19.310; episode_length: 9.338
iteration: 166; return: -19.183; episode_length: 9.332
iteration: 167; return: -19.215; episode_length: 9.336
iteration: 168; return: -19.476; episode_length: 9.412
iteration: 169; return: -18.867; episode_length: 9.270
iteration: 170; return: -19.102; episode_length: 9.251
iteration: 171; return: -18.891; episode_length: 9.231
iteration: 172; return: -19.598; episode_length: 9.265
iteration: 173; return: -18.582; episode_length: 9.223
iteration: 174; return: -18.648; episode_length: 9.189
iteration: 175; return: -19.404; episode_length: 9.191
iteration: 176; return: -18.891; episode_length: 9.229
iteration: 177; return: -18.908; episode_length: 9.174
iteration: 178; return: -18.862; episode_length: 9.219
iteration: 179; return: -19.046; episode_length: 9.229
iteration: 180; return: -18.572; episode_length: 9.178
iteration: 181; return: -19.257; episode_length: 9.178
iteration: 182; return: -18.583; episode_length: 9.148
iteration: 183; return: -18.599; episode_length: 9.168
iteration: 184; return: -19.122; episode_length: 9.150
iteration: 185; return: -18.872; episode_length: 9.159
iteration: 186; return: -19.404; episode_length: 9.152
iteration: 187; return: -18.740; episode_length: 9.150
iteration: 188; return: -19.031; episode_length: 9.128
iteration: 189; return: -18.308; episode_length: 9.071
iteration: 190; return: -18.767; episode_length: 9.165
iteration: 191; return: -18.483; episode_length: 9.056
iteration: 192; return: -19.045; episode_length: 9.098
iteration: 193; return: -18.494; episode_length: 9.045
iteration: 194; return: -19.018; episode_length: 9.043
iteration: 195; return: -19.034; episode_length: 9.034
iteration: 196; return: -19.018; episode_length: 9.071
iteration: 197; return: -19.080; episode_length: 9.128
iteration: 198; return: -19.323; episode_length: 9.018
iteration: 199; return: -18.634; episode_length: 9.029
iteration: 200; return: -18.908; episode_length: 9.045
In [12]:
Copied!
policy_array = [
derive_deterministic_policy(mdp=highway_mdp, policy=model)
for model in model_checkpoints
]
policy_array = [
derive_deterministic_policy(mdp=highway_mdp, policy=model)
for model in model_checkpoints
]
In [13]:
Copied!
plot_policy_step_grid_map = make_plot_policy_step_function(
columns=10, rows=4, policy_over_time=policy_array
)
plot_policy_step_grid_map = make_plot_policy_step_function(
columns=10, rows=4, policy_over_time=policy_array
)
In [14]:
Copied!
if interactive_widgets:
import ipywidgets
from IPython.display import display
iteration_slider = ipywidgets.IntSlider(
min=0, max=len(model_checkpoints) - 1, step=1, value=0
)
w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)
display(w)
else:
plot_policy_step_grid_map(200)
if interactive_widgets:
import ipywidgets
from IPython.display import display
iteration_slider = ipywidgets.IntSlider(
min=0, max=len(model_checkpoints) - 1, step=1, value=0
)
w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)
display(w)
else:
plot_policy_step_grid_map(200)