Policy Gradient
In [1]:
Copied!
import os
from behavior_generation_lecture_python.mdp.policy import CategoricalPolicy
from behavior_generation_lecture_python.utils.grid_plotting import (
make_plot_policy_step_function,
)
from behavior_generation_lecture_python.mdp.mdp import (
GridMDP,
policy_gradient,
derive_deterministic_policy,
GRID_MDP_DICT,
HIGHWAY_MDP_DICT,
LC_RIGHT_ACTION,
STAY_IN_LANE_ACTION,
)
HIGHWAY_MDP_DICT["restrict_actions_to_available_states"] = False
import os
from behavior_generation_lecture_python.mdp.policy import CategoricalPolicy
from behavior_generation_lecture_python.utils.grid_plotting import (
make_plot_policy_step_function,
)
from behavior_generation_lecture_python.mdp.mdp import (
GridMDP,
policy_gradient,
derive_deterministic_policy,
GRID_MDP_DICT,
HIGHWAY_MDP_DICT,
LC_RIGHT_ACTION,
STAY_IN_LANE_ACTION,
)
HIGHWAY_MDP_DICT["restrict_actions_to_available_states"] = False
TOY EXAMPLE¶
In [2]:
Copied!
grid_mdp = GridMDP(**GRID_MDP_DICT)
grid_mdp = GridMDP(**GRID_MDP_DICT)
In [3]:
Copied!
policy = CategoricalPolicy(
sizes=[len(grid_mdp.initial_state), 32, len(grid_mdp.actions)],
actions=list(grid_mdp.actions),
)
policy = CategoricalPolicy(
sizes=[len(grid_mdp.initial_state), 32, len(grid_mdp.actions)],
actions=list(grid_mdp.actions),
)
In [4]:
Copied!
model_checkpoints = policy_gradient(
mdp=grid_mdp,
policy=policy,
iterations=100,
return_history=True,
)
model_checkpoints = policy_gradient(
mdp=grid_mdp,
policy=policy,
iterations=100,
return_history=True,
)
iteration: 1; return: -1.277; episode_length: 30.107
iteration: 2; return: -1.246; episode_length: 26.460
iteration: 3; return: -0.943; episode_length: 21.013
iteration: 4; return: -0.783; episode_length: 18.805
iteration: 5; return: -0.743; episode_length: 17.568
iteration: 6; return: -0.502; episode_length: 13.831
iteration: 7; return: -0.478; episode_length: 12.889
iteration: 8; return: -0.339; episode_length: 12.227
iteration: 9; return: -0.364; episode_length: 12.002
iteration: 10; return: -0.231; episode_length: 11.077
iteration: 11; return: -0.159; episode_length: 11.541
iteration: 12; return: -0.068; episode_length: 12.536
iteration: 13; return: -0.112; episode_length: 12.207
iteration: 14; return: -0.050; episode_length: 12.808
iteration: 15; return: 0.016; episode_length: 12.740
iteration: 16; return: 0.038; episode_length: 14.011
iteration: 17; return: 0.075; episode_length: 14.911
iteration: 18; return: -0.010; episode_length: 15.382
iteration: 19; return: 0.076; episode_length: 16.742
iteration: 20; return: 0.087; episode_length: 15.481
iteration: 21; return: 0.035; episode_length: 14.881
iteration: 22; return: 0.112; episode_length: 14.412
iteration: 23; return: 0.164; episode_length: 14.513
iteration: 24; return: 0.122; episode_length: 13.268
iteration: 25; return: 0.201; episode_length: 12.874
iteration: 26; return: 0.185; episode_length: 12.659
iteration: 27; return: 0.148; episode_length: 11.795
iteration: 28; return: 0.135; episode_length: 11.491
iteration: 29; return: 0.189; episode_length: 11.886
iteration: 30; return: 0.141; episode_length: 11.131
iteration: 31; return: 0.161; episode_length: 10.993
iteration: 32; return: 0.161; episode_length: 11.109
iteration: 33; return: 0.215; episode_length: 11.586
iteration: 34; return: 0.209; episode_length: 12.158
iteration: 35; return: 0.273; episode_length: 11.917
iteration: 36; return: 0.241; episode_length: 12.733
iteration: 37; return: 0.239; episode_length: 11.938
iteration: 38; return: 0.230; episode_length: 12.740
iteration: 39; return: 0.295; episode_length: 12.429
iteration: 40; return: 0.299; episode_length: 12.363
iteration: 41; return: 0.319; episode_length: 12.272
iteration: 42; return: 0.265; episode_length: 12.046
iteration: 43; return: 0.348; episode_length: 11.685
iteration: 44; return: 0.332; episode_length: 11.814
iteration: 45; return: 0.361; episode_length: 11.245
iteration: 46; return: 0.377; episode_length: 11.438
iteration: 47; return: 0.324; episode_length: 11.265
iteration: 48; return: 0.376; episode_length: 10.859
iteration: 49; return: 0.350; episode_length: 10.859
iteration: 50; return: 0.289; episode_length: 11.252
iteration: 51; return: 0.323; episode_length: 10.810
iteration: 52; return: 0.300; episode_length: 10.945
iteration: 53; return: 0.296; episode_length: 11.415
iteration: 54; return: 0.370; episode_length: 11.020
iteration: 55; return: 0.373; episode_length: 11.340
iteration: 56; return: 0.361; episode_length: 10.974
iteration: 57; return: 0.354; episode_length: 11.145
iteration: 58; return: 0.346; episode_length: 11.133
iteration: 59; return: 0.368; episode_length: 11.442
iteration: 60; return: 0.375; episode_length: 11.238
iteration: 61; return: 0.360; episode_length: 11.800
iteration: 62; return: 0.348; episode_length: 11.548
iteration: 63; return: 0.366; episode_length: 11.165
iteration: 64; return: 0.431; episode_length: 11.192
iteration: 65; return: 0.405; episode_length: 11.263
iteration: 66; return: 0.388; episode_length: 11.333
iteration: 67; return: 0.380; episode_length: 11.312
iteration: 68; return: 0.444; episode_length: 11.049
iteration: 69; return: 0.408; episode_length: 11.046
iteration: 70; return: 0.469; episode_length: 10.738
iteration: 71; return: 0.478; episode_length: 10.248
iteration: 72; return: 0.478; episode_length: 9.964
iteration: 73; return: 0.499; episode_length: 10.577
iteration: 74; return: 0.456; episode_length: 10.427
iteration: 75; return: 0.496; episode_length: 10.066
iteration: 76; return: 0.522; episode_length: 10.128
iteration: 77; return: 0.549; episode_length: 10.144
iteration: 78; return: 0.549; episode_length: 9.994
iteration: 79; return: 0.543; episode_length: 10.111
iteration: 80; return: 0.544; episode_length: 10.089
iteration: 81; return: 0.577; episode_length: 9.889
iteration: 82; return: 0.591; episode_length: 10.024
iteration: 83; return: 0.568; episode_length: 9.833
iteration: 84; return: 0.612; episode_length: 9.739
iteration: 85; return: 0.597; episode_length: 9.629
iteration: 86; return: 0.615; episode_length: 10.030
iteration: 87; return: 0.618; episode_length: 9.852
iteration: 88; return: 0.607; episode_length: 9.848
iteration: 89; return: 0.624; episode_length: 9.715
iteration: 90; return: 0.630; episode_length: 9.866
iteration: 91; return: 0.629; episode_length: 9.507
iteration: 92; return: 0.623; episode_length: 9.570
iteration: 93; return: 0.628; episode_length: 9.544
iteration: 94; return: 0.640; episode_length: 9.424
iteration: 95; return: 0.628; episode_length: 9.715
iteration: 96; return: 0.634; episode_length: 9.497
iteration: 97; return: 0.637; episode_length: 9.338
iteration: 98; return: 0.618; episode_length: 9.431
iteration: 99; return: 0.645; episode_length: 9.404
iteration: 100; return: 0.625; episode_length: 9.510
In [5]:
Copied!
policy_array = [
derive_deterministic_policy(mdp=grid_mdp, policy=model)
for model in model_checkpoints
]
policy_array = [
derive_deterministic_policy(mdp=grid_mdp, policy=model)
for model in model_checkpoints
]
In [6]:
Copied!
plot_policy_step_grid_map = make_plot_policy_step_function(
columns=4, rows=3, policy_over_time=policy_array
)
plot_policy_step_grid_map = make_plot_policy_step_function(
columns=4, rows=3, policy_over_time=policy_array
)
In [7]:
Copied!
interactive_widgets = not os.getenv("CI") == "true" # non-interative in CI
if interactive_widgets:
import ipywidgets
from IPython.display import display
iteration_slider = ipywidgets.IntSlider(
min=0, max=len(model_checkpoints) - 1, step=1, value=0
)
w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)
display(w)
else:
plot_policy_step_grid_map(100)
interactive_widgets = not os.getenv("CI") == "true" # non-interative in CI
if interactive_widgets:
import ipywidgets
from IPython.display import display
iteration_slider = ipywidgets.IntSlider(
min=0, max=len(model_checkpoints) - 1, step=1, value=0
)
w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)
display(w)
else:
plot_policy_step_grid_map(100)
HIGHWAY EXAMPLE¶
In [8]:
Copied!
if False:
# we will change this to true later on, to see the effect
HIGHWAY_MDP_DICT["transition_probabilities_per_action"][LC_RIGHT_ACTION] = [
(0.4, LC_RIGHT_ACTION),
(0.6, STAY_IN_LANE_ACTION),
]
if False:
# we will change this to true later on, to see the effect
HIGHWAY_MDP_DICT["transition_probabilities_per_action"][LC_RIGHT_ACTION] = [
(0.4, LC_RIGHT_ACTION),
(0.6, STAY_IN_LANE_ACTION),
]
In [9]:
Copied!
highway_mdp = GridMDP(**HIGHWAY_MDP_DICT)
highway_mdp = GridMDP(**HIGHWAY_MDP_DICT)
In [10]:
Copied!
policy = CategoricalPolicy(
sizes=[len(highway_mdp.initial_state), 32, len(highway_mdp.actions)],
actions=list(highway_mdp.actions),
)
policy = CategoricalPolicy(
sizes=[len(highway_mdp.initial_state), 32, len(highway_mdp.actions)],
actions=list(highway_mdp.actions),
)
In [11]:
Copied!
model_checkpoints = policy_gradient(
mdp=highway_mdp,
policy=policy,
iterations=200,
return_history=True,
)
model_checkpoints = policy_gradient(
mdp=highway_mdp,
policy=policy,
iterations=200,
return_history=True,
)
iteration: 1; return: -56.606; episode_length: 9.844
iteration: 2; return: -52.329; episode_length: 10.347
iteration: 3; return: -43.888; episode_length: 11.220
iteration: 4; return: -41.127; episode_length: 12.455
iteration: 5; return: -37.115; episode_length: 13.385
iteration: 6; return: -38.499; episode_length: 14.170
iteration: 7; return: -35.788; episode_length: 14.558
iteration: 8; return: -38.409; episode_length: 15.446
iteration: 9; return: -38.872; episode_length: 15.216
iteration: 10; return: -35.352; episode_length: 14.210
iteration: 11; return: -35.431; episode_length: 13.894
iteration: 12; return: -34.837; episode_length: 13.782
iteration: 13; return: -33.391; episode_length: 12.701
iteration: 14; return: -32.643; episode_length: 11.868
iteration: 15; return: -35.838; episode_length: 11.425
iteration: 16; return: -34.383; episode_length: 11.077
iteration: 17; return: -34.009; episode_length: 11.347
iteration: 18; return: -32.090; episode_length: 11.321
iteration: 19; return: -33.756; episode_length: 11.619
iteration: 20; return: -32.325; episode_length: 11.807
iteration: 21; return: -32.330; episode_length: 12.545
iteration: 22; return: -31.633; episode_length: 13.144
iteration: 23; return: -31.939; episode_length: 13.233
iteration: 24; return: -33.192; episode_length: 13.357
iteration: 25; return: -32.051; episode_length: 13.372
iteration: 26; return: -32.522; episode_length: 13.003
iteration: 27; return: -30.355; episode_length: 12.765
iteration: 28; return: -31.939; episode_length: 12.803
iteration: 29; return: -31.516; episode_length: 12.348
iteration: 30; return: -30.510; episode_length: 12.041
iteration: 31; return: -29.864; episode_length: 11.735
iteration: 32; return: -30.433; episode_length: 11.340
iteration: 33; return: -30.062; episode_length: 11.026
iteration: 34; return: -31.157; episode_length: 10.772
iteration: 35; return: -29.128; episode_length: 10.880
iteration: 36; return: -29.481; episode_length: 11.188
iteration: 37; return: -28.940; episode_length: 11.912
iteration: 38; return: -30.291; episode_length: 11.962
iteration: 39; return: -28.798; episode_length: 12.360
iteration: 40; return: -27.993; episode_length: 11.900
iteration: 41; return: -28.148; episode_length: 11.926
iteration: 42; return: -27.776; episode_length: 11.550
iteration: 43; return: -27.124; episode_length: 11.694
iteration: 44; return: -27.330; episode_length: 11.326
iteration: 45; return: -27.236; episode_length: 10.952
iteration: 46; return: -27.237; episode_length: 10.595
iteration: 47; return: -26.013; episode_length: 10.612
iteration: 48; return: -26.156; episode_length: 10.575
iteration: 49; return: -26.898; episode_length: 10.857
iteration: 50; return: -25.573; episode_length: 10.866
iteration: 51; return: -25.940; episode_length: 10.742
iteration: 52; return: -26.357; episode_length: 10.588
iteration: 53; return: -26.035; episode_length: 10.431
iteration: 54; return: -25.645; episode_length: 10.347
iteration: 55; return: -24.113; episode_length: 10.273
iteration: 56; return: -25.112; episode_length: 10.233
iteration: 57; return: -24.958; episode_length: 10.464
iteration: 58; return: -24.213; episode_length: 10.674
iteration: 59; return: -24.645; episode_length: 10.657
iteration: 60; return: -24.160; episode_length: 10.651
iteration: 61; return: -24.553; episode_length: 10.606
iteration: 62; return: -24.019; episode_length: 10.376
iteration: 63; return: -23.727; episode_length: 10.038
iteration: 64; return: -24.434; episode_length: 9.878
iteration: 65; return: -25.479; episode_length: 9.747
iteration: 66; return: -22.849; episode_length: 9.808
iteration: 67; return: -24.233; episode_length: 9.812
iteration: 68; return: -23.870; episode_length: 9.994
iteration: 69; return: -22.201; episode_length: 10.175
iteration: 70; return: -21.946; episode_length: 10.375
iteration: 71; return: -22.701; episode_length: 10.252
iteration: 72; return: -23.018; episode_length: 10.229
iteration: 73; return: -23.350; episode_length: 10.264
iteration: 74; return: -22.646; episode_length: 10.167
iteration: 75; return: -22.266; episode_length: 9.796
iteration: 76; return: -22.100; episode_length: 9.608
iteration: 77; return: -23.721; episode_length: 9.644
iteration: 78; return: -22.527; episode_length: 9.594
iteration: 79; return: -22.358; episode_length: 9.675
iteration: 80; return: -21.473; episode_length: 9.664
iteration: 81; return: -21.741; episode_length: 9.891
iteration: 82; return: -21.971; episode_length: 9.687
iteration: 83; return: -22.273; episode_length: 9.750
iteration: 84; return: -21.740; episode_length: 9.732
iteration: 85; return: -21.640; episode_length: 9.741
iteration: 86; return: -21.366; episode_length: 9.758
iteration: 87; return: -21.724; episode_length: 9.717
iteration: 88; return: -21.869; episode_length: 9.508
iteration: 89; return: -21.245; episode_length: 9.499
iteration: 90; return: -20.990; episode_length: 9.557
iteration: 91; return: -21.136; episode_length: 9.584
iteration: 92; return: -20.749; episode_length: 9.530
iteration: 93; return: -21.547; episode_length: 9.614
iteration: 94; return: -20.613; episode_length: 9.350
iteration: 95; return: -21.048; episode_length: 9.629
iteration: 96; return: -20.693; episode_length: 9.444
iteration: 97; return: -20.535; episode_length: 9.501
iteration: 98; return: -20.942; episode_length: 9.601
iteration: 99; return: -20.497; episode_length: 9.574
iteration: 100; return: -20.957; episode_length: 9.673
iteration: 101; return: -20.558; episode_length: 9.463
iteration: 102; return: -21.403; episode_length: 9.473
iteration: 103; return: -20.872; episode_length: 9.400
iteration: 104; return: -20.728; episode_length: 9.336
iteration: 105; return: -20.878; episode_length: 9.388
iteration: 106; return: -21.075; episode_length: 9.424
iteration: 107; return: -20.611; episode_length: 9.442
iteration: 108; return: -20.192; episode_length: 9.442
iteration: 109; return: -20.060; episode_length: 9.402
iteration: 110; return: -20.641; episode_length: 9.412
iteration: 111; return: -20.026; episode_length: 9.469
iteration: 112; return: -20.294; episode_length: 9.475
iteration: 113; return: -20.709; episode_length: 9.508
iteration: 114; return: -20.106; episode_length: 9.479
iteration: 115; return: -20.330; episode_length: 9.365
iteration: 116; return: -20.448; episode_length: 9.303
iteration: 117; return: -20.094; episode_length: 9.231
iteration: 118; return: -20.523; episode_length: 9.315
iteration: 119; return: -20.192; episode_length: 9.340
iteration: 120; return: -19.851; episode_length: 9.330
iteration: 121; return: -20.303; episode_length: 9.361
iteration: 122; return: -20.034; episode_length: 9.383
iteration: 123; return: -19.909; episode_length: 9.526
iteration: 124; return: -20.138; episode_length: 9.570
iteration: 125; return: -20.498; episode_length: 9.670
iteration: 126; return: -20.169; episode_length: 9.412
iteration: 127; return: -19.336; episode_length: 9.284
iteration: 128; return: -20.232; episode_length: 9.208
iteration: 129; return: -19.951; episode_length: 9.131
iteration: 130; return: -21.332; episode_length: 9.111
iteration: 131; return: -19.748; episode_length: 9.128
iteration: 132; return: -20.063; episode_length: 9.065
iteration: 133; return: -20.189; episode_length: 9.095
iteration: 134; return: -19.960; episode_length: 9.165
iteration: 135; return: -19.224; episode_length: 9.202
iteration: 136; return: -19.875; episode_length: 9.214
iteration: 137; return: -19.579; episode_length: 9.380
iteration: 138; return: -20.762; episode_length: 9.779
iteration: 139; return: -20.464; episode_length: 9.679
iteration: 140; return: -20.131; episode_length: 9.762
iteration: 141; return: -19.781; episode_length: 9.685
iteration: 142; return: -19.768; episode_length: 9.598
iteration: 143; return: -19.393; episode_length: 9.461
iteration: 144; return: -18.827; episode_length: 9.305
iteration: 145; return: -19.840; episode_length: 9.324
iteration: 146; return: -19.304; episode_length: 9.172
iteration: 147; return: -19.956; episode_length: 9.143
iteration: 148; return: -19.217; episode_length: 9.058
iteration: 149; return: -20.149; episode_length: 9.091
iteration: 150; return: -20.365; episode_length: 9.091
iteration: 151; return: -19.774; episode_length: 9.052
iteration: 152; return: -19.857; episode_length: 9.049
iteration: 153; return: -19.926; episode_length: 9.087
iteration: 154; return: -19.447; episode_length: 9.060
iteration: 155; return: -19.336; episode_length: 9.100
iteration: 156; return: -19.690; episode_length: 9.142
iteration: 157; return: -18.832; episode_length: 9.214
iteration: 158; return: -19.064; episode_length: 9.217
iteration: 159; return: -19.130; episode_length: 9.191
iteration: 160; return: -19.147; episode_length: 9.195
iteration: 161; return: -19.028; episode_length: 9.369
iteration: 162; return: -19.118; episode_length: 9.369
iteration: 163; return: -19.983; episode_length: 9.511
iteration: 164; return: -20.190; episode_length: 9.584
iteration: 165; return: -19.547; episode_length: 9.365
iteration: 166; return: -19.398; episode_length: 9.404
iteration: 167; return: -19.064; episode_length: 9.418
iteration: 168; return: -20.123; episode_length: 9.459
iteration: 169; return: -19.215; episode_length: 9.371
iteration: 170; return: -19.495; episode_length: 9.174
iteration: 171; return: -18.473; episode_length: 9.157
iteration: 172; return: -18.575; episode_length: 9.246
iteration: 173; return: -19.378; episode_length: 9.130
iteration: 174; return: -18.260; episode_length: 9.165
iteration: 175; return: -18.503; episode_length: 9.118
iteration: 176; return: -18.505; episode_length: 9.072
iteration: 177; return: -18.196; episode_length: 9.067
iteration: 178; return: -18.778; episode_length: 9.159
iteration: 179; return: -18.905; episode_length: 9.170
iteration: 180; return: -18.538; episode_length: 9.174
iteration: 181; return: -18.895; episode_length: 9.225
iteration: 182; return: -19.015; episode_length: 9.200
iteration: 183; return: -19.241; episode_length: 9.278
iteration: 184; return: -18.858; episode_length: 9.200
iteration: 185; return: -19.155; episode_length: 9.139
iteration: 186; return: -18.903; episode_length: 9.137
iteration: 187; return: -19.410; episode_length: 9.118
iteration: 188; return: -18.793; episode_length: 9.085
iteration: 189; return: -18.818; episode_length: 9.180
iteration: 190; return: -19.146; episode_length: 9.130
iteration: 191; return: -19.217; episode_length: 9.272
iteration: 192; return: -19.125; episode_length: 9.219
iteration: 193; return: -19.358; episode_length: 9.236
iteration: 194; return: -18.524; episode_length: 9.204
iteration: 195; return: -18.803; episode_length: 9.221
iteration: 196; return: -19.242; episode_length: 9.185
iteration: 197; return: -18.607; episode_length: 9.098
iteration: 198; return: -19.192; episode_length: 9.052
iteration: 199; return: -19.038; episode_length: 9.085
iteration: 200; return: -18.816; episode_length: 9.115
In [12]:
Copied!
policy_array = [
derive_deterministic_policy(mdp=highway_mdp, policy=model)
for model in model_checkpoints
]
policy_array = [
derive_deterministic_policy(mdp=highway_mdp, policy=model)
for model in model_checkpoints
]
In [13]:
Copied!
plot_policy_step_grid_map = make_plot_policy_step_function(
columns=10, rows=4, policy_over_time=policy_array
)
plot_policy_step_grid_map = make_plot_policy_step_function(
columns=10, rows=4, policy_over_time=policy_array
)
In [14]:
Copied!
if interactive_widgets:
import ipywidgets
from IPython.display import display
iteration_slider = ipywidgets.IntSlider(
min=0, max=len(model_checkpoints) - 1, step=1, value=0
)
w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)
display(w)
else:
plot_policy_step_grid_map(200)
if interactive_widgets:
import ipywidgets
from IPython.display import display
iteration_slider = ipywidgets.IntSlider(
min=0, max=len(model_checkpoints) - 1, step=1, value=0
)
w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)
display(w)
else:
plot_policy_step_grid_map(200)