Policy Gradient
In [1]:
Copied!
import os
from behavior_generation_lecture_python.mdp.policy import CategorialPolicy
from behavior_generation_lecture_python.utils.grid_plotting import (
make_plot_policy_step_function,
)
from behavior_generation_lecture_python.mdp.mdp import (
GridMDP,
policy_gradient,
derive_deterministic_policy,
GRID_MDP_DICT,
HIGHWAY_MDP_DICT,
LC_RIGHT_ACTION,
STAY_IN_LANE_ACTION,
)
HIGHWAY_MDP_DICT["restrict_actions_to_available_states"] = False
import os
from behavior_generation_lecture_python.mdp.policy import CategorialPolicy
from behavior_generation_lecture_python.utils.grid_plotting import (
make_plot_policy_step_function,
)
from behavior_generation_lecture_python.mdp.mdp import (
GridMDP,
policy_gradient,
derive_deterministic_policy,
GRID_MDP_DICT,
HIGHWAY_MDP_DICT,
LC_RIGHT_ACTION,
STAY_IN_LANE_ACTION,
)
HIGHWAY_MDP_DICT["restrict_actions_to_available_states"] = False
TOY EXAMPLE¶
In [2]:
Copied!
grid_mdp = GridMDP(**GRID_MDP_DICT)
grid_mdp = GridMDP(**GRID_MDP_DICT)
In [3]:
Copied!
policy = CategorialPolicy(
sizes=[len(grid_mdp.initial_state), 32, len(grid_mdp.actions)],
actions=list(grid_mdp.actions),
)
policy = CategorialPolicy(
sizes=[len(grid_mdp.initial_state), 32, len(grid_mdp.actions)],
actions=list(grid_mdp.actions),
)
In [4]:
Copied!
model_checkpoints = policy_gradient(
mdp=grid_mdp,
policy=policy,
iterations=100,
return_history=True,
)
model_checkpoints = policy_gradient(
mdp=grid_mdp,
policy=policy,
iterations=100,
return_history=True,
)
iteration: 1; return: -1.860; episode_length: 35.347
iteration: 2; return: -1.248; episode_length: 21.838
iteration: 3; return: -1.071; episode_length: 16.912
iteration: 4; return: -0.887; episode_length: 14.144
iteration: 5; return: -0.855; episode_length: 14.181
iteration: 6; return: -0.756; episode_length: 12.765
iteration: 7; return: -0.679; episode_length: 11.386
iteration: 8; return: -0.499; episode_length: 11.192
iteration: 9; return: -0.417; episode_length: 12.026
iteration: 10; return: -0.374; episode_length: 11.380
iteration: 11; return: -0.347; episode_length: 12.089
iteration: 12; return: -0.203; episode_length: 11.405
iteration: 13; return: -0.071; episode_length: 12.922
iteration: 14; return: -0.061; episode_length: 13.389
iteration: 15; return: -0.101; episode_length: 13.992
iteration: 16; return: -0.026; episode_length: 15.716
iteration: 17; return: -0.040; episode_length: 16.379
iteration: 18; return: -0.041; episode_length: 17.237
iteration: 19; return: -0.002; episode_length: 18.261
iteration: 20; return: -0.008; episode_length: 16.846
iteration: 21; return: 0.049; episode_length: 16.635
iteration: 22; return: 0.114; episode_length: 18.229
iteration: 23; return: 0.145; episode_length: 15.994
iteration: 24; return: 0.107; episode_length: 17.678
iteration: 25; return: 0.112; episode_length: 16.959
iteration: 26; return: 0.213; episode_length: 15.437
iteration: 27; return: 0.168; episode_length: 14.452
iteration: 28; return: 0.172; episode_length: 12.846
iteration: 29; return: 0.293; episode_length: 12.622
iteration: 30; return: 0.248; episode_length: 12.229
iteration: 31; return: 0.276; episode_length: 12.094
iteration: 32; return: 0.324; episode_length: 11.129
iteration: 33; return: 0.331; episode_length: 11.154
iteration: 34; return: 0.369; episode_length: 11.192
iteration: 35; return: 0.326; episode_length: 11.293
iteration: 36; return: 0.292; episode_length: 11.622
iteration: 37; return: 0.312; episode_length: 11.181
iteration: 38; return: 0.319; episode_length: 11.633
iteration: 39; return: 0.309; episode_length: 11.145
iteration: 40; return: 0.327; episode_length: 11.434
iteration: 41; return: 0.351; episode_length: 11.637
iteration: 42; return: 0.338; episode_length: 12.270
iteration: 43; return: 0.389; episode_length: 11.800
iteration: 44; return: 0.424; episode_length: 12.131
iteration: 45; return: 0.375; episode_length: 11.719
iteration: 46; return: 0.419; episode_length: 11.767
iteration: 47; return: 0.427; episode_length: 12.170
iteration: 48; return: 0.447; episode_length: 11.396
iteration: 49; return: 0.441; episode_length: 11.696
iteration: 50; return: 0.501; episode_length: 11.333
iteration: 51; return: 0.491; episode_length: 11.342
iteration: 52; return: 0.457; episode_length: 10.965
iteration: 53; return: 0.505; episode_length: 10.624
iteration: 54; return: 0.493; episode_length: 10.281
iteration: 55; return: 0.551; episode_length: 10.537
iteration: 56; return: 0.526; episode_length: 10.537
iteration: 57; return: 0.563; episode_length: 10.467
iteration: 58; return: 0.580; episode_length: 10.087
iteration: 59; return: 0.592; episode_length: 10.371
iteration: 60; return: 0.580; episode_length: 10.356
iteration: 61; return: 0.591; episode_length: 10.123
iteration: 62; return: 0.613; episode_length: 10.177
iteration: 63; return: 0.610; episode_length: 9.872
iteration: 64; return: 0.601; episode_length: 9.899
iteration: 65; return: 0.582; episode_length: 9.956
iteration: 66; return: 0.626; episode_length: 9.760
iteration: 67; return: 0.619; episode_length: 9.647
iteration: 68; return: 0.611; episode_length: 9.572
iteration: 69; return: 0.619; episode_length: 9.566
iteration: 70; return: 0.629; episode_length: 9.424
iteration: 71; return: 0.626; episode_length: 9.420
iteration: 72; return: 0.637; episode_length: 9.408
iteration: 73; return: 0.665; episode_length: 9.180
iteration: 74; return: 0.632; episode_length: 9.264
iteration: 75; return: 0.620; episode_length: 9.214
iteration: 76; return: 0.644; episode_length: 9.004
iteration: 77; return: 0.643; episode_length: 9.195
iteration: 78; return: 0.632; episode_length: 8.959
iteration: 79; return: 0.644; episode_length: 9.087
iteration: 80; return: 0.648; episode_length: 8.982
iteration: 81; return: 0.654; episode_length: 8.775
iteration: 82; return: 0.654; episode_length: 9.193
iteration: 83; return: 0.670; episode_length: 8.733
iteration: 84; return: 0.659; episode_length: 9.078
iteration: 85; return: 0.643; episode_length: 9.111
iteration: 86; return: 0.652; episode_length: 8.971
iteration: 87; return: 0.653; episode_length: 8.968
iteration: 88; return: 0.652; episode_length: 8.978
iteration: 89; return: 0.656; episode_length: 8.720
iteration: 90; return: 0.661; episode_length: 8.772
iteration: 91; return: 0.661; episode_length: 8.860
iteration: 92; return: 0.630; episode_length: 9.238
iteration: 93; return: 0.651; episode_length: 8.843
iteration: 94; return: 0.656; episode_length: 8.883
iteration: 95; return: 0.673; episode_length: 8.659
iteration: 96; return: 0.666; episode_length: 8.996
iteration: 97; return: 0.654; episode_length: 8.930
iteration: 98; return: 0.674; episode_length: 8.895
iteration: 99; return: 0.656; episode_length: 8.731
iteration: 100; return: 0.658; episode_length: 8.836
In [5]:
Copied!
policy_array = [
derive_deterministic_policy(mdp=grid_mdp, policy=model)
for model in model_checkpoints
]
policy_array = [
derive_deterministic_policy(mdp=grid_mdp, policy=model)
for model in model_checkpoints
]
In [6]:
Copied!
plot_policy_step_grid_map = make_plot_policy_step_function(
columns=4, rows=3, policy_over_time=policy_array
)
plot_policy_step_grid_map = make_plot_policy_step_function(
columns=4, rows=3, policy_over_time=policy_array
)
In [7]:
Copied!
interactive_widgets = not os.getenv("CI") == "true" # non-interative in CI
if interactive_widgets:
import ipywidgets
from IPython.display import display
iteration_slider = ipywidgets.IntSlider(
min=0, max=len(model_checkpoints) - 1, step=1, value=0
)
w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)
display(w)
else:
plot_policy_step_grid_map(100)
interactive_widgets = not os.getenv("CI") == "true" # non-interative in CI
if interactive_widgets:
import ipywidgets
from IPython.display import display
iteration_slider = ipywidgets.IntSlider(
min=0, max=len(model_checkpoints) - 1, step=1, value=0
)
w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)
display(w)
else:
plot_policy_step_grid_map(100)
HIGHWAY EXAMPLE¶
In [8]:
Copied!
if False:
# we will change this to true later on, to see the effect
HIGHWAY_MDP_DICT["transition_probabilities_per_action"][LC_RIGHT_ACTION] = [
(0.4, LC_RIGHT_ACTION),
(0.6, STAY_IN_LANE_ACTION),
]
if False:
# we will change this to true later on, to see the effect
HIGHWAY_MDP_DICT["transition_probabilities_per_action"][LC_RIGHT_ACTION] = [
(0.4, LC_RIGHT_ACTION),
(0.6, STAY_IN_LANE_ACTION),
]
In [9]:
Copied!
highway_mdp = GridMDP(**HIGHWAY_MDP_DICT)
highway_mdp = GridMDP(**HIGHWAY_MDP_DICT)
In [10]:
Copied!
policy = CategorialPolicy(
sizes=[len(highway_mdp.initial_state), 32, len(highway_mdp.actions)],
actions=list(highway_mdp.actions),
)
policy = CategorialPolicy(
sizes=[len(highway_mdp.initial_state), 32, len(highway_mdp.actions)],
actions=list(highway_mdp.actions),
)
In [11]:
Copied!
model_checkpoints = policy_gradient(
mdp=highway_mdp,
policy=policy,
iterations=200,
return_history=True,
)
model_checkpoints = policy_gradient(
mdp=highway_mdp,
policy=policy,
iterations=200,
return_history=True,
)
iteration: 1; return: -49.962; episode_length: 11.266
iteration: 2; return: -44.062; episode_length: 12.874
iteration: 3; return: -39.237; episode_length: 14.144
iteration: 4; return: -37.577; episode_length: 14.507
iteration: 5; return: -37.910; episode_length: 15.553
iteration: 6; return: -38.254; episode_length: 15.539
iteration: 7; return: -36.065; episode_length: 14.758
iteration: 8; return: -34.793; episode_length: 14.167
iteration: 9; return: -34.585; episode_length: 13.086
iteration: 10; return: -31.758; episode_length: 11.802
iteration: 11; return: -35.555; episode_length: 10.976
iteration: 12; return: -34.165; episode_length: 10.829
iteration: 13; return: -34.876; episode_length: 10.924
iteration: 14; return: -32.337; episode_length: 11.319
iteration: 15; return: -30.360; episode_length: 12.089
iteration: 16; return: -32.346; episode_length: 12.641
iteration: 17; return: -31.231; episode_length: 12.879
iteration: 18; return: -32.442; episode_length: 13.561
iteration: 19; return: -30.921; episode_length: 13.686
iteration: 20; return: -30.644; episode_length: 13.372
iteration: 21; return: -29.849; episode_length: 12.811
iteration: 22; return: -28.933; episode_length: 12.010
iteration: 23; return: -28.830; episode_length: 12.010
iteration: 24; return: -27.885; episode_length: 11.584
iteration: 25; return: -28.366; episode_length: 11.247
iteration: 26; return: -27.878; episode_length: 11.064
iteration: 27; return: -28.793; episode_length: 11.136
iteration: 28; return: -30.444; episode_length: 10.961
iteration: 29; return: -28.262; episode_length: 11.104
iteration: 30; return: -27.875; episode_length: 11.206
iteration: 31; return: -25.935; episode_length: 11.706
iteration: 32; return: -26.724; episode_length: 11.800
iteration: 33; return: -25.973; episode_length: 12.409
iteration: 34; return: -26.752; episode_length: 12.141
iteration: 35; return: -26.736; episode_length: 12.138
iteration: 36; return: -26.653; episode_length: 12.391
iteration: 37; return: -26.335; episode_length: 12.432
iteration: 38; return: -25.373; episode_length: 12.227
iteration: 39; return: -25.208; episode_length: 11.983
iteration: 40; return: -25.974; episode_length: 11.936
iteration: 41; return: -25.116; episode_length: 11.818
iteration: 42; return: -26.210; episode_length: 11.699
iteration: 43; return: -25.119; episode_length: 11.696
iteration: 44; return: -25.935; episode_length: 12.055
iteration: 45; return: -24.699; episode_length: 11.955
iteration: 46; return: -25.964; episode_length: 11.964
iteration: 47; return: -23.993; episode_length: 12.034
iteration: 48; return: -25.022; episode_length: 11.971
iteration: 49; return: -25.339; episode_length: 12.029
iteration: 50; return: -23.592; episode_length: 11.626
iteration: 51; return: -24.297; episode_length: 11.167
iteration: 52; return: -25.547; episode_length: 11.185
iteration: 53; return: -25.647; episode_length: 11.118
iteration: 54; return: -24.073; episode_length: 11.109
iteration: 55; return: -24.576; episode_length: 11.479
iteration: 56; return: -22.794; episode_length: 11.319
iteration: 57; return: -24.129; episode_length: 11.726
iteration: 58; return: -23.415; episode_length: 11.735
iteration: 59; return: -24.665; episode_length: 12.063
iteration: 60; return: -22.805; episode_length: 11.501
iteration: 61; return: -23.464; episode_length: 11.577
iteration: 62; return: -23.079; episode_length: 11.264
iteration: 63; return: -21.863; episode_length: 11.229
iteration: 64; return: -23.318; episode_length: 11.188
iteration: 65; return: -23.222; episode_length: 10.880
iteration: 66; return: -22.824; episode_length: 10.876
iteration: 67; return: -23.484; episode_length: 10.870
iteration: 68; return: -23.829; episode_length: 10.699
iteration: 69; return: -22.567; episode_length: 10.791
iteration: 70; return: -22.125; episode_length: 10.961
iteration: 71; return: -22.712; episode_length: 11.264
iteration: 72; return: -20.980; episode_length: 11.252
iteration: 73; return: -21.360; episode_length: 11.412
iteration: 74; return: -23.270; episode_length: 11.069
iteration: 75; return: -22.427; episode_length: 11.305
iteration: 76; return: -21.793; episode_length: 11.149
iteration: 77; return: -22.356; episode_length: 11.143
iteration: 78; return: -22.479; episode_length: 10.943
iteration: 79; return: -21.461; episode_length: 10.928
iteration: 80; return: -20.806; episode_length: 10.939
iteration: 81; return: -20.534; episode_length: 10.874
iteration: 82; return: -21.002; episode_length: 11.015
iteration: 83; return: -20.987; episode_length: 10.956
iteration: 84; return: -21.575; episode_length: 11.020
iteration: 85; return: -21.885; episode_length: 11.042
iteration: 86; return: -21.045; episode_length: 11.147
iteration: 87; return: -21.620; episode_length: 11.238
iteration: 88; return: -21.479; episode_length: 11.263
iteration: 89; return: -20.176; episode_length: 11.293
iteration: 90; return: -22.281; episode_length: 11.427
iteration: 91; return: -22.472; episode_length: 11.394
iteration: 92; return: -21.088; episode_length: 10.950
iteration: 93; return: -21.441; episode_length: 11.210
iteration: 94; return: -22.069; episode_length: 11.082
iteration: 95; return: -20.927; episode_length: 11.068
iteration: 96; return: -20.218; episode_length: 10.908
iteration: 97; return: -21.226; episode_length: 10.998
iteration: 98; return: -20.926; episode_length: 10.924
iteration: 99; return: -20.788; episode_length: 10.808
iteration: 100; return: -20.427; episode_length: 10.857
iteration: 101; return: -20.748; episode_length: 10.893
iteration: 102; return: -20.402; episode_length: 10.803
iteration: 103; return: -20.773; episode_length: 10.941
iteration: 104; return: -20.400; episode_length: 10.717
iteration: 105; return: -20.447; episode_length: 10.759
iteration: 106; return: -20.484; episode_length: 10.635
iteration: 107; return: -20.340; episode_length: 10.635
iteration: 108; return: -20.299; episode_length: 10.697
iteration: 109; return: -21.073; episode_length: 10.742
iteration: 110; return: -20.324; episode_length: 10.674
iteration: 111; return: -19.780; episode_length: 10.608
iteration: 112; return: -20.023; episode_length: 10.515
iteration: 113; return: -21.132; episode_length: 10.667
iteration: 114; return: -19.947; episode_length: 10.628
iteration: 115; return: -20.256; episode_length: 10.780
iteration: 116; return: -21.162; episode_length: 10.678
iteration: 117; return: -19.903; episode_length: 10.740
iteration: 118; return: -20.231; episode_length: 10.784
iteration: 119; return: -20.209; episode_length: 10.674
iteration: 120; return: -20.028; episode_length: 10.932
iteration: 121; return: -20.480; episode_length: 10.872
iteration: 122; return: -20.039; episode_length: 10.761
iteration: 123; return: -19.840; episode_length: 10.835
iteration: 124; return: -19.645; episode_length: 10.699
iteration: 125; return: -19.839; episode_length: 10.740
iteration: 126; return: -20.599; episode_length: 10.565
iteration: 127; return: -19.318; episode_length: 10.676
iteration: 128; return: -20.722; episode_length: 10.531
iteration: 129; return: -20.294; episode_length: 10.606
iteration: 130; return: -19.752; episode_length: 10.641
iteration: 131; return: -19.086; episode_length: 10.740
iteration: 132; return: -19.881; episode_length: 10.778
iteration: 133; return: -20.090; episode_length: 10.703
iteration: 134; return: -21.153; episode_length: 10.655
iteration: 135; return: -19.333; episode_length: 10.568
iteration: 136; return: -20.429; episode_length: 10.464
iteration: 137; return: -20.137; episode_length: 10.565
iteration: 138; return: -20.259; episode_length: 10.697
iteration: 139; return: -19.477; episode_length: 10.612
iteration: 140; return: -19.994; episode_length: 10.643
iteration: 141; return: -19.614; episode_length: 10.493
iteration: 142; return: -19.669; episode_length: 10.604
iteration: 143; return: -19.660; episode_length: 10.697
iteration: 144; return: -19.996; episode_length: 10.600
iteration: 145; return: -20.326; episode_length: 10.649
iteration: 146; return: -20.464; episode_length: 10.742
iteration: 147; return: -19.798; episode_length: 10.624
iteration: 148; return: -20.199; episode_length: 10.604
iteration: 149; return: -19.851; episode_length: 10.521
iteration: 150; return: -18.734; episode_length: 10.469
iteration: 151; return: -18.912; episode_length: 10.517
iteration: 152; return: -19.238; episode_length: 10.459
iteration: 153; return: -20.845; episode_length: 10.740
iteration: 154; return: -19.394; episode_length: 10.640
iteration: 155; return: -20.064; episode_length: 10.676
iteration: 156; return: -20.198; episode_length: 10.547
iteration: 157; return: -20.134; episode_length: 10.780
iteration: 158; return: -19.809; episode_length: 10.724
iteration: 159; return: -19.830; episode_length: 10.742
iteration: 160; return: -20.059; episode_length: 10.937
iteration: 161; return: -18.605; episode_length: 10.688
iteration: 162; return: -19.940; episode_length: 10.401
iteration: 163; return: -20.004; episode_length: 10.676
iteration: 164; return: -19.864; episode_length: 10.816
iteration: 165; return: -19.998; episode_length: 10.618
iteration: 166; return: -19.658; episode_length: 10.831
iteration: 167; return: -19.616; episode_length: 10.488
iteration: 168; return: -19.966; episode_length: 10.511
iteration: 169; return: -18.456; episode_length: 10.602
iteration: 170; return: -18.429; episode_length: 10.388
iteration: 171; return: -18.884; episode_length: 10.568
iteration: 172; return: -19.333; episode_length: 10.497
iteration: 173; return: -19.368; episode_length: 10.759
iteration: 174; return: -19.690; episode_length: 10.726
iteration: 175; return: -20.052; episode_length: 10.751
iteration: 176; return: -19.330; episode_length: 10.713
iteration: 177; return: -19.758; episode_length: 10.724
iteration: 178; return: -18.989; episode_length: 10.645
iteration: 179; return: -20.004; episode_length: 10.610
iteration: 180; return: -19.362; episode_length: 10.670
iteration: 181; return: -19.092; episode_length: 10.521
iteration: 182; return: -18.968; episode_length: 10.653
iteration: 183; return: -18.490; episode_length: 10.624
iteration: 184; return: -20.109; episode_length: 10.493
iteration: 185; return: -19.113; episode_length: 10.680
iteration: 186; return: -19.328; episode_length: 10.649
iteration: 187; return: -19.526; episode_length: 10.643
iteration: 188; return: -20.673; episode_length: 10.697
iteration: 189; return: -18.918; episode_length: 10.844
iteration: 190; return: -19.210; episode_length: 10.738
iteration: 191; return: -18.968; episode_length: 10.742
iteration: 192; return: -20.476; episode_length: 10.639
iteration: 193; return: -18.046; episode_length: 10.433
iteration: 194; return: -19.323; episode_length: 10.577
iteration: 195; return: -19.368; episode_length: 10.535
iteration: 196; return: -19.139; episode_length: 10.549
iteration: 197; return: -19.536; episode_length: 10.553
iteration: 198; return: -19.886; episode_length: 10.590
iteration: 199; return: -19.071; episode_length: 10.423
iteration: 200; return: -19.212; episode_length: 10.600
In [12]:
Copied!
policy_array = [
derive_deterministic_policy(mdp=highway_mdp, policy=model)
for model in model_checkpoints
]
policy_array = [
derive_deterministic_policy(mdp=highway_mdp, policy=model)
for model in model_checkpoints
]
In [13]:
Copied!
plot_policy_step_grid_map = make_plot_policy_step_function(
columns=10, rows=4, policy_over_time=policy_array
)
plot_policy_step_grid_map = make_plot_policy_step_function(
columns=10, rows=4, policy_over_time=policy_array
)
In [14]:
Copied!
if interactive_widgets:
import ipywidgets
from IPython.display import display
iteration_slider = ipywidgets.IntSlider(
min=0, max=len(model_checkpoints) - 1, step=1, value=0
)
w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)
display(w)
else:
plot_policy_step_grid_map(200)
if interactive_widgets:
import ipywidgets
from IPython.display import display
iteration_slider = ipywidgets.IntSlider(
min=0, max=len(model_checkpoints) - 1, step=1, value=0
)
w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)
display(w)
else:
plot_policy_step_grid_map(200)