Policy Gradient
In [1]:
Copied!
import os
from behavior_generation_lecture_python.mdp.policy import CategoricalPolicy
from behavior_generation_lecture_python.utils.grid_plotting import (
make_plot_policy_step_function,
)
from behavior_generation_lecture_python.mdp.mdp import (
GridMDP,
policy_gradient,
derive_deterministic_policy,
GRID_MDP_DICT,
HIGHWAY_MDP_DICT,
LC_RIGHT_ACTION,
STAY_IN_LANE_ACTION,
)
HIGHWAY_MDP_DICT["restrict_actions_to_available_states"] = False
import os
from behavior_generation_lecture_python.mdp.policy import CategoricalPolicy
from behavior_generation_lecture_python.utils.grid_plotting import (
make_plot_policy_step_function,
)
from behavior_generation_lecture_python.mdp.mdp import (
GridMDP,
policy_gradient,
derive_deterministic_policy,
GRID_MDP_DICT,
HIGHWAY_MDP_DICT,
LC_RIGHT_ACTION,
STAY_IN_LANE_ACTION,
)
HIGHWAY_MDP_DICT["restrict_actions_to_available_states"] = False
TOY EXAMPLE¶
In [2]:
Copied!
grid_mdp = GridMDP(**GRID_MDP_DICT)
grid_mdp = GridMDP(**GRID_MDP_DICT)
In [3]:
Copied!
policy = CategoricalPolicy(
sizes=[len(grid_mdp.initial_state), 32, len(grid_mdp.actions)],
actions=list(grid_mdp.actions),
)
policy = CategoricalPolicy(
sizes=[len(grid_mdp.initial_state), 32, len(grid_mdp.actions)],
actions=list(grid_mdp.actions),
)
In [4]:
Copied!
model_checkpoints = policy_gradient(
mdp=grid_mdp,
policy=policy,
iterations=100,
return_history=True,
)
model_checkpoints = policy_gradient(
mdp=grid_mdp,
policy=policy,
iterations=100,
return_history=True,
)
iteration: 1; return: -1.163; episode_length: 17.353
iteration: 2; return: -1.081; episode_length: 16.325
iteration: 3; return: -0.955; episode_length: 13.255
iteration: 4; return: -0.810; episode_length: 12.956
iteration: 5; return: -0.706; episode_length: 10.427
iteration: 6; return: -0.648; episode_length: 9.278
iteration: 7; return: -0.468; episode_length: 9.458
iteration: 8; return: -0.390; episode_length: 9.265
iteration: 9; return: -0.339; episode_length: 10.252
iteration: 10; return: -0.240; episode_length: 10.397
iteration: 11; return: -0.237; episode_length: 11.247
iteration: 12; return: -0.064; episode_length: 12.012
iteration: 13; return: -0.088; episode_length: 13.189
iteration: 14; return: -0.128; episode_length: 14.911
iteration: 15; return: -0.017; episode_length: 16.132
iteration: 16; return: -0.041; episode_length: 15.444
iteration: 17; return: -0.096; episode_length: 15.623
iteration: 18; return: -0.060; episode_length: 16.395
iteration: 19; return: 0.019; episode_length: 15.021
iteration: 20; return: -0.002; episode_length: 15.879
iteration: 21; return: 0.047; episode_length: 15.636
iteration: 22; return: 0.036; episode_length: 14.522
iteration: 23; return: 0.034; episode_length: 14.133
iteration: 24; return: 0.061; episode_length: 13.097
iteration: 25; return: 0.068; episode_length: 14.051
iteration: 26; return: 0.156; episode_length: 11.950
iteration: 27; return: 0.152; episode_length: 12.289
iteration: 28; return: 0.154; episode_length: 12.379
iteration: 29; return: 0.123; episode_length: 11.412
iteration: 30; return: 0.169; episode_length: 11.417
iteration: 31; return: 0.138; episode_length: 11.746
iteration: 32; return: 0.161; episode_length: 11.570
iteration: 33; return: 0.125; episode_length: 11.503
iteration: 34; return: 0.200; episode_length: 11.367
iteration: 35; return: 0.194; episode_length: 11.368
iteration: 36; return: 0.202; episode_length: 11.642
iteration: 37; return: 0.202; episode_length: 12.325
iteration: 38; return: 0.149; episode_length: 11.548
iteration: 39; return: 0.221; episode_length: 11.776
iteration: 40; return: 0.175; episode_length: 11.955
iteration: 41; return: 0.218; episode_length: 12.246
iteration: 42; return: 0.261; episode_length: 12.423
iteration: 43; return: 0.275; episode_length: 13.018
iteration: 44; return: 0.211; episode_length: 12.636
iteration: 45; return: 0.228; episode_length: 13.036
iteration: 46; return: 0.173; episode_length: 12.540
iteration: 47; return: 0.274; episode_length: 13.039
iteration: 48; return: 0.263; episode_length: 11.936
iteration: 49; return: 0.210; episode_length: 12.595
iteration: 50; return: 0.250; episode_length: 12.111
iteration: 51; return: 0.262; episode_length: 12.022
iteration: 52; return: 0.306; episode_length: 11.610
iteration: 53; return: 0.224; episode_length: 11.874
iteration: 54; return: 0.262; episode_length: 11.938
iteration: 55; return: 0.233; episode_length: 12.012
iteration: 56; return: 0.251; episode_length: 11.477
iteration: 57; return: 0.326; episode_length: 10.963
iteration: 58; return: 0.280; episode_length: 10.926
iteration: 59; return: 0.241; episode_length: 10.987
iteration: 60; return: 0.330; episode_length: 10.565
iteration: 61; return: 0.329; episode_length: 11.118
iteration: 62; return: 0.333; episode_length: 11.348
iteration: 63; return: 0.313; episode_length: 11.366
iteration: 64; return: 0.364; episode_length: 10.846
iteration: 65; return: 0.361; episode_length: 11.724
iteration: 66; return: 0.338; episode_length: 12.353
iteration: 67; return: 0.324; episode_length: 11.475
iteration: 68; return: 0.324; episode_length: 11.776
iteration: 69; return: 0.382; episode_length: 11.282
iteration: 70; return: 0.315; episode_length: 11.408
iteration: 71; return: 0.372; episode_length: 11.438
iteration: 72; return: 0.331; episode_length: 11.600
iteration: 73; return: 0.331; episode_length: 10.874
iteration: 74; return: 0.382; episode_length: 10.833
iteration: 75; return: 0.359; episode_length: 11.066
iteration: 76; return: 0.325; episode_length: 11.317
iteration: 77; return: 0.349; episode_length: 11.073
iteration: 78; return: 0.391; episode_length: 10.841
iteration: 79; return: 0.383; episode_length: 10.745
iteration: 80; return: 0.337; episode_length: 11.199
iteration: 81; return: 0.405; episode_length: 10.663
iteration: 82; return: 0.380; episode_length: 11.233
iteration: 83; return: 0.431; episode_length: 10.658
iteration: 84; return: 0.390; episode_length: 10.692
iteration: 85; return: 0.371; episode_length: 10.390
iteration: 86; return: 0.369; episode_length: 10.296
iteration: 87; return: 0.451; episode_length: 10.380
iteration: 88; return: 0.453; episode_length: 10.337
iteration: 89; return: 0.441; episode_length: 10.324
iteration: 90; return: 0.399; episode_length: 10.418
iteration: 91; return: 0.427; episode_length: 10.374
iteration: 92; return: 0.440; episode_length: 9.938
iteration: 93; return: 0.445; episode_length: 10.185
iteration: 94; return: 0.456; episode_length: 10.423
iteration: 95; return: 0.485; episode_length: 10.279
iteration: 96; return: 0.474; episode_length: 10.177
iteration: 97; return: 0.470; episode_length: 10.050
iteration: 98; return: 0.531; episode_length: 9.783
iteration: 99; return: 0.487; episode_length: 10.024
iteration: 100; return: 0.533; episode_length: 9.990
In [5]:
Copied!
policy_array = [
derive_deterministic_policy(mdp=grid_mdp, policy=model)
for model in model_checkpoints
]
policy_array = [
derive_deterministic_policy(mdp=grid_mdp, policy=model)
for model in model_checkpoints
]
In [6]:
Copied!
plot_policy_step_grid_map = make_plot_policy_step_function(
columns=4, rows=3, policy_over_time=policy_array
)
plot_policy_step_grid_map = make_plot_policy_step_function(
columns=4, rows=3, policy_over_time=policy_array
)
In [7]:
Copied!
interactive_widgets = not os.getenv("CI") == "true" # non-interative in CI
if interactive_widgets:
import ipywidgets
from IPython.display import display
iteration_slider = ipywidgets.IntSlider(
min=0, max=len(model_checkpoints) - 1, step=1, value=0
)
w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)
display(w)
else:
plot_policy_step_grid_map(100)
interactive_widgets = not os.getenv("CI") == "true" # non-interative in CI
if interactive_widgets:
import ipywidgets
from IPython.display import display
iteration_slider = ipywidgets.IntSlider(
min=0, max=len(model_checkpoints) - 1, step=1, value=0
)
w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)
display(w)
else:
plot_policy_step_grid_map(100)
HIGHWAY EXAMPLE¶
In [8]:
Copied!
if False:
# we will change this to true later on, to see the effect
HIGHWAY_MDP_DICT["transition_probabilities_per_action"][LC_RIGHT_ACTION] = [
(0.4, LC_RIGHT_ACTION),
(0.6, STAY_IN_LANE_ACTION),
]
if False:
# we will change this to true later on, to see the effect
HIGHWAY_MDP_DICT["transition_probabilities_per_action"][LC_RIGHT_ACTION] = [
(0.4, LC_RIGHT_ACTION),
(0.6, STAY_IN_LANE_ACTION),
]
In [9]:
Copied!
highway_mdp = GridMDP(**HIGHWAY_MDP_DICT)
highway_mdp = GridMDP(**HIGHWAY_MDP_DICT)
In [10]:
Copied!
policy = CategoricalPolicy(
sizes=[len(highway_mdp.initial_state), 32, len(highway_mdp.actions)],
actions=list(highway_mdp.actions),
)
policy = CategoricalPolicy(
sizes=[len(highway_mdp.initial_state), 32, len(highway_mdp.actions)],
actions=list(highway_mdp.actions),
)
In [11]:
Copied!
model_checkpoints = policy_gradient(
mdp=highway_mdp,
policy=policy,
iterations=200,
return_history=True,
)
model_checkpoints = policy_gradient(
mdp=highway_mdp,
policy=policy,
iterations=200,
return_history=True,
)
iteration: 1; return: -49.091; episode_length: 9.862
iteration: 2; return: -35.374; episode_length: 10.883
iteration: 3; return: -33.875; episode_length: 12.237
iteration: 4; return: -33.777; episode_length: 13.662
iteration: 5; return: -36.944; episode_length: 15.675
iteration: 6; return: -38.374; episode_length: 15.767
iteration: 7; return: -40.931; episode_length: 17.279
iteration: 8; return: -39.780; episode_length: 16.946
iteration: 9; return: -40.364; episode_length: 17.017
iteration: 10; return: -39.980; episode_length: 17.037
iteration: 11; return: -37.926; episode_length: 16.132
iteration: 12; return: -34.840; episode_length: 14.858
iteration: 13; return: -32.209; episode_length: 13.413
iteration: 14; return: -31.856; episode_length: 12.856
iteration: 15; return: -30.891; episode_length: 11.881
iteration: 16; return: -33.857; episode_length: 11.004
iteration: 17; return: -33.594; episode_length: 10.649
iteration: 18; return: -32.876; episode_length: 10.876
iteration: 19; return: -31.449; episode_length: 11.307
iteration: 20; return: -31.181; episode_length: 11.749
iteration: 21; return: -31.218; episode_length: 12.160
iteration: 22; return: -30.230; episode_length: 12.232
iteration: 23; return: -29.900; episode_length: 12.177
iteration: 24; return: -30.771; episode_length: 12.465
iteration: 25; return: -30.003; episode_length: 12.575
iteration: 26; return: -31.088; episode_length: 12.995
iteration: 27; return: -30.472; episode_length: 12.716
iteration: 28; return: -27.886; episode_length: 11.678
iteration: 29; return: -29.200; episode_length: 11.118
iteration: 30; return: -28.811; episode_length: 11.009
iteration: 31; return: -29.494; episode_length: 10.523
iteration: 32; return: -28.898; episode_length: 10.622
iteration: 33; return: -28.212; episode_length: 10.610
iteration: 34; return: -29.302; episode_length: 10.810
iteration: 35; return: -28.829; episode_length: 10.861
iteration: 36; return: -27.279; episode_length: 11.635
iteration: 37; return: -27.730; episode_length: 12.072
iteration: 38; return: -27.933; episode_length: 11.924
iteration: 39; return: -27.108; episode_length: 12.290
iteration: 40; return: -27.385; episode_length: 12.200
iteration: 41; return: -27.720; episode_length: 11.699
iteration: 42; return: -25.924; episode_length: 11.590
iteration: 43; return: -26.240; episode_length: 11.335
iteration: 44; return: -26.031; episode_length: 11.165
iteration: 45; return: -27.110; episode_length: 11.208
iteration: 46; return: -25.972; episode_length: 10.945
iteration: 47; return: -24.474; episode_length: 10.831
iteration: 48; return: -24.950; episode_length: 10.829
iteration: 49; return: -25.239; episode_length: 10.694
iteration: 50; return: -25.000; episode_length: 10.572
iteration: 51; return: -25.743; episode_length: 10.382
iteration: 52; return: -25.294; episode_length: 10.497
iteration: 53; return: -23.858; episode_length: 10.610
iteration: 54; return: -23.648; episode_length: 10.676
iteration: 55; return: -23.869; episode_length: 11.145
iteration: 56; return: -24.241; episode_length: 11.475
iteration: 57; return: -25.103; episode_length: 11.470
iteration: 58; return: -24.088; episode_length: 10.998
iteration: 59; return: -23.251; episode_length: 10.996
iteration: 60; return: -24.642; episode_length: 10.928
iteration: 61; return: -24.168; episode_length: 10.651
iteration: 62; return: -24.331; episode_length: 10.568
iteration: 63; return: -23.146; episode_length: 10.604
iteration: 64; return: -23.551; episode_length: 10.606
iteration: 65; return: -22.869; episode_length: 10.604
iteration: 66; return: -23.702; episode_length: 10.513
iteration: 67; return: -24.065; episode_length: 10.499
iteration: 68; return: -23.151; episode_length: 10.678
iteration: 69; return: -22.625; episode_length: 10.855
iteration: 70; return: -23.174; episode_length: 11.165
iteration: 71; return: -22.777; episode_length: 11.399
iteration: 72; return: -23.466; episode_length: 11.486
iteration: 73; return: -22.679; episode_length: 11.328
iteration: 74; return: -22.459; episode_length: 11.368
iteration: 75; return: -22.344; episode_length: 11.174
iteration: 76; return: -22.343; episode_length: 10.891
iteration: 77; return: -21.929; episode_length: 11.129
iteration: 78; return: -22.917; episode_length: 10.913
iteration: 79; return: -22.405; episode_length: 11.026
iteration: 80; return: -22.678; episode_length: 10.872
iteration: 81; return: -22.568; episode_length: 10.866
iteration: 82; return: -23.009; episode_length: 10.956
iteration: 83; return: -21.186; episode_length: 10.831
iteration: 84; return: -21.261; episode_length: 10.891
iteration: 85; return: -21.263; episode_length: 10.954
iteration: 86; return: -21.046; episode_length: 10.848
iteration: 87; return: -21.203; episode_length: 11.029
iteration: 88; return: -21.430; episode_length: 10.850
iteration: 89; return: -20.454; episode_length: 10.967
iteration: 90; return: -20.551; episode_length: 10.915
iteration: 91; return: -20.517; episode_length: 10.694
iteration: 92; return: -20.097; episode_length: 10.759
iteration: 93; return: -21.600; episode_length: 10.772
iteration: 94; return: -20.356; episode_length: 10.680
iteration: 95; return: -20.525; episode_length: 10.852
iteration: 96; return: -21.024; episode_length: 10.730
iteration: 97; return: -20.616; episode_length: 10.793
iteration: 98; return: -20.544; episode_length: 10.755
iteration: 99; return: -21.190; episode_length: 10.954
iteration: 100; return: -21.453; episode_length: 10.697
iteration: 101; return: -21.541; episode_length: 10.833
iteration: 102; return: -20.545; episode_length: 10.686
iteration: 103; return: -21.538; episode_length: 10.852
iteration: 104; return: -20.618; episode_length: 10.667
iteration: 105; return: -20.950; episode_length: 10.950
iteration: 106; return: -21.562; episode_length: 10.738
iteration: 107; return: -21.071; episode_length: 10.834
iteration: 108; return: -19.554; episode_length: 10.732
iteration: 109; return: -20.913; episode_length: 10.852
iteration: 110; return: -19.229; episode_length: 10.705
iteration: 111; return: -20.347; episode_length: 10.612
iteration: 112; return: -20.310; episode_length: 10.791
iteration: 113; return: -19.544; episode_length: 10.770
iteration: 114; return: -20.376; episode_length: 10.924
iteration: 115; return: -21.765; episode_length: 10.985
iteration: 116; return: -21.092; episode_length: 10.993
iteration: 117; return: -21.459; episode_length: 11.060
iteration: 118; return: -20.741; episode_length: 11.106
iteration: 119; return: -20.579; episode_length: 11.115
iteration: 120; return: -20.682; episode_length: 10.736
iteration: 121; return: -21.872; episode_length: 11.022
iteration: 122; return: -19.702; episode_length: 10.878
iteration: 123; return: -20.833; episode_length: 10.874
iteration: 124; return: -19.609; episode_length: 10.714
iteration: 125; return: -20.777; episode_length: 10.859
iteration: 126; return: -21.762; episode_length: 10.640
iteration: 127; return: -20.261; episode_length: 10.635
iteration: 128; return: -20.318; episode_length: 10.676
iteration: 129; return: -20.867; episode_length: 10.539
iteration: 130; return: -20.278; episode_length: 10.694
iteration: 131; return: -20.161; episode_length: 10.588
iteration: 132; return: -21.028; episode_length: 10.791
iteration: 133; return: -20.583; episode_length: 10.694
iteration: 134; return: -20.067; episode_length: 10.734
iteration: 135; return: -20.243; episode_length: 10.866
iteration: 136; return: -21.045; episode_length: 10.766
iteration: 137; return: -20.121; episode_length: 10.588
iteration: 138; return: -19.682; episode_length: 10.622
iteration: 139; return: -19.859; episode_length: 10.889
iteration: 140; return: -20.596; episode_length: 10.668
iteration: 141; return: -19.652; episode_length: 10.493
iteration: 142; return: -21.456; episode_length: 10.755
iteration: 143; return: -19.225; episode_length: 10.545
iteration: 144; return: -19.495; episode_length: 10.665
iteration: 145; return: -20.770; episode_length: 10.753
iteration: 146; return: -19.680; episode_length: 10.801
iteration: 147; return: -20.382; episode_length: 10.689
iteration: 148; return: -20.244; episode_length: 10.904
iteration: 149; return: -20.253; episode_length: 10.803
iteration: 150; return: -20.517; episode_length: 10.640
iteration: 151; return: -20.437; episode_length: 10.724
iteration: 152; return: -19.013; episode_length: 10.487
iteration: 153; return: -20.055; episode_length: 10.684
iteration: 154; return: -20.757; episode_length: 10.982
iteration: 155; return: -20.949; episode_length: 10.686
iteration: 156; return: -19.408; episode_length: 10.541
iteration: 157; return: -20.277; episode_length: 10.680
iteration: 158; return: -20.107; episode_length: 10.676
iteration: 159; return: -19.678; episode_length: 10.665
iteration: 160; return: -20.660; episode_length: 10.755
iteration: 161; return: -20.133; episode_length: 10.772
iteration: 162; return: -19.733; episode_length: 10.795
iteration: 163; return: -19.479; episode_length: 10.651
iteration: 164; return: -19.687; episode_length: 10.441
iteration: 165; return: -19.650; episode_length: 10.501
iteration: 166; return: -19.985; episode_length: 10.801
iteration: 167; return: -20.011; episode_length: 10.825
iteration: 168; return: -19.662; episode_length: 10.863
iteration: 169; return: -20.378; episode_length: 10.768
iteration: 170; return: -20.004; episode_length: 10.891
iteration: 171; return: -19.740; episode_length: 10.734
iteration: 172; return: -20.192; episode_length: 10.778
iteration: 173; return: -20.623; episode_length: 10.795
iteration: 174; return: -20.317; episode_length: 10.857
iteration: 175; return: -20.628; episode_length: 10.844
iteration: 176; return: -22.828; episode_length: 10.655
iteration: 177; return: -20.440; episode_length: 10.547
iteration: 178; return: -20.161; episode_length: 10.581
iteration: 179; return: -20.326; episode_length: 10.575
iteration: 180; return: -19.767; episode_length: 10.816
iteration: 181; return: -19.686; episode_length: 10.697
iteration: 182; return: -18.970; episode_length: 10.713
iteration: 183; return: -19.038; episode_length: 10.674
iteration: 184; return: -19.638; episode_length: 10.680
iteration: 185; return: -19.910; episode_length: 10.770
iteration: 186; return: -19.616; episode_length: 10.758
iteration: 187; return: -19.176; episode_length: 10.891
iteration: 188; return: -19.606; episode_length: 10.719
iteration: 189; return: -19.685; episode_length: 10.732
iteration: 190; return: -19.673; episode_length: 10.840
iteration: 191; return: -20.462; episode_length: 10.657
iteration: 192; return: -19.893; episode_length: 10.732
iteration: 193; return: -19.953; episode_length: 10.643
iteration: 194; return: -21.379; episode_length: 10.438
iteration: 195; return: -19.939; episode_length: 10.547
iteration: 196; return: -19.866; episode_length: 10.455
iteration: 197; return: -19.389; episode_length: 10.600
iteration: 198; return: -19.344; episode_length: 10.493
iteration: 199; return: -20.036; episode_length: 10.584
iteration: 200; return: -19.609; episode_length: 10.714
In [12]:
Copied!
policy_array = [
derive_deterministic_policy(mdp=highway_mdp, policy=model)
for model in model_checkpoints
]
policy_array = [
derive_deterministic_policy(mdp=highway_mdp, policy=model)
for model in model_checkpoints
]
In [13]:
Copied!
plot_policy_step_grid_map = make_plot_policy_step_function(
columns=10, rows=4, policy_over_time=policy_array
)
plot_policy_step_grid_map = make_plot_policy_step_function(
columns=10, rows=4, policy_over_time=policy_array
)
In [14]:
Copied!
if interactive_widgets:
import ipywidgets
from IPython.display import display
iteration_slider = ipywidgets.IntSlider(
min=0, max=len(model_checkpoints) - 1, step=1, value=0
)
w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)
display(w)
else:
plot_policy_step_grid_map(200)
if interactive_widgets:
import ipywidgets
from IPython.display import display
iteration_slider = ipywidgets.IntSlider(
min=0, max=len(model_checkpoints) - 1, step=1, value=0
)
w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)
display(w)
else:
plot_policy_step_grid_map(200)