Policy Gradient
In [1]:
Copied!
import os
from behavior_generation_lecture_python.mdp.policy import CategoricalPolicy
from behavior_generation_lecture_python.utils.grid_plotting import (
make_plot_policy_step_function,
)
from behavior_generation_lecture_python.mdp.mdp import (
GridMDP,
policy_gradient,
derive_deterministic_policy,
GRID_MDP_DICT,
HIGHWAY_MDP_DICT,
LC_RIGHT_ACTION,
STAY_IN_LANE_ACTION,
)
HIGHWAY_MDP_DICT["restrict_actions_to_available_states"] = False
import os
from behavior_generation_lecture_python.mdp.policy import CategoricalPolicy
from behavior_generation_lecture_python.utils.grid_plotting import (
make_plot_policy_step_function,
)
from behavior_generation_lecture_python.mdp.mdp import (
GridMDP,
policy_gradient,
derive_deterministic_policy,
GRID_MDP_DICT,
HIGHWAY_MDP_DICT,
LC_RIGHT_ACTION,
STAY_IN_LANE_ACTION,
)
HIGHWAY_MDP_DICT["restrict_actions_to_available_states"] = False
TOY EXAMPLE¶
In [2]:
Copied!
grid_mdp = GridMDP(**GRID_MDP_DICT)
grid_mdp = GridMDP(**GRID_MDP_DICT)
In [3]:
Copied!
policy = CategoricalPolicy(
sizes=[len(grid_mdp.initial_state), 32, len(grid_mdp.actions)],
actions=list(grid_mdp.actions),
)
policy = CategoricalPolicy(
sizes=[len(grid_mdp.initial_state), 32, len(grid_mdp.actions)],
actions=list(grid_mdp.actions),
)
In [4]:
Copied!
model_checkpoints = policy_gradient(
mdp=grid_mdp,
policy=policy,
iterations=100,
return_history=True,
)
model_checkpoints = policy_gradient(
mdp=grid_mdp,
policy=policy,
iterations=100,
return_history=True,
)
iteration: 1; return: -2.562; episode_length: 60.542
iteration: 2; return: -1.912; episode_length: 44.786
iteration: 3; return: -1.419; episode_length: 32.706
iteration: 4; return: -1.305; episode_length: 28.287
iteration: 5; return: -1.160; episode_length: 22.096
iteration: 6; return: -0.912; episode_length: 17.072
iteration: 7; return: -0.844; episode_length: 18.123
iteration: 8; return: -0.805; episode_length: 13.835
iteration: 9; return: -0.647; episode_length: 13.371
iteration: 10; return: -0.612; episode_length: 11.950
iteration: 11; return: -0.497; episode_length: 12.087
iteration: 12; return: -0.364; episode_length: 11.751
iteration: 13; return: -0.281; episode_length: 11.640
iteration: 14; return: -0.286; episode_length: 11.891
iteration: 15; return: -0.226; episode_length: 12.409
iteration: 16; return: -0.110; episode_length: 13.129
iteration: 17; return: -0.084; episode_length: 12.241
iteration: 18; return: 0.014; episode_length: 12.836
iteration: 19; return: 0.112; episode_length: 13.496
iteration: 20; return: 0.054; episode_length: 14.090
iteration: 21; return: 0.098; episode_length: 15.300
iteration: 22; return: 0.052; episode_length: 15.172
iteration: 23; return: 0.103; episode_length: 15.219
iteration: 24; return: 0.131; episode_length: 15.785
iteration: 25; return: 0.085; episode_length: 16.350
iteration: 26; return: 0.100; episode_length: 17.037
iteration: 27; return: 0.072; episode_length: 16.576
iteration: 28; return: 0.044; episode_length: 16.187
iteration: 29; return: 0.113; episode_length: 15.792
iteration: 30; return: 0.059; episode_length: 16.164
iteration: 31; return: 0.126; episode_length: 14.570
iteration: 32; return: 0.189; episode_length: 14.979
iteration: 33; return: 0.179; episode_length: 14.006
iteration: 34; return: 0.194; episode_length: 13.662
iteration: 35; return: 0.207; episode_length: 12.879
iteration: 36; return: 0.231; episode_length: 12.728
iteration: 37; return: 0.215; episode_length: 11.660
iteration: 38; return: 0.186; episode_length: 11.363
iteration: 39; return: 0.159; episode_length: 11.660
iteration: 40; return: 0.224; episode_length: 11.537
iteration: 41; return: 0.176; episode_length: 10.810
iteration: 42; return: 0.197; episode_length: 11.779
iteration: 43; return: 0.220; episode_length: 11.273
iteration: 44; return: 0.231; episode_length: 11.321
iteration: 45; return: 0.222; episode_length: 11.438
iteration: 46; return: 0.286; episode_length: 11.940
iteration: 47; return: 0.239; episode_length: 11.788
iteration: 48; return: 0.212; episode_length: 12.786
iteration: 49; return: 0.295; episode_length: 12.026
iteration: 50; return: 0.351; episode_length: 11.647
iteration: 51; return: 0.298; episode_length: 12.198
iteration: 52; return: 0.238; episode_length: 11.731
iteration: 53; return: 0.251; episode_length: 12.484
iteration: 54; return: 0.281; episode_length: 12.502
iteration: 55; return: 0.237; episode_length: 11.950
iteration: 56; return: 0.321; episode_length: 11.746
iteration: 57; return: 0.286; episode_length: 11.858
iteration: 58; return: 0.329; episode_length: 11.396
iteration: 59; return: 0.343; episode_length: 11.879
iteration: 60; return: 0.384; episode_length: 12.063
iteration: 61; return: 0.304; episode_length: 12.439
iteration: 62; return: 0.344; episode_length: 11.446
iteration: 63; return: 0.326; episode_length: 11.674
iteration: 64; return: 0.330; episode_length: 11.765
iteration: 65; return: 0.323; episode_length: 11.069
iteration: 66; return: 0.364; episode_length: 11.275
iteration: 67; return: 0.379; episode_length: 11.884
iteration: 68; return: 0.402; episode_length: 11.724
iteration: 69; return: 0.409; episode_length: 11.767
iteration: 70; return: 0.392; episode_length: 11.745
iteration: 71; return: 0.437; episode_length: 11.689
iteration: 72; return: 0.350; episode_length: 11.432
iteration: 73; return: 0.432; episode_length: 10.852
iteration: 74; return: 0.421; episode_length: 10.974
iteration: 75; return: 0.468; episode_length: 11.086
iteration: 76; return: 0.432; episode_length: 11.102
iteration: 77; return: 0.458; episode_length: 10.692
iteration: 78; return: 0.447; episode_length: 10.980
iteration: 79; return: 0.422; episode_length: 11.279
iteration: 80; return: 0.458; episode_length: 10.778
iteration: 81; return: 0.503; episode_length: 10.915
iteration: 82; return: 0.477; episode_length: 10.676
iteration: 83; return: 0.471; episode_length: 11.204
iteration: 84; return: 0.486; episode_length: 11.004
iteration: 85; return: 0.475; episode_length: 10.618
iteration: 86; return: 0.530; episode_length: 10.711
iteration: 87; return: 0.523; episode_length: 10.682
iteration: 88; return: 0.586; episode_length: 10.697
iteration: 89; return: 0.536; episode_length: 10.592
iteration: 90; return: 0.529; episode_length: 10.464
iteration: 91; return: 0.561; episode_length: 10.697
iteration: 92; return: 0.521; episode_length: 10.732
iteration: 93; return: 0.547; episode_length: 10.620
iteration: 94; return: 0.574; episode_length: 10.300
iteration: 95; return: 0.529; episode_length: 10.315
iteration: 96; return: 0.551; episode_length: 10.268
iteration: 97; return: 0.606; episode_length: 10.146
iteration: 98; return: 0.561; episode_length: 10.072
iteration: 99; return: 0.597; episode_length: 10.171
iteration: 100; return: 0.603; episode_length: 10.018
In [5]:
Copied!
policy_array = [
derive_deterministic_policy(mdp=grid_mdp, policy=model)
for model in model_checkpoints
]
policy_array = [
derive_deterministic_policy(mdp=grid_mdp, policy=model)
for model in model_checkpoints
]
In [6]:
Copied!
plot_policy_step_grid_map = make_plot_policy_step_function(
columns=4, rows=3, policy_over_time=policy_array
)
plot_policy_step_grid_map = make_plot_policy_step_function(
columns=4, rows=3, policy_over_time=policy_array
)
In [7]:
Copied!
interactive_widgets = not os.getenv("CI") == "true" # non-interative in CI
if interactive_widgets:
import ipywidgets
from IPython.display import display
iteration_slider = ipywidgets.IntSlider(
min=0, max=len(model_checkpoints) - 1, step=1, value=0
)
w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)
display(w)
else:
plot_policy_step_grid_map(100)
interactive_widgets = not os.getenv("CI") == "true" # non-interative in CI
if interactive_widgets:
import ipywidgets
from IPython.display import display
iteration_slider = ipywidgets.IntSlider(
min=0, max=len(model_checkpoints) - 1, step=1, value=0
)
w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)
display(w)
else:
plot_policy_step_grid_map(100)
HIGHWAY EXAMPLE¶
In [8]:
Copied!
if False:
# we will change this to true later on, to see the effect
HIGHWAY_MDP_DICT["transition_probabilities_per_action"][LC_RIGHT_ACTION] = [
(0.4, LC_RIGHT_ACTION),
(0.6, STAY_IN_LANE_ACTION),
]
if False:
# we will change this to true later on, to see the effect
HIGHWAY_MDP_DICT["transition_probabilities_per_action"][LC_RIGHT_ACTION] = [
(0.4, LC_RIGHT_ACTION),
(0.6, STAY_IN_LANE_ACTION),
]
In [9]:
Copied!
highway_mdp = GridMDP(**HIGHWAY_MDP_DICT)
highway_mdp = GridMDP(**HIGHWAY_MDP_DICT)
In [10]:
Copied!
policy = CategoricalPolicy(
sizes=[len(highway_mdp.initial_state), 32, len(highway_mdp.actions)],
actions=list(highway_mdp.actions),
)
policy = CategoricalPolicy(
sizes=[len(highway_mdp.initial_state), 32, len(highway_mdp.actions)],
actions=list(highway_mdp.actions),
)
In [11]:
Copied!
model_checkpoints = policy_gradient(
mdp=highway_mdp,
policy=policy,
iterations=200,
return_history=True,
)
model_checkpoints = policy_gradient(
mdp=highway_mdp,
policy=policy,
iterations=200,
return_history=True,
)
iteration: 1; return: -62.362; episode_length: 9.679
iteration: 2; return: -58.430; episode_length: 9.519
iteration: 3; return: -49.317; episode_length: 9.795
iteration: 4; return: -39.321; episode_length: 10.705
iteration: 5; return: -36.183; episode_length: 13.115
iteration: 6; return: -36.389; episode_length: 15.207
iteration: 7; return: -39.605; episode_length: 16.671
iteration: 8; return: -42.462; episode_length: 18.337
iteration: 9; return: -42.561; episode_length: 18.625
iteration: 10; return: -42.015; episode_length: 18.461
iteration: 11; return: -44.822; episode_length: 19.407
iteration: 12; return: -43.844; episode_length: 19.118
iteration: 13; return: -43.720; episode_length: 19.184
iteration: 14; return: -40.071; episode_length: 17.840
iteration: 15; return: -39.697; episode_length: 17.655
iteration: 16; return: -37.197; episode_length: 16.680
iteration: 17; return: -35.910; episode_length: 16.090
iteration: 18; return: -34.976; episode_length: 15.084
iteration: 19; return: -32.932; episode_length: 13.667
iteration: 20; return: -32.494; episode_length: 12.437
iteration: 21; return: -32.332; episode_length: 11.393
iteration: 22; return: -31.978; episode_length: 11.042
iteration: 23; return: -32.538; episode_length: 10.707
iteration: 24; return: -34.079; episode_length: 10.749
iteration: 25; return: -32.624; episode_length: 11.116
iteration: 26; return: -31.758; episode_length: 11.532
iteration: 27; return: -30.487; episode_length: 12.595
iteration: 28; return: -33.190; episode_length: 13.606
iteration: 29; return: -31.304; episode_length: 13.556
iteration: 30; return: -31.414; episode_length: 13.465
iteration: 31; return: -30.449; episode_length: 13.401
iteration: 32; return: -30.135; episode_length: 12.537
iteration: 33; return: -30.474; episode_length: 12.180
iteration: 34; return: -29.757; episode_length: 12.055
iteration: 35; return: -29.348; episode_length: 11.391
iteration: 36; return: -29.511; episode_length: 10.971
iteration: 37; return: -31.120; episode_length: 10.887
iteration: 38; return: -30.677; episode_length: 10.789
iteration: 39; return: -30.018; episode_length: 11.088
iteration: 40; return: -28.755; episode_length: 11.147
iteration: 41; return: -30.305; episode_length: 11.314
iteration: 42; return: -29.083; episode_length: 11.816
iteration: 43; return: -28.446; episode_length: 12.058
iteration: 44; return: -29.339; episode_length: 12.302
iteration: 45; return: -26.583; episode_length: 11.477
iteration: 46; return: -27.229; episode_length: 11.215
iteration: 47; return: -28.384; episode_length: 11.302
iteration: 48; return: -27.955; episode_length: 10.810
iteration: 49; return: -27.513; episode_length: 11.024
iteration: 50; return: -27.414; episode_length: 10.971
iteration: 51; return: -27.680; episode_length: 10.766
iteration: 52; return: -27.897; episode_length: 10.565
iteration: 53; return: -28.590; episode_length: 10.924
iteration: 54; return: -27.437; episode_length: 11.158
iteration: 55; return: -27.009; episode_length: 11.713
iteration: 56; return: -27.734; episode_length: 12.220
iteration: 57; return: -27.174; episode_length: 12.324
iteration: 58; return: -27.473; episode_length: 12.101
iteration: 59; return: -26.347; episode_length: 11.964
iteration: 60; return: -25.825; episode_length: 11.345
iteration: 61; return: -25.305; episode_length: 10.778
iteration: 62; return: -26.443; episode_length: 10.538
iteration: 63; return: -26.572; episode_length: 10.343
iteration: 64; return: -25.240; episode_length: 10.022
iteration: 65; return: -27.543; episode_length: 10.105
iteration: 66; return: -25.486; episode_length: 10.132
iteration: 67; return: -24.145; episode_length: 10.241
iteration: 68; return: -24.304; episode_length: 10.649
iteration: 69; return: -24.379; episode_length: 10.531
iteration: 70; return: -24.030; episode_length: 10.612
iteration: 71; return: -25.269; episode_length: 10.937
iteration: 72; return: -24.768; episode_length: 10.761
iteration: 73; return: -24.457; episode_length: 10.815
iteration: 74; return: -23.706; episode_length: 10.300
iteration: 75; return: -24.018; episode_length: 10.132
iteration: 76; return: -23.764; episode_length: 10.115
iteration: 77; return: -23.717; episode_length: 10.254
iteration: 78; return: -23.528; episode_length: 10.271
iteration: 79; return: -24.246; episode_length: 10.595
iteration: 80; return: -24.038; episode_length: 10.610
iteration: 81; return: -23.320; episode_length: 10.621
iteration: 82; return: -23.561; episode_length: 10.271
iteration: 83; return: -23.540; episode_length: 10.016
iteration: 84; return: -22.838; episode_length: 9.897
iteration: 85; return: -23.285; episode_length: 9.702
iteration: 86; return: -23.166; episode_length: 9.636
iteration: 87; return: -23.551; episode_length: 9.715
iteration: 88; return: -22.113; episode_length: 9.764
iteration: 89; return: -21.755; episode_length: 9.796
iteration: 90; return: -22.744; episode_length: 9.942
iteration: 91; return: -22.276; episode_length: 9.868
iteration: 92; return: -22.528; episode_length: 9.964
iteration: 93; return: -22.794; episode_length: 9.818
iteration: 94; return: -22.098; episode_length: 9.856
iteration: 95; return: -21.797; episode_length: 9.864
iteration: 96; return: -22.828; episode_length: 9.664
iteration: 97; return: -21.548; episode_length: 9.510
iteration: 98; return: -22.524; episode_length: 9.568
iteration: 99; return: -21.521; episode_length: 9.561
iteration: 100; return: -21.757; episode_length: 9.576
iteration: 101; return: -22.065; episode_length: 9.627
iteration: 102; return: -20.705; episode_length: 9.592
iteration: 103; return: -21.215; episode_length: 9.603
iteration: 104; return: -22.488; episode_length: 9.493
iteration: 105; return: -21.967; episode_length: 9.749
iteration: 106; return: -21.475; episode_length: 9.866
iteration: 107; return: -22.545; episode_length: 9.990
iteration: 108; return: -22.373; episode_length: 10.010
iteration: 109; return: -21.994; episode_length: 9.938
iteration: 110; return: -22.191; episode_length: 10.048
iteration: 111; return: -21.971; episode_length: 9.730
iteration: 112; return: -21.671; episode_length: 9.789
iteration: 113; return: -21.879; episode_length: 9.758
iteration: 114; return: -20.696; episode_length: 9.739
iteration: 115; return: -21.025; episode_length: 9.610
iteration: 116; return: -21.148; episode_length: 9.614
iteration: 117; return: -21.422; episode_length: 9.559
iteration: 118; return: -21.210; episode_length: 9.664
iteration: 119; return: -20.387; episode_length: 9.683
iteration: 120; return: -20.693; episode_length: 9.768
iteration: 121; return: -21.155; episode_length: 9.581
iteration: 122; return: -20.934; episode_length: 9.670
iteration: 123; return: -21.033; episode_length: 9.572
iteration: 124; return: -20.300; episode_length: 9.491
iteration: 125; return: -20.099; episode_length: 9.544
iteration: 126; return: -20.077; episode_length: 9.418
iteration: 127; return: -19.869; episode_length: 9.359
iteration: 128; return: -20.968; episode_length: 9.433
iteration: 129; return: -20.759; episode_length: 9.418
iteration: 130; return: -20.585; episode_length: 9.415
iteration: 131; return: -20.362; episode_length: 9.438
iteration: 132; return: -20.544; episode_length: 9.324
iteration: 133; return: -21.403; episode_length: 9.418
iteration: 134; return: -20.338; episode_length: 9.513
iteration: 135; return: -20.108; episode_length: 9.475
iteration: 136; return: -20.082; episode_length: 9.515
iteration: 137; return: -20.554; episode_length: 9.539
iteration: 138; return: -19.642; episode_length: 9.530
iteration: 139; return: -20.348; episode_length: 9.485
iteration: 140; return: -20.161; episode_length: 9.363
iteration: 141; return: -19.614; episode_length: 9.330
iteration: 142; return: -19.782; episode_length: 9.259
iteration: 143; return: -20.470; episode_length: 9.253
iteration: 144; return: -19.833; episode_length: 9.305
iteration: 145; return: -20.030; episode_length: 9.293
iteration: 146; return: -20.051; episode_length: 9.385
iteration: 147; return: -19.558; episode_length: 9.505
iteration: 148; return: -20.065; episode_length: 9.526
iteration: 149; return: -19.763; episode_length: 9.250
iteration: 150; return: -19.622; episode_length: 9.263
iteration: 151; return: -19.914; episode_length: 9.204
iteration: 152; return: -19.273; episode_length: 9.227
iteration: 153; return: -19.657; episode_length: 9.214
iteration: 154; return: -20.359; episode_length: 9.133
iteration: 155; return: -21.189; episode_length: 9.118
iteration: 156; return: -20.230; episode_length: 9.130
iteration: 157; return: -19.406; episode_length: 9.206
iteration: 158; return: -19.757; episode_length: 9.267
iteration: 159; return: -19.235; episode_length: 9.259
iteration: 160; return: -19.976; episode_length: 9.307
iteration: 161; return: -19.606; episode_length: 9.420
iteration: 162; return: -19.924; episode_length: 9.566
iteration: 163; return: -19.646; episode_length: 9.483
iteration: 164; return: -19.605; episode_length: 9.404
iteration: 165; return: -19.842; episode_length: 9.421
iteration: 166; return: -19.253; episode_length: 9.305
iteration: 167; return: -19.211; episode_length: 9.272
iteration: 168; return: -19.044; episode_length: 9.170
iteration: 169; return: -19.487; episode_length: 9.139
iteration: 170; return: -19.684; episode_length: 9.087
iteration: 171; return: -20.512; episode_length: 9.078
iteration: 172; return: -19.858; episode_length: 9.133
iteration: 173; return: -18.812; episode_length: 9.111
iteration: 174; return: -19.363; episode_length: 9.130
iteration: 175; return: -19.187; episode_length: 9.269
iteration: 176; return: -19.397; episode_length: 9.340
iteration: 177; return: -19.841; episode_length: 9.396
iteration: 178; return: -20.004; episode_length: 9.715
iteration: 179; return: -19.662; episode_length: 9.601
iteration: 180; return: -19.469; episode_length: 9.670
iteration: 181; return: -19.425; episode_length: 9.365
iteration: 182; return: -20.020; episode_length: 9.234
iteration: 183; return: -19.579; episode_length: 9.176
iteration: 184; return: -18.816; episode_length: 9.131
iteration: 185; return: -20.105; episode_length: 9.082
iteration: 186; return: -20.056; episode_length: 9.083
iteration: 187; return: -19.860; episode_length: 9.124
iteration: 188; return: -19.293; episode_length: 9.113
iteration: 189; return: -18.894; episode_length: 9.159
iteration: 190; return: -18.696; episode_length: 9.165
iteration: 191; return: -19.750; episode_length: 9.157
iteration: 192; return: -19.245; episode_length: 9.141
iteration: 193; return: -19.120; episode_length: 9.248
iteration: 194; return: -18.634; episode_length: 9.296
iteration: 195; return: -20.045; episode_length: 9.490
iteration: 196; return: -19.313; episode_length: 9.326
iteration: 197; return: -19.455; episode_length: 9.371
iteration: 198; return: -18.892; episode_length: 9.170
iteration: 199; return: -19.294; episode_length: 9.155
iteration: 200; return: -19.149; episode_length: 9.120
In [12]:
Copied!
policy_array = [
derive_deterministic_policy(mdp=highway_mdp, policy=model)
for model in model_checkpoints
]
policy_array = [
derive_deterministic_policy(mdp=highway_mdp, policy=model)
for model in model_checkpoints
]
In [13]:
Copied!
plot_policy_step_grid_map = make_plot_policy_step_function(
columns=10, rows=4, policy_over_time=policy_array
)
plot_policy_step_grid_map = make_plot_policy_step_function(
columns=10, rows=4, policy_over_time=policy_array
)
In [14]:
Copied!
if interactive_widgets:
import ipywidgets
from IPython.display import display
iteration_slider = ipywidgets.IntSlider(
min=0, max=len(model_checkpoints) - 1, step=1, value=0
)
w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)
display(w)
else:
plot_policy_step_grid_map(200)
if interactive_widgets:
import ipywidgets
from IPython.display import display
iteration_slider = ipywidgets.IntSlider(
min=0, max=len(model_checkpoints) - 1, step=1, value=0
)
w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)
display(w)
else:
plot_policy_step_grid_map(200)