Policy Gradient
In [1]:
Copied!
import os
from behavior_generation_lecture_python.mdp.policy import CategoricalPolicy
from behavior_generation_lecture_python.utils.grid_plotting import (
make_plot_policy_step_function,
)
from behavior_generation_lecture_python.mdp.mdp import (
GridMDP,
policy_gradient,
derive_deterministic_policy,
GRID_MDP_DICT,
HIGHWAY_MDP_DICT,
LC_RIGHT_ACTION,
STAY_IN_LANE_ACTION,
)
HIGHWAY_MDP_DICT["restrict_actions_to_available_states"] = False
import os
from behavior_generation_lecture_python.mdp.policy import CategoricalPolicy
from behavior_generation_lecture_python.utils.grid_plotting import (
make_plot_policy_step_function,
)
from behavior_generation_lecture_python.mdp.mdp import (
GridMDP,
policy_gradient,
derive_deterministic_policy,
GRID_MDP_DICT,
HIGHWAY_MDP_DICT,
LC_RIGHT_ACTION,
STAY_IN_LANE_ACTION,
)
HIGHWAY_MDP_DICT["restrict_actions_to_available_states"] = False
TOY EXAMPLE¶
In [2]:
Copied!
grid_mdp = GridMDP(**GRID_MDP_DICT)
grid_mdp = GridMDP(**GRID_MDP_DICT)
In [3]:
Copied!
policy = CategoricalPolicy(
sizes=[len(grid_mdp.initial_state), 32, len(grid_mdp.actions)],
actions=list(grid_mdp.actions),
)
policy = CategoricalPolicy(
sizes=[len(grid_mdp.initial_state), 32, len(grid_mdp.actions)],
actions=list(grid_mdp.actions),
)
In [4]:
Copied!
model_checkpoints = policy_gradient(
mdp=grid_mdp,
policy=policy,
iterations=100,
return_history=True,
)
model_checkpoints = policy_gradient(
mdp=grid_mdp,
policy=policy,
iterations=100,
return_history=True,
)
iteration: 1; return: -1.037; episode_length: 18.347
iteration: 2; return: -0.883; episode_length: 16.191
iteration: 3; return: -0.829; episode_length: 15.683
iteration: 4; return: -0.602; episode_length: 14.726
iteration: 5; return: -0.591; episode_length: 13.481
iteration: 6; return: -0.455; episode_length: 13.023
iteration: 7; return: -0.352; episode_length: 13.150
iteration: 8; return: -0.284; episode_length: 12.746
iteration: 9; return: -0.150; episode_length: 13.914
iteration: 10; return: -0.124; episode_length: 13.373
iteration: 11; return: -0.095; episode_length: 15.601
iteration: 12; return: -0.045; episode_length: 15.443
iteration: 13; return: -0.084; episode_length: 15.753
iteration: 14; return: -0.092; episode_length: 17.498
iteration: 15; return: -0.018; episode_length: 15.848
iteration: 16; return: -0.013; episode_length: 17.675
iteration: 17; return: 0.106; episode_length: 16.560
iteration: 18; return: 0.091; episode_length: 15.623
iteration: 19; return: 0.051; episode_length: 16.513
iteration: 20; return: 0.097; episode_length: 13.712
iteration: 21; return: 0.120; episode_length: 13.947
iteration: 22; return: 0.142; episode_length: 13.129
iteration: 23; return: 0.174; episode_length: 12.448
iteration: 24; return: 0.227; episode_length: 12.727
iteration: 25; return: 0.220; episode_length: 12.202
iteration: 26; return: 0.275; episode_length: 12.187
iteration: 27; return: 0.265; episode_length: 11.739
iteration: 28; return: 0.235; episode_length: 11.839
iteration: 29; return: 0.314; episode_length: 11.270
iteration: 30; return: 0.295; episode_length: 11.719
iteration: 31; return: 0.259; episode_length: 11.640
iteration: 32; return: 0.312; episode_length: 11.174
iteration: 33; return: 0.359; episode_length: 11.739
iteration: 34; return: 0.367; episode_length: 11.609
iteration: 35; return: 0.364; episode_length: 11.818
iteration: 36; return: 0.365; episode_length: 11.561
iteration: 37; return: 0.370; episode_length: 11.886
iteration: 38; return: 0.400; episode_length: 11.870
iteration: 39; return: 0.363; episode_length: 11.933
iteration: 40; return: 0.429; episode_length: 11.412
iteration: 41; return: 0.452; episode_length: 11.489
iteration: 42; return: 0.474; episode_length: 11.606
iteration: 43; return: 0.486; episode_length: 11.165
iteration: 44; return: 0.494; episode_length: 11.188
iteration: 45; return: 0.485; episode_length: 10.519
iteration: 46; return: 0.532; episode_length: 10.674
iteration: 47; return: 0.546; episode_length: 10.835
iteration: 48; return: 0.536; episode_length: 10.168
iteration: 49; return: 0.578; episode_length: 10.052
iteration: 50; return: 0.527; episode_length: 10.429
iteration: 51; return: 0.524; episode_length: 10.330
iteration: 52; return: 0.559; episode_length: 10.187
iteration: 53; return: 0.588; episode_length: 9.905
iteration: 54; return: 0.600; episode_length: 9.913
iteration: 55; return: 0.598; episode_length: 9.950
iteration: 56; return: 0.608; episode_length: 9.741
iteration: 57; return: 0.599; episode_length: 10.016
iteration: 58; return: 0.593; episode_length: 9.728
iteration: 59; return: 0.598; episode_length: 9.451
iteration: 60; return: 0.621; episode_length: 9.447
iteration: 61; return: 0.598; episode_length: 9.622
iteration: 62; return: 0.634; episode_length: 9.142
iteration: 63; return: 0.601; episode_length: 9.392
iteration: 64; return: 0.627; episode_length: 9.562
iteration: 65; return: 0.599; episode_length: 9.683
iteration: 66; return: 0.650; episode_length: 9.289
iteration: 67; return: 0.609; episode_length: 9.454
iteration: 68; return: 0.629; episode_length: 9.334
iteration: 69; return: 0.638; episode_length: 9.472
iteration: 70; return: 0.628; episode_length: 9.451
iteration: 71; return: 0.646; episode_length: 9.297
iteration: 72; return: 0.645; episode_length: 9.232
iteration: 73; return: 0.623; episode_length: 9.234
iteration: 74; return: 0.651; episode_length: 9.361
iteration: 75; return: 0.636; episode_length: 9.176
iteration: 76; return: 0.653; episode_length: 9.133
iteration: 77; return: 0.641; episode_length: 9.236
iteration: 78; return: 0.658; episode_length: 9.174
iteration: 79; return: 0.637; episode_length: 9.236
iteration: 80; return: 0.633; episode_length: 9.082
iteration: 81; return: 0.651; episode_length: 9.351
iteration: 82; return: 0.635; episode_length: 8.968
iteration: 83; return: 0.653; episode_length: 8.953
iteration: 84; return: 0.671; episode_length: 9.043
iteration: 85; return: 0.661; episode_length: 8.932
iteration: 86; return: 0.624; episode_length: 9.120
iteration: 87; return: 0.650; episode_length: 9.122
iteration: 88; return: 0.669; episode_length: 9.009
iteration: 89; return: 0.661; episode_length: 8.950
iteration: 90; return: 0.667; episode_length: 8.791
iteration: 91; return: 0.649; episode_length: 9.049
iteration: 92; return: 0.647; episode_length: 9.107
iteration: 93; return: 0.656; episode_length: 9.052
iteration: 94; return: 0.648; episode_length: 8.996
iteration: 95; return: 0.668; episode_length: 8.853
iteration: 96; return: 0.656; episode_length: 8.890
iteration: 97; return: 0.666; episode_length: 8.829
iteration: 98; return: 0.657; episode_length: 8.952
iteration: 99; return: 0.664; episode_length: 8.961
iteration: 100; return: 0.661; episode_length: 9.115
In [5]:
Copied!
policy_array = [
derive_deterministic_policy(mdp=grid_mdp, policy=model)
for model in model_checkpoints
]
policy_array = [
derive_deterministic_policy(mdp=grid_mdp, policy=model)
for model in model_checkpoints
]
In [6]:
Copied!
plot_policy_step_grid_map = make_plot_policy_step_function(
columns=4, rows=3, policy_over_time=policy_array
)
plot_policy_step_grid_map = make_plot_policy_step_function(
columns=4, rows=3, policy_over_time=policy_array
)
In [7]:
Copied!
interactive_widgets = not os.getenv("CI") == "true" # non-interative in CI
if interactive_widgets:
import ipywidgets
from IPython.display import display
iteration_slider = ipywidgets.IntSlider(
min=0, max=len(model_checkpoints) - 1, step=1, value=0
)
w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)
display(w)
else:
plot_policy_step_grid_map(100)
interactive_widgets = not os.getenv("CI") == "true" # non-interative in CI
if interactive_widgets:
import ipywidgets
from IPython.display import display
iteration_slider = ipywidgets.IntSlider(
min=0, max=len(model_checkpoints) - 1, step=1, value=0
)
w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)
display(w)
else:
plot_policy_step_grid_map(100)
HIGHWAY EXAMPLE¶
In [8]:
Copied!
if False:
# we will change this to true later on, to see the effect
HIGHWAY_MDP_DICT["transition_probabilities_per_action"][LC_RIGHT_ACTION] = [
(0.4, LC_RIGHT_ACTION),
(0.6, STAY_IN_LANE_ACTION),
]
if False:
# we will change this to true later on, to see the effect
HIGHWAY_MDP_DICT["transition_probabilities_per_action"][LC_RIGHT_ACTION] = [
(0.4, LC_RIGHT_ACTION),
(0.6, STAY_IN_LANE_ACTION),
]
In [9]:
Copied!
highway_mdp = GridMDP(**HIGHWAY_MDP_DICT)
highway_mdp = GridMDP(**HIGHWAY_MDP_DICT)
In [10]:
Copied!
policy = CategoricalPolicy(
sizes=[len(highway_mdp.initial_state), 32, len(highway_mdp.actions)],
actions=list(highway_mdp.actions),
)
policy = CategoricalPolicy(
sizes=[len(highway_mdp.initial_state), 32, len(highway_mdp.actions)],
actions=list(highway_mdp.actions),
)
In [11]:
Copied!
model_checkpoints = policy_gradient(
mdp=highway_mdp,
policy=policy,
iterations=200,
return_history=True,
)
model_checkpoints = policy_gradient(
mdp=highway_mdp,
policy=policy,
iterations=200,
return_history=True,
)
iteration: 1; return: -65.620; episode_length: 10.850
iteration: 2; return: -65.883; episode_length: 11.971
iteration: 3; return: -65.885; episode_length: 13.365
iteration: 4; return: -65.623; episode_length: 13.556
iteration: 5; return: -65.709; episode_length: 14.320
iteration: 6; return: -65.766; episode_length: 13.944
iteration: 7; return: -65.997; episode_length: 14.176
iteration: 8; return: -66.269; episode_length: 14.309
iteration: 9; return: -66.230; episode_length: 14.555
iteration: 10; return: -65.893; episode_length: 14.435
iteration: 11; return: -66.272; episode_length: 15.109
iteration: 12; return: -65.962; episode_length: 14.695
iteration: 13; return: -66.153; episode_length: 14.779
iteration: 14; return: -66.142; episode_length: 15.409
iteration: 15; return: -65.453; episode_length: 14.629
iteration: 16; return: -66.295; episode_length: 15.096
iteration: 17; return: -66.403; episode_length: 15.641
iteration: 18; return: -65.878; episode_length: 15.277
iteration: 19; return: -65.982; episode_length: 15.228
iteration: 20; return: -66.609; episode_length: 15.804
iteration: 21; return: -65.876; episode_length: 15.133
iteration: 22; return: -65.865; episode_length: 15.021
iteration: 23; return: -66.317; episode_length: 15.543
iteration: 24; return: -65.621; episode_length: 15.197
iteration: 25; return: -65.585; episode_length: 15.293
iteration: 26; return: -66.006; episode_length: 15.586
iteration: 27; return: -65.865; episode_length: 15.431
iteration: 28; return: -65.874; episode_length: 15.362
iteration: 29; return: -66.025; episode_length: 15.466
iteration: 30; return: -66.207; episode_length: 15.943
iteration: 31; return: -65.862; episode_length: 15.677
iteration: 32; return: -65.994; episode_length: 15.811
iteration: 33; return: -65.637; episode_length: 15.428
iteration: 34; return: -66.157; episode_length: 15.777
iteration: 35; return: -65.823; episode_length: 15.531
iteration: 36; return: -65.909; episode_length: 15.808
iteration: 37; return: -65.757; episode_length: 15.394
iteration: 38; return: -65.483; episode_length: 15.391
iteration: 39; return: -65.953; episode_length: 15.607
iteration: 40; return: -65.794; episode_length: 15.650
iteration: 41; return: -66.003; episode_length: 15.785
iteration: 42; return: -65.469; episode_length: 15.347
iteration: 43; return: -65.874; episode_length: 15.833
iteration: 44; return: -66.105; episode_length: 15.981
iteration: 45; return: -66.107; episode_length: 15.820
iteration: 46; return: -65.429; episode_length: 15.139
iteration: 47; return: -65.710; episode_length: 15.435
iteration: 48; return: -65.224; episode_length: 15.139
iteration: 49; return: -64.479; episode_length: 14.248
iteration: 50; return: -64.478; episode_length: 14.065
iteration: 51; return: -63.887; episode_length: 13.485
iteration: 52; return: -63.183; episode_length: 12.730
iteration: 53; return: -63.229; episode_length: 12.725
iteration: 54; return: -62.890; episode_length: 12.202
iteration: 55; return: -62.449; episode_length: 11.891
iteration: 56; return: -62.062; episode_length: 11.444
iteration: 57; return: -61.766; episode_length: 11.254
iteration: 58; return: -61.779; episode_length: 10.724
iteration: 59; return: -61.537; episode_length: 10.398
iteration: 60; return: -61.503; episode_length: 10.127
iteration: 61; return: -61.573; episode_length: 9.942
iteration: 62; return: -61.365; episode_length: 9.806
iteration: 63; return: -61.201; episode_length: 9.679
iteration: 64; return: -61.274; episode_length: 9.592
iteration: 65; return: -61.625; episode_length: 9.576
iteration: 66; return: -61.278; episode_length: 9.543
iteration: 67; return: -61.473; episode_length: 9.544
iteration: 68; return: -61.483; episode_length: 9.631
iteration: 69; return: -61.277; episode_length: 9.683
iteration: 70; return: -61.419; episode_length: 9.707
iteration: 71; return: -61.187; episode_length: 9.668
iteration: 72; return: -61.318; episode_length: 9.707
iteration: 73; return: -61.310; episode_length: 9.579
iteration: 74; return: -61.050; episode_length: 9.588
iteration: 75; return: -61.269; episode_length: 9.475
iteration: 76; return: -61.303; episode_length: 9.400
iteration: 77; return: -61.968; episode_length: 9.385
iteration: 78; return: -61.371; episode_length: 9.517
iteration: 79; return: -61.606; episode_length: 9.537
iteration: 80; return: -61.307; episode_length: 9.435
iteration: 81; return: -61.333; episode_length: 9.463
iteration: 82; return: -61.051; episode_length: 9.501
iteration: 83; return: -61.368; episode_length: 9.392
iteration: 84; return: -61.278; episode_length: 9.481
iteration: 85; return: -61.457; episode_length: 9.456
iteration: 86; return: -61.331; episode_length: 9.463
iteration: 87; return: -61.375; episode_length: 9.369
iteration: 88; return: -61.404; episode_length: 9.353
iteration: 89; return: -61.595; episode_length: 9.338
iteration: 90; return: -61.370; episode_length: 9.404
iteration: 91; return: -61.295; episode_length: 9.350
iteration: 92; return: -61.375; episode_length: 9.336
iteration: 93; return: -61.301; episode_length: 9.375
iteration: 94; return: -61.535; episode_length: 9.383
iteration: 95; return: -61.542; episode_length: 9.426
iteration: 96; return: -61.225; episode_length: 9.386
iteration: 97; return: -61.453; episode_length: 9.330
iteration: 98; return: -61.290; episode_length: 9.236
iteration: 99; return: -61.555; episode_length: 9.204
iteration: 100; return: -61.893; episode_length: 9.195
iteration: 101; return: -61.592; episode_length: 9.229
iteration: 102; return: -61.600; episode_length: 9.212
iteration: 103; return: -61.786; episode_length: 9.238
iteration: 104; return: -61.670; episode_length: 9.178
iteration: 105; return: -61.943; episode_length: 9.176
iteration: 106; return: -61.455; episode_length: 9.178
iteration: 107; return: -61.691; episode_length: 9.223
iteration: 108; return: -61.705; episode_length: 9.214
iteration: 109; return: -61.531; episode_length: 9.280
iteration: 110; return: -61.319; episode_length: 9.332
iteration: 111; return: -61.557; episode_length: 9.297
iteration: 112; return: -61.435; episode_length: 9.336
iteration: 113; return: -61.066; episode_length: 9.392
iteration: 114; return: -61.319; episode_length: 9.443
iteration: 115; return: -61.162; episode_length: 9.418
iteration: 116; return: -61.357; episode_length: 9.402
iteration: 117; return: -61.310; episode_length: 9.519
iteration: 118; return: -61.240; episode_length: 9.523
iteration: 119; return: -61.092; episode_length: 9.588
iteration: 120; return: -61.002; episode_length: 9.546
iteration: 121; return: -61.157; episode_length: 9.854
iteration: 122; return: -60.864; episode_length: 9.885
iteration: 123; return: -61.150; episode_length: 9.775
iteration: 124; return: -60.896; episode_length: 10.026
iteration: 125; return: -60.865; episode_length: 10.087
iteration: 126; return: -60.972; episode_length: 10.177
iteration: 127; return: -61.031; episode_length: 10.394
iteration: 128; return: -61.160; episode_length: 10.535
iteration: 129; return: -61.124; episode_length: 10.484
iteration: 130; return: -61.419; episode_length: 10.594
iteration: 131; return: -61.238; episode_length: 10.479
iteration: 132; return: -61.205; episode_length: 10.687
iteration: 133; return: -61.113; episode_length: 10.705
iteration: 134; return: -61.366; episode_length: 10.778
iteration: 135; return: -61.249; episode_length: 10.736
iteration: 136; return: -61.181; episode_length: 10.816
iteration: 137; return: -61.287; episode_length: 10.982
iteration: 138; return: -61.252; episode_length: 10.872
iteration: 139; return: -61.413; episode_length: 10.937
iteration: 140; return: -61.195; episode_length: 10.755
iteration: 141; return: -61.450; episode_length: 10.791
iteration: 142; return: -61.324; episode_length: 10.896
iteration: 143; return: -61.098; episode_length: 10.665
iteration: 144; return: -61.231; episode_length: 10.625
iteration: 145; return: -61.379; episode_length: 10.904
iteration: 146; return: -61.541; episode_length: 11.100
iteration: 147; return: -61.113; episode_length: 10.859
iteration: 148; return: -61.461; episode_length: 10.928
iteration: 149; return: -61.227; episode_length: 10.747
iteration: 150; return: -61.033; episode_length: 10.464
iteration: 151; return: -61.368; episode_length: 10.511
iteration: 152; return: -61.168; episode_length: 10.384
iteration: 153; return: -61.093; episode_length: 10.388
iteration: 154; return: -61.085; episode_length: 10.167
iteration: 155; return: -61.000; episode_length: 10.260
iteration: 156; return: -61.037; episode_length: 10.224
iteration: 157; return: -61.248; episode_length: 10.089
iteration: 158; return: -60.811; episode_length: 10.050
iteration: 159; return: -60.953; episode_length: 10.258
iteration: 160; return: -61.252; episode_length: 10.162
iteration: 161; return: -61.258; episode_length: 10.351
iteration: 162; return: -61.045; episode_length: 10.227
iteration: 163; return: -61.042; episode_length: 10.414
iteration: 164; return: -61.142; episode_length: 10.283
iteration: 165; return: -60.865; episode_length: 10.233
iteration: 166; return: -61.246; episode_length: 10.351
iteration: 167; return: -61.326; episode_length: 10.462
iteration: 168; return: -60.977; episode_length: 10.300
iteration: 169; return: -61.081; episode_length: 10.433
iteration: 170; return: -61.209; episode_length: 10.565
iteration: 171; return: -60.960; episode_length: 10.403
iteration: 172; return: -61.055; episode_length: 10.528
iteration: 173; return: -61.161; episode_length: 10.497
iteration: 174; return: -61.138; episode_length: 10.349
iteration: 175; return: -61.089; episode_length: 10.339
iteration: 176; return: -60.953; episode_length: 10.193
iteration: 177; return: -60.981; episode_length: 10.366
iteration: 178; return: -61.099; episode_length: 10.360
iteration: 179; return: -61.085; episode_length: 10.165
iteration: 180; return: -61.294; episode_length: 10.224
iteration: 181; return: -60.988; episode_length: 10.136
iteration: 182; return: -61.197; episode_length: 10.177
iteration: 183; return: -61.099; episode_length: 10.093
iteration: 184; return: -61.160; episode_length: 9.889
iteration: 185; return: -60.998; episode_length: 9.777
iteration: 186; return: -61.181; episode_length: 9.664
iteration: 187; return: -61.112; episode_length: 9.633
iteration: 188; return: -60.799; episode_length: 9.577
iteration: 189; return: -60.912; episode_length: 9.539
iteration: 190; return: -61.298; episode_length: 9.523
iteration: 191; return: -61.085; episode_length: 9.491
iteration: 192; return: -61.186; episode_length: 9.433
iteration: 193; return: -61.110; episode_length: 9.493
iteration: 194; return: -61.250; episode_length: 9.473
iteration: 195; return: -61.027; episode_length: 9.491
iteration: 196; return: -61.313; episode_length: 9.440
iteration: 197; return: -61.168; episode_length: 9.426
iteration: 198; return: -61.212; episode_length: 9.487
iteration: 199; return: -61.206; episode_length: 9.537
iteration: 200; return: -61.382; episode_length: 9.519
In [12]:
Copied!
policy_array = [
derive_deterministic_policy(mdp=highway_mdp, policy=model)
for model in model_checkpoints
]
policy_array = [
derive_deterministic_policy(mdp=highway_mdp, policy=model)
for model in model_checkpoints
]
In [13]:
Copied!
plot_policy_step_grid_map = make_plot_policy_step_function(
columns=10, rows=4, policy_over_time=policy_array
)
plot_policy_step_grid_map = make_plot_policy_step_function(
columns=10, rows=4, policy_over_time=policy_array
)
In [14]:
Copied!
if interactive_widgets:
import ipywidgets
from IPython.display import display
iteration_slider = ipywidgets.IntSlider(
min=0, max=len(model_checkpoints) - 1, step=1, value=0
)
w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)
display(w)
else:
plot_policy_step_grid_map(200)
if interactive_widgets:
import ipywidgets
from IPython.display import display
iteration_slider = ipywidgets.IntSlider(
min=0, max=len(model_checkpoints) - 1, step=1, value=0
)
w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)
display(w)
else:
plot_policy_step_grid_map(200)