Policy Gradient
In [1]:
Copied!
import os
from behavior_generation_lecture_python.mdp.policy import CategoricalPolicy
from behavior_generation_lecture_python.utils.grid_plotting import (
make_plot_policy_step_function,
)
from behavior_generation_lecture_python.mdp.mdp import (
GridMDP,
policy_gradient,
derive_deterministic_policy,
GRID_MDP_DICT,
HIGHWAY_MDP_DICT,
LC_RIGHT_ACTION,
STAY_IN_LANE_ACTION,
)
HIGHWAY_MDP_DICT["restrict_actions_to_available_states"] = False
import os
from behavior_generation_lecture_python.mdp.policy import CategoricalPolicy
from behavior_generation_lecture_python.utils.grid_plotting import (
make_plot_policy_step_function,
)
from behavior_generation_lecture_python.mdp.mdp import (
GridMDP,
policy_gradient,
derive_deterministic_policy,
GRID_MDP_DICT,
HIGHWAY_MDP_DICT,
LC_RIGHT_ACTION,
STAY_IN_LANE_ACTION,
)
HIGHWAY_MDP_DICT["restrict_actions_to_available_states"] = False
TOY EXAMPLE¶
In [2]:
Copied!
grid_mdp = GridMDP(**GRID_MDP_DICT)
grid_mdp = GridMDP(**GRID_MDP_DICT)
In [3]:
Copied!
policy = CategoricalPolicy(
sizes=[len(grid_mdp.initial_state), 32, len(grid_mdp.actions)],
actions=list(grid_mdp.actions),
)
policy = CategoricalPolicy(
sizes=[len(grid_mdp.initial_state), 32, len(grid_mdp.actions)],
actions=list(grid_mdp.actions),
)
In [4]:
Copied!
model_checkpoints = policy_gradient(
mdp=grid_mdp,
policy=policy,
iterations=100,
return_history=True,
)
model_checkpoints = policy_gradient(
mdp=grid_mdp,
policy=policy,
iterations=100,
return_history=True,
)
iteration: 1; return: -0.744; episode_length: 14.929
iteration: 2; return: -0.599; episode_length: 12.725
iteration: 3; return: -0.370; episode_length: 13.136
iteration: 4; return: -0.241; episode_length: 12.325
iteration: 5; return: -0.181; episode_length: 14.133
iteration: 6; return: -0.105; episode_length: 13.903
iteration: 7; return: -0.040; episode_length: 15.179
iteration: 8; return: 0.020; episode_length: 17.618
iteration: 9; return: 0.032; episode_length: 16.802
iteration: 10; return: 0.116; episode_length: 16.631
iteration: 11; return: -0.003; episode_length: 16.949
iteration: 12; return: 0.127; episode_length: 15.864
iteration: 13; return: 0.091; episode_length: 15.219
iteration: 14; return: 0.151; episode_length: 14.623
iteration: 15; return: 0.173; episode_length: 13.986
iteration: 16; return: 0.182; episode_length: 13.580
iteration: 17; return: 0.239; episode_length: 12.143
iteration: 18; return: 0.295; episode_length: 12.603
iteration: 19; return: 0.270; episode_length: 12.114
iteration: 20; return: 0.244; episode_length: 11.931
iteration: 21; return: 0.278; episode_length: 12.536
iteration: 22; return: 0.336; episode_length: 11.591
iteration: 23; return: 0.305; episode_length: 11.491
iteration: 24; return: 0.378; episode_length: 11.657
iteration: 25; return: 0.389; episode_length: 11.382
iteration: 26; return: 0.343; episode_length: 11.769
iteration: 27; return: 0.399; episode_length: 11.466
iteration: 28; return: 0.348; episode_length: 11.535
iteration: 29; return: 0.413; episode_length: 11.974
iteration: 30; return: 0.406; episode_length: 12.386
iteration: 31; return: 0.474; episode_length: 11.501
iteration: 32; return: 0.481; episode_length: 12.271
iteration: 33; return: 0.516; episode_length: 11.706
iteration: 34; return: 0.458; episode_length: 11.102
iteration: 35; return: 0.541; episode_length: 11.351
iteration: 36; return: 0.528; episode_length: 11.142
iteration: 37; return: 0.513; episode_length: 11.176
iteration: 38; return: 0.560; episode_length: 11.004
iteration: 39; return: 0.543; episode_length: 10.900
iteration: 40; return: 0.584; episode_length: 10.647
iteration: 41; return: 0.534; episode_length: 10.554
iteration: 42; return: 0.580; episode_length: 10.277
iteration: 43; return: 0.575; episode_length: 10.572
iteration: 44; return: 0.585; episode_length: 10.545
iteration: 45; return: 0.588; episode_length: 10.375
iteration: 46; return: 0.608; episode_length: 10.183
iteration: 47; return: 0.583; episode_length: 10.028
iteration: 48; return: 0.581; episode_length: 10.064
iteration: 49; return: 0.602; episode_length: 9.685
iteration: 50; return: 0.607; episode_length: 10.032
iteration: 51; return: 0.610; episode_length: 9.944
iteration: 52; return: 0.597; episode_length: 9.982
iteration: 53; return: 0.612; episode_length: 9.907
iteration: 54; return: 0.618; episode_length: 9.852
iteration: 55; return: 0.637; episode_length: 9.513
iteration: 56; return: 0.612; episode_length: 9.635
iteration: 57; return: 0.618; episode_length: 9.599
iteration: 58; return: 0.648; episode_length: 9.340
iteration: 59; return: 0.630; episode_length: 9.310
iteration: 60; return: 0.642; episode_length: 9.305
iteration: 61; return: 0.633; episode_length: 9.340
iteration: 62; return: 0.644; episode_length: 9.526
iteration: 63; return: 0.647; episode_length: 9.179
iteration: 64; return: 0.636; episode_length: 8.946
iteration: 65; return: 0.623; episode_length: 9.405
iteration: 66; return: 0.638; episode_length: 8.980
iteration: 67; return: 0.653; episode_length: 9.126
iteration: 68; return: 0.651; episode_length: 9.016
iteration: 69; return: 0.643; episode_length: 9.029
iteration: 70; return: 0.672; episode_length: 8.682
iteration: 71; return: 0.640; episode_length: 9.013
iteration: 72; return: 0.665; episode_length: 8.758
iteration: 73; return: 0.665; episode_length: 8.841
iteration: 74; return: 0.648; episode_length: 8.909
iteration: 75; return: 0.642; episode_length: 9.052
iteration: 76; return: 0.650; episode_length: 9.117
iteration: 77; return: 0.645; episode_length: 8.982
iteration: 78; return: 0.659; episode_length: 9.069
iteration: 79; return: 0.652; episode_length: 8.971
iteration: 80; return: 0.644; episode_length: 9.163
iteration: 81; return: 0.654; episode_length: 9.180
iteration: 82; return: 0.653; episode_length: 9.214
iteration: 83; return: 0.661; episode_length: 9.204
iteration: 84; return: 0.649; episode_length: 9.047
iteration: 85; return: 0.651; episode_length: 9.098
iteration: 86; return: 0.654; episode_length: 8.858
iteration: 87; return: 0.679; episode_length: 8.772
iteration: 88; return: 0.660; episode_length: 8.789
iteration: 89; return: 0.668; episode_length: 8.689
iteration: 90; return: 0.655; episode_length: 8.911
iteration: 91; return: 0.667; episode_length: 8.805
iteration: 92; return: 0.675; episode_length: 8.608
iteration: 93; return: 0.674; episode_length: 8.706
iteration: 94; return: 0.681; episode_length: 8.629
iteration: 95; return: 0.658; episode_length: 8.929
iteration: 96; return: 0.653; episode_length: 8.803
iteration: 97; return: 0.635; episode_length: 8.879
iteration: 98; return: 0.668; episode_length: 8.869
iteration: 99; return: 0.673; episode_length: 8.648
iteration: 100; return: 0.672; episode_length: 8.843
In [5]:
Copied!
policy_array = [
derive_deterministic_policy(mdp=grid_mdp, policy=model)
for model in model_checkpoints
]
policy_array = [
derive_deterministic_policy(mdp=grid_mdp, policy=model)
for model in model_checkpoints
]
In [6]:
Copied!
plot_policy_step_grid_map = make_plot_policy_step_function(
columns=4, rows=3, policy_over_time=policy_array
)
plot_policy_step_grid_map = make_plot_policy_step_function(
columns=4, rows=3, policy_over_time=policy_array
)
In [7]:
Copied!
interactive_widgets = not os.getenv("CI") == "true" # non-interative in CI
if interactive_widgets:
import ipywidgets
from IPython.display import display
iteration_slider = ipywidgets.IntSlider(
min=0, max=len(model_checkpoints) - 1, step=1, value=0
)
w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)
display(w)
else:
plot_policy_step_grid_map(100)
interactive_widgets = not os.getenv("CI") == "true" # non-interative in CI
if interactive_widgets:
import ipywidgets
from IPython.display import display
iteration_slider = ipywidgets.IntSlider(
min=0, max=len(model_checkpoints) - 1, step=1, value=0
)
w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)
display(w)
else:
plot_policy_step_grid_map(100)
HIGHWAY EXAMPLE¶
In [8]:
Copied!
if False:
# we will change this to true later on, to see the effect
HIGHWAY_MDP_DICT["transition_probabilities_per_action"][LC_RIGHT_ACTION] = [
(0.4, LC_RIGHT_ACTION),
(0.6, STAY_IN_LANE_ACTION),
]
if False:
# we will change this to true later on, to see the effect
HIGHWAY_MDP_DICT["transition_probabilities_per_action"][LC_RIGHT_ACTION] = [
(0.4, LC_RIGHT_ACTION),
(0.6, STAY_IN_LANE_ACTION),
]
In [9]:
Copied!
highway_mdp = GridMDP(**HIGHWAY_MDP_DICT)
highway_mdp = GridMDP(**HIGHWAY_MDP_DICT)
In [10]:
Copied!
policy = CategoricalPolicy(
sizes=[len(highway_mdp.initial_state), 32, len(highway_mdp.actions)],
actions=list(highway_mdp.actions),
)
policy = CategoricalPolicy(
sizes=[len(highway_mdp.initial_state), 32, len(highway_mdp.actions)],
actions=list(highway_mdp.actions),
)
In [11]:
Copied!
model_checkpoints = policy_gradient(
mdp=highway_mdp,
policy=policy,
iterations=200,
return_history=True,
)
model_checkpoints = policy_gradient(
mdp=highway_mdp,
policy=policy,
iterations=200,
return_history=True,
)
iteration: 1; return: -63.293; episode_length: 10.258
iteration: 2; return: -55.902; episode_length: 10.403
iteration: 3; return: -55.274; episode_length: 10.724
iteration: 4; return: -50.029; episode_length: 10.982
iteration: 5; return: -43.916; episode_length: 11.660
iteration: 6; return: -41.417; episode_length: 12.409
iteration: 7; return: -39.890; episode_length: 12.776
iteration: 8; return: -39.499; episode_length: 14.843
iteration: 9; return: -37.733; episode_length: 15.246
iteration: 10; return: -38.640; episode_length: 14.893
iteration: 11; return: -35.902; episode_length: 14.902
iteration: 12; return: -36.639; episode_length: 14.976
iteration: 13; return: -35.359; episode_length: 13.961
iteration: 14; return: -34.960; episode_length: 13.442
iteration: 15; return: -35.034; episode_length: 13.000
iteration: 16; return: -35.471; episode_length: 12.541
iteration: 17; return: -34.768; episode_length: 12.337
iteration: 18; return: -35.778; episode_length: 12.200
iteration: 19; return: -33.545; episode_length: 12.879
iteration: 20; return: -34.258; episode_length: 12.541
iteration: 21; return: -33.967; episode_length: 12.726
iteration: 22; return: -32.276; episode_length: 12.671
iteration: 23; return: -31.556; episode_length: 13.099
iteration: 24; return: -32.542; episode_length: 13.569
iteration: 25; return: -32.540; episode_length: 13.873
iteration: 26; return: -32.446; episode_length: 13.171
iteration: 27; return: -32.488; episode_length: 13.273
iteration: 28; return: -31.024; episode_length: 13.166
iteration: 29; return: -31.902; episode_length: 12.874
iteration: 30; return: -29.904; episode_length: 13.047
iteration: 31; return: -30.327; episode_length: 12.816
iteration: 32; return: -29.148; episode_length: 12.128
iteration: 33; return: -29.862; episode_length: 11.756
iteration: 34; return: -30.872; episode_length: 11.671
iteration: 35; return: -30.586; episode_length: 11.160
iteration: 36; return: -29.418; episode_length: 11.368
iteration: 37; return: -29.966; episode_length: 11.229
iteration: 38; return: -27.421; episode_length: 11.392
iteration: 39; return: -28.030; episode_length: 11.495
iteration: 40; return: -28.600; episode_length: 11.969
iteration: 41; return: -28.140; episode_length: 11.910
iteration: 42; return: -28.426; episode_length: 12.133
iteration: 43; return: -27.299; episode_length: 12.065
iteration: 44; return: -28.061; episode_length: 12.305
iteration: 45; return: -28.426; episode_length: 12.330
iteration: 46; return: -27.926; episode_length: 12.417
iteration: 47; return: -28.249; episode_length: 12.227
iteration: 48; return: -27.581; episode_length: 12.318
iteration: 49; return: -27.012; episode_length: 12.092
iteration: 50; return: -26.854; episode_length: 11.291
iteration: 51; return: -24.321; episode_length: 10.998
iteration: 52; return: -24.173; episode_length: 10.831
iteration: 53; return: -26.578; episode_length: 10.667
iteration: 54; return: -24.419; episode_length: 10.535
iteration: 55; return: -24.573; episode_length: 10.360
iteration: 56; return: -25.333; episode_length: 10.462
iteration: 57; return: -24.289; episode_length: 10.651
iteration: 58; return: -25.042; episode_length: 11.118
iteration: 59; return: -23.946; episode_length: 11.210
iteration: 60; return: -24.345; episode_length: 11.506
iteration: 61; return: -24.855; episode_length: 11.858
iteration: 62; return: -23.613; episode_length: 11.478
iteration: 63; return: -23.668; episode_length: 11.380
iteration: 64; return: -22.704; episode_length: 11.049
iteration: 65; return: -23.535; episode_length: 11.064
iteration: 66; return: -22.633; episode_length: 10.672
iteration: 67; return: -23.957; episode_length: 10.740
iteration: 68; return: -25.578; episode_length: 10.557
iteration: 69; return: -22.243; episode_length: 10.665
iteration: 70; return: -23.290; episode_length: 10.678
iteration: 71; return: -23.596; episode_length: 10.987
iteration: 72; return: -23.546; episode_length: 11.345
iteration: 73; return: -24.547; episode_length: 11.564
iteration: 74; return: -22.949; episode_length: 11.653
iteration: 75; return: -22.193; episode_length: 11.245
iteration: 76; return: -22.944; episode_length: 11.233
iteration: 77; return: -22.635; episode_length: 11.064
iteration: 78; return: -21.252; episode_length: 10.876
iteration: 79; return: -22.631; episode_length: 10.663
iteration: 80; return: -22.416; episode_length: 10.838
iteration: 81; return: -22.069; episode_length: 10.745
iteration: 82; return: -21.616; episode_length: 10.665
iteration: 83; return: -21.713; episode_length: 10.620
iteration: 84; return: -21.422; episode_length: 10.667
iteration: 85; return: -22.334; episode_length: 10.579
iteration: 86; return: -21.879; episode_length: 10.800
iteration: 87; return: -20.855; episode_length: 10.701
iteration: 88; return: -20.247; episode_length: 10.738
iteration: 89; return: -21.292; episode_length: 10.766
iteration: 90; return: -20.994; episode_length: 10.812
iteration: 91; return: -20.985; episode_length: 10.780
iteration: 92; return: -21.064; episode_length: 11.004
iteration: 93; return: -21.720; episode_length: 10.861
iteration: 94; return: -20.582; episode_length: 10.825
iteration: 95; return: -22.360; episode_length: 11.060
iteration: 96; return: -21.166; episode_length: 10.954
iteration: 97; return: -20.906; episode_length: 10.758
iteration: 98; return: -20.943; episode_length: 10.647
iteration: 99; return: -21.673; episode_length: 10.770
iteration: 100; return: -21.459; episode_length: 11.049
iteration: 101; return: -20.591; episode_length: 10.730
iteration: 102; return: -20.594; episode_length: 10.857
iteration: 103; return: -21.664; episode_length: 10.904
iteration: 104; return: -20.536; episode_length: 11.042
iteration: 105; return: -20.262; episode_length: 10.759
iteration: 106; return: -20.969; episode_length: 10.919
iteration: 107; return: -21.049; episode_length: 11.240
iteration: 108; return: -21.768; episode_length: 11.181
iteration: 109; return: -19.741; episode_length: 11.089
iteration: 110; return: -20.478; episode_length: 10.989
iteration: 111; return: -20.901; episode_length: 11.022
iteration: 112; return: -20.400; episode_length: 10.781
iteration: 113; return: -21.319; episode_length: 10.863
iteration: 114; return: -20.953; episode_length: 10.715
iteration: 115; return: -20.493; episode_length: 10.680
iteration: 116; return: -19.543; episode_length: 10.699
iteration: 117; return: -20.640; episode_length: 10.604
iteration: 118; return: -20.279; episode_length: 10.806
iteration: 119; return: -20.838; episode_length: 10.663
iteration: 120; return: -20.528; episode_length: 10.666
iteration: 121; return: -20.380; episode_length: 10.635
iteration: 122; return: -20.126; episode_length: 10.505
iteration: 123; return: -19.887; episode_length: 10.433
iteration: 124; return: -20.515; episode_length: 10.425
iteration: 125; return: -20.678; episode_length: 10.477
iteration: 126; return: -20.052; episode_length: 10.410
iteration: 127; return: -21.219; episode_length: 10.354
iteration: 128; return: -20.176; episode_length: 10.633
iteration: 129; return: -19.921; episode_length: 10.475
iteration: 130; return: -20.850; episode_length: 10.608
iteration: 131; return: -20.546; episode_length: 10.663
iteration: 132; return: -21.327; episode_length: 10.844
iteration: 133; return: -20.251; episode_length: 10.713
iteration: 134; return: -20.874; episode_length: 10.863
iteration: 135; return: -19.704; episode_length: 10.717
iteration: 136; return: -19.800; episode_length: 10.651
iteration: 137; return: -19.680; episode_length: 10.604
iteration: 138; return: -20.267; episode_length: 10.665
iteration: 139; return: -19.663; episode_length: 10.462
iteration: 140; return: -20.565; episode_length: 10.568
iteration: 141; return: -20.354; episode_length: 10.819
iteration: 142; return: -19.530; episode_length: 10.491
iteration: 143; return: -19.726; episode_length: 10.555
iteration: 144; return: -19.670; episode_length: 10.513
iteration: 145; return: -20.004; episode_length: 10.579
iteration: 146; return: -20.718; episode_length: 10.515
iteration: 147; return: -20.081; episode_length: 10.624
iteration: 148; return: -19.120; episode_length: 10.561
iteration: 149; return: -19.302; episode_length: 10.643
iteration: 150; return: -19.878; episode_length: 10.553
iteration: 151; return: -20.687; episode_length: 10.749
iteration: 152; return: -19.676; episode_length: 10.506
iteration: 153; return: -21.254; episode_length: 10.667
iteration: 154; return: -20.009; episode_length: 10.703
iteration: 155; return: -20.953; episode_length: 10.649
iteration: 156; return: -19.511; episode_length: 10.602
iteration: 157; return: -20.664; episode_length: 10.640
iteration: 158; return: -20.268; episode_length: 10.394
iteration: 159; return: -20.483; episode_length: 10.655
iteration: 160; return: -19.646; episode_length: 10.505
iteration: 161; return: -19.613; episode_length: 10.590
iteration: 162; return: -19.874; episode_length: 10.528
iteration: 163; return: -20.616; episode_length: 10.555
iteration: 164; return: -19.847; episode_length: 10.354
iteration: 165; return: -20.165; episode_length: 10.421
iteration: 166; return: -20.568; episode_length: 10.305
iteration: 167; return: -20.992; episode_length: 10.366
iteration: 168; return: -19.042; episode_length: 10.401
iteration: 169; return: -19.625; episode_length: 10.362
iteration: 170; return: -20.440; episode_length: 10.573
iteration: 171; return: -20.208; episode_length: 10.444
iteration: 172; return: -19.744; episode_length: 10.597
iteration: 173; return: -19.333; episode_length: 10.535
iteration: 174; return: -19.049; episode_length: 10.647
iteration: 175; return: -20.425; episode_length: 10.789
iteration: 176; return: -19.393; episode_length: 10.848
iteration: 177; return: -20.523; episode_length: 10.908
iteration: 178; return: -21.020; episode_length: 10.967
iteration: 179; return: -20.275; episode_length: 10.908
iteration: 180; return: -19.856; episode_length: 10.738
iteration: 181; return: -19.318; episode_length: 10.763
iteration: 182; return: -20.545; episode_length: 10.954
iteration: 183; return: -19.504; episode_length: 10.627
iteration: 184; return: -19.790; episode_length: 10.515
iteration: 185; return: -19.957; episode_length: 10.674
iteration: 186; return: -18.892; episode_length: 10.551
iteration: 187; return: -20.314; episode_length: 10.491
iteration: 188; return: -19.899; episode_length: 10.584
iteration: 189; return: -19.449; episode_length: 10.499
iteration: 190; return: -19.575; episode_length: 10.535
iteration: 191; return: -20.466; episode_length: 10.612
iteration: 192; return: -19.296; episode_length: 10.592
iteration: 193; return: -19.472; episode_length: 10.330
iteration: 194; return: -19.681; episode_length: 10.563
iteration: 195; return: -19.243; episode_length: 10.405
iteration: 196; return: -19.628; episode_length: 10.577
iteration: 197; return: -20.532; episode_length: 10.647
iteration: 198; return: -18.709; episode_length: 10.357
iteration: 199; return: -20.233; episode_length: 10.573
iteration: 200; return: -19.646; episode_length: 10.619
In [12]:
Copied!
policy_array = [
derive_deterministic_policy(mdp=highway_mdp, policy=model)
for model in model_checkpoints
]
policy_array = [
derive_deterministic_policy(mdp=highway_mdp, policy=model)
for model in model_checkpoints
]
In [13]:
Copied!
plot_policy_step_grid_map = make_plot_policy_step_function(
columns=10, rows=4, policy_over_time=policy_array
)
plot_policy_step_grid_map = make_plot_policy_step_function(
columns=10, rows=4, policy_over_time=policy_array
)
In [14]:
Copied!
if interactive_widgets:
import ipywidgets
from IPython.display import display
iteration_slider = ipywidgets.IntSlider(
min=0, max=len(model_checkpoints) - 1, step=1, value=0
)
w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)
display(w)
else:
plot_policy_step_grid_map(200)
if interactive_widgets:
import ipywidgets
from IPython.display import display
iteration_slider = ipywidgets.IntSlider(
min=0, max=len(model_checkpoints) - 1, step=1, value=0
)
w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)
display(w)
else:
plot_policy_step_grid_map(200)