Policy Gradient
In [1]:
Copied!
import os
from behavior_generation_lecture_python.mdp.policy import CategoricalPolicy
from behavior_generation_lecture_python.utils.grid_plotting import (
make_plot_policy_step_function,
)
from behavior_generation_lecture_python.mdp.mdp import (
GridMDP,
policy_gradient,
derive_deterministic_policy,
GRID_MDP_DICT,
HIGHWAY_MDP_DICT,
LC_RIGHT_ACTION,
STAY_IN_LANE_ACTION,
)
HIGHWAY_MDP_DICT["restrict_actions_to_available_states"] = False
import os
from behavior_generation_lecture_python.mdp.policy import CategoricalPolicy
from behavior_generation_lecture_python.utils.grid_plotting import (
make_plot_policy_step_function,
)
from behavior_generation_lecture_python.mdp.mdp import (
GridMDP,
policy_gradient,
derive_deterministic_policy,
GRID_MDP_DICT,
HIGHWAY_MDP_DICT,
LC_RIGHT_ACTION,
STAY_IN_LANE_ACTION,
)
HIGHWAY_MDP_DICT["restrict_actions_to_available_states"] = False
TOY EXAMPLE¶
In [2]:
Copied!
grid_mdp = GridMDP(**GRID_MDP_DICT)
grid_mdp = GridMDP(**GRID_MDP_DICT)
In [3]:
Copied!
policy = CategoricalPolicy(
sizes=[len(grid_mdp.initial_state), 32, len(grid_mdp.actions)],
actions=list(grid_mdp.actions),
)
policy = CategoricalPolicy(
sizes=[len(grid_mdp.initial_state), 32, len(grid_mdp.actions)],
actions=list(grid_mdp.actions),
)
In [4]:
Copied!
model_checkpoints = policy_gradient(
mdp=grid_mdp,
policy=policy,
iterations=100,
return_history=True,
)
model_checkpoints = policy_gradient(
mdp=grid_mdp,
policy=policy,
iterations=100,
return_history=True,
)
iteration: 1; return: -1.794; episode_length: 31.488
iteration: 2; return: -1.522; episode_length: 24.645
iteration: 3; return: -1.320; episode_length: 18.337
iteration: 4; return: -1.214; episode_length: 14.388
iteration: 5; return: -1.101; episode_length: 12.464
iteration: 6; return: -1.004; episode_length: 10.876
iteration: 7; return: -0.965; episode_length: 10.552
iteration: 8; return: -0.882; episode_length: 10.294
iteration: 9; return: -0.721; episode_length: 9.808
iteration: 10; return: -0.705; episode_length: 10.383
iteration: 11; return: -0.557; episode_length: 9.913
iteration: 12; return: -0.463; episode_length: 9.976
iteration: 13; return: -0.424; episode_length: 9.909
iteration: 14; return: -0.277; episode_length: 10.565
iteration: 15; return: -0.144; episode_length: 11.580
iteration: 16; return: -0.190; episode_length: 12.034
iteration: 17; return: -0.152; episode_length: 13.262
iteration: 18; return: -0.048; episode_length: 13.967
iteration: 19; return: -0.067; episode_length: 16.687
iteration: 20; return: 0.035; episode_length: 16.776
iteration: 21; return: -0.015; episode_length: 18.585
iteration: 22; return: 0.015; episode_length: 16.824
iteration: 23; return: -0.031; episode_length: 18.599
iteration: 24; return: -0.008; episode_length: 17.717
iteration: 25; return: -0.041; episode_length: 18.270
iteration: 26; return: 0.047; episode_length: 18.148
iteration: 27; return: 0.130; episode_length: 16.253
iteration: 28; return: 0.157; episode_length: 15.876
iteration: 29; return: 0.166; episode_length: 15.072
iteration: 30; return: 0.143; episode_length: 14.221
iteration: 31; return: 0.175; episode_length: 13.357
iteration: 32; return: 0.209; episode_length: 12.467
iteration: 33; return: 0.222; episode_length: 12.554
iteration: 34; return: 0.244; episode_length: 11.872
iteration: 35; return: 0.267; episode_length: 11.510
iteration: 36; return: 0.225; episode_length: 11.460
iteration: 37; return: 0.212; episode_length: 11.142
iteration: 38; return: 0.227; episode_length: 11.044
iteration: 39; return: 0.246; episode_length: 11.040
iteration: 40; return: 0.241; episode_length: 12.335
iteration: 41; return: 0.270; episode_length: 11.403
iteration: 42; return: 0.275; episode_length: 11.747
iteration: 43; return: 0.271; episode_length: 12.486
iteration: 44; return: 0.256; episode_length: 11.539
iteration: 45; return: 0.324; episode_length: 12.279
iteration: 46; return: 0.343; episode_length: 12.549
iteration: 47; return: 0.314; episode_length: 12.182
iteration: 48; return: 0.283; episode_length: 12.956
iteration: 49; return: 0.356; episode_length: 12.226
iteration: 50; return: 0.300; episode_length: 12.427
iteration: 51; return: 0.346; episode_length: 12.406
iteration: 52; return: 0.334; episode_length: 12.271
iteration: 53; return: 0.391; episode_length: 12.207
iteration: 54; return: 0.352; episode_length: 12.109
iteration: 55; return: 0.372; episode_length: 12.022
iteration: 56; return: 0.460; episode_length: 12.085
iteration: 57; return: 0.422; episode_length: 12.265
iteration: 58; return: 0.419; episode_length: 11.588
iteration: 59; return: 0.406; episode_length: 12.300
iteration: 60; return: 0.432; episode_length: 11.443
iteration: 61; return: 0.429; episode_length: 11.581
iteration: 62; return: 0.463; episode_length: 11.837
iteration: 63; return: 0.454; episode_length: 11.731
iteration: 64; return: 0.407; episode_length: 11.721
iteration: 65; return: 0.489; episode_length: 11.026
iteration: 66; return: 0.493; episode_length: 11.393
iteration: 67; return: 0.485; episode_length: 11.460
iteration: 68; return: 0.473; episode_length: 10.908
iteration: 69; return: 0.494; episode_length: 11.201
iteration: 70; return: 0.489; episode_length: 11.201
iteration: 71; return: 0.512; episode_length: 10.829
iteration: 72; return: 0.498; episode_length: 11.286
iteration: 73; return: 0.550; episode_length: 10.738
iteration: 74; return: 0.522; episode_length: 10.525
iteration: 75; return: 0.515; episode_length: 10.759
iteration: 76; return: 0.553; episode_length: 10.328
iteration: 77; return: 0.567; episode_length: 10.475
iteration: 78; return: 0.559; episode_length: 10.466
iteration: 79; return: 0.554; episode_length: 10.380
iteration: 80; return: 0.576; episode_length: 10.275
iteration: 81; return: 0.604; episode_length: 9.994
iteration: 82; return: 0.582; episode_length: 10.036
iteration: 83; return: 0.615; episode_length: 9.937
iteration: 84; return: 0.602; episode_length: 9.683
iteration: 85; return: 0.607; episode_length: 9.584
iteration: 86; return: 0.616; episode_length: 9.642
iteration: 87; return: 0.602; episode_length: 9.696
iteration: 88; return: 0.612; episode_length: 9.553
iteration: 89; return: 0.628; episode_length: 9.373
iteration: 90; return: 0.625; episode_length: 9.442
iteration: 91; return: 0.649; episode_length: 9.479
iteration: 92; return: 0.649; episode_length: 9.392
iteration: 93; return: 0.599; episode_length: 9.274
iteration: 94; return: 0.638; episode_length: 9.227
iteration: 95; return: 0.616; episode_length: 9.457
iteration: 96; return: 0.625; episode_length: 9.359
iteration: 97; return: 0.622; episode_length: 9.424
iteration: 98; return: 0.626; episode_length: 9.246
iteration: 99; return: 0.621; episode_length: 9.269
iteration: 100; return: 0.653; episode_length: 9.126
In [5]:
Copied!
policy_array = [
derive_deterministic_policy(mdp=grid_mdp, policy=model)
for model in model_checkpoints
]
policy_array = [
derive_deterministic_policy(mdp=grid_mdp, policy=model)
for model in model_checkpoints
]
In [6]:
Copied!
plot_policy_step_grid_map = make_plot_policy_step_function(
columns=4, rows=3, policy_over_time=policy_array
)
plot_policy_step_grid_map = make_plot_policy_step_function(
columns=4, rows=3, policy_over_time=policy_array
)
In [7]:
Copied!
interactive_widgets = not os.getenv("CI") == "true" # non-interative in CI
if interactive_widgets:
import ipywidgets
from IPython.display import display
iteration_slider = ipywidgets.IntSlider(
min=0, max=len(model_checkpoints) - 1, step=1, value=0
)
w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)
display(w)
else:
plot_policy_step_grid_map(100)
interactive_widgets = not os.getenv("CI") == "true" # non-interative in CI
if interactive_widgets:
import ipywidgets
from IPython.display import display
iteration_slider = ipywidgets.IntSlider(
min=0, max=len(model_checkpoints) - 1, step=1, value=0
)
w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)
display(w)
else:
plot_policy_step_grid_map(100)
HIGHWAY EXAMPLE¶
In [8]:
Copied!
if False:
# we will change this to true later on, to see the effect
HIGHWAY_MDP_DICT["transition_probabilities_per_action"][LC_RIGHT_ACTION] = [
(0.4, LC_RIGHT_ACTION),
(0.6, STAY_IN_LANE_ACTION),
]
if False:
# we will change this to true later on, to see the effect
HIGHWAY_MDP_DICT["transition_probabilities_per_action"][LC_RIGHT_ACTION] = [
(0.4, LC_RIGHT_ACTION),
(0.6, STAY_IN_LANE_ACTION),
]
In [9]:
Copied!
highway_mdp = GridMDP(**HIGHWAY_MDP_DICT)
highway_mdp = GridMDP(**HIGHWAY_MDP_DICT)
In [10]:
Copied!
policy = CategoricalPolicy(
sizes=[len(highway_mdp.initial_state), 32, len(highway_mdp.actions)],
actions=list(highway_mdp.actions),
)
policy = CategoricalPolicy(
sizes=[len(highway_mdp.initial_state), 32, len(highway_mdp.actions)],
actions=list(highway_mdp.actions),
)
In [11]:
Copied!
model_checkpoints = policy_gradient(
mdp=highway_mdp,
policy=policy,
iterations=200,
return_history=True,
)
model_checkpoints = policy_gradient(
mdp=highway_mdp,
policy=policy,
iterations=200,
return_history=True,
)
iteration: 1; return: -62.917; episode_length: 9.720
iteration: 2; return: -61.538; episode_length: 9.876
iteration: 3; return: -57.298; episode_length: 10.130
iteration: 4; return: -47.298; episode_length: 11.224
iteration: 5; return: -41.468; episode_length: 12.935
iteration: 6; return: -39.515; episode_length: 15.179
iteration: 7; return: -41.747; episode_length: 17.851
iteration: 8; return: -44.227; episode_length: 18.344
iteration: 9; return: -44.331; episode_length: 19.250
iteration: 10; return: -44.106; episode_length: 18.962
iteration: 11; return: -44.240; episode_length: 19.415
iteration: 12; return: -44.564; episode_length: 19.394
iteration: 13; return: -41.601; episode_length: 18.355
iteration: 14; return: -40.887; episode_length: 18.193
iteration: 15; return: -41.864; episode_length: 18.333
iteration: 16; return: -42.188; episode_length: 18.423
iteration: 17; return: -38.146; episode_length: 16.628
iteration: 18; return: -37.031; episode_length: 15.505
iteration: 19; return: -36.000; episode_length: 14.418
iteration: 20; return: -33.105; episode_length: 12.871
iteration: 21; return: -32.947; episode_length: 11.418
iteration: 22; return: -34.978; episode_length: 10.998
iteration: 23; return: -36.744; episode_length: 10.797
iteration: 24; return: -34.830; episode_length: 11.053
iteration: 25; return: -33.509; episode_length: 11.391
iteration: 26; return: -32.570; episode_length: 11.936
iteration: 27; return: -32.131; episode_length: 12.615
iteration: 28; return: -33.005; episode_length: 13.284
iteration: 29; return: -32.722; episode_length: 13.531
iteration: 30; return: -33.328; episode_length: 13.705
iteration: 31; return: -31.395; episode_length: 13.281
iteration: 32; return: -29.819; episode_length: 12.399
iteration: 33; return: -29.758; episode_length: 11.660
iteration: 34; return: -30.529; episode_length: 11.676
iteration: 35; return: -29.985; episode_length: 10.859
iteration: 36; return: -32.002; episode_length: 11.305
iteration: 37; return: -29.539; episode_length: 11.377
iteration: 38; return: -29.656; episode_length: 11.642
iteration: 39; return: -28.407; episode_length: 12.053
iteration: 40; return: -28.533; episode_length: 12.292
iteration: 41; return: -29.078; episode_length: 12.515
iteration: 42; return: -29.273; episode_length: 12.312
iteration: 43; return: -28.232; episode_length: 12.099
iteration: 44; return: -27.892; episode_length: 11.724
iteration: 45; return: -27.153; episode_length: 11.258
iteration: 46; return: -28.418; episode_length: 10.732
iteration: 47; return: -28.074; episode_length: 10.850
iteration: 48; return: -28.242; episode_length: 10.819
iteration: 49; return: -27.431; episode_length: 11.394
iteration: 50; return: -28.317; episode_length: 11.910
iteration: 51; return: -27.893; episode_length: 12.484
iteration: 52; return: -27.322; episode_length: 12.481
iteration: 53; return: -28.648; episode_length: 12.783
iteration: 54; return: -29.639; episode_length: 13.128
iteration: 55; return: -26.791; episode_length: 12.325
iteration: 56; return: -27.113; episode_length: 11.800
iteration: 57; return: -27.101; episode_length: 11.245
iteration: 58; return: -25.609; episode_length: 10.887
iteration: 59; return: -26.805; episode_length: 10.709
iteration: 60; return: -26.178; episode_length: 10.736
iteration: 61; return: -26.894; episode_length: 10.577
iteration: 62; return: -26.129; episode_length: 10.789
iteration: 63; return: -24.922; episode_length: 10.838
iteration: 64; return: -27.066; episode_length: 11.319
iteration: 65; return: -25.235; episode_length: 11.396
iteration: 66; return: -24.930; episode_length: 11.789
iteration: 67; return: -24.473; episode_length: 11.676
iteration: 68; return: -25.687; episode_length: 11.432
iteration: 69; return: -24.316; episode_length: 11.309
iteration: 70; return: -25.059; episode_length: 10.969
iteration: 71; return: -24.366; episode_length: 11.046
iteration: 72; return: -24.464; episode_length: 10.915
iteration: 73; return: -24.308; episode_length: 11.007
iteration: 74; return: -24.266; episode_length: 10.833
iteration: 75; return: -23.865; episode_length: 10.553
iteration: 76; return: -23.666; episode_length: 10.726
iteration: 77; return: -23.715; episode_length: 10.711
iteration: 78; return: -24.226; episode_length: 10.996
iteration: 79; return: -23.806; episode_length: 11.060
iteration: 80; return: -22.719; episode_length: 11.444
iteration: 81; return: -23.954; episode_length: 11.508
iteration: 82; return: -22.082; episode_length: 11.102
iteration: 83; return: -23.683; episode_length: 10.876
iteration: 84; return: -23.026; episode_length: 10.789
iteration: 85; return: -23.423; episode_length: 10.575
iteration: 86; return: -22.719; episode_length: 10.555
iteration: 87; return: -23.713; episode_length: 10.568
iteration: 88; return: -21.966; episode_length: 10.519
iteration: 89; return: -22.577; episode_length: 10.431
iteration: 90; return: -22.662; episode_length: 10.555
iteration: 91; return: -21.053; episode_length: 10.586
iteration: 92; return: -20.977; episode_length: 10.486
iteration: 93; return: -21.279; episode_length: 10.657
iteration: 94; return: -21.132; episode_length: 10.846
iteration: 95; return: -21.761; episode_length: 10.866
iteration: 96; return: -21.706; episode_length: 10.825
iteration: 97; return: -20.998; episode_length: 10.763
iteration: 98; return: -22.194; episode_length: 10.672
iteration: 99; return: -21.345; episode_length: 10.791
iteration: 100; return: -21.343; episode_length: 10.541
iteration: 101; return: -21.122; episode_length: 10.491
iteration: 102; return: -21.294; episode_length: 10.521
iteration: 103; return: -21.591; episode_length: 10.755
iteration: 104; return: -21.582; episode_length: 10.678
iteration: 105; return: -21.544; episode_length: 10.633
iteration: 106; return: -21.896; episode_length: 10.833
iteration: 107; return: -20.780; episode_length: 10.885
iteration: 108; return: -21.558; episode_length: 10.825
iteration: 109; return: -21.172; episode_length: 10.891
iteration: 110; return: -21.699; episode_length: 10.898
iteration: 111; return: -21.105; episode_length: 10.996
iteration: 112; return: -20.384; episode_length: 10.806
iteration: 113; return: -21.614; episode_length: 10.874
iteration: 114; return: -21.067; episode_length: 10.857
iteration: 115; return: -20.809; episode_length: 10.761
iteration: 116; return: -20.372; episode_length: 10.699
iteration: 117; return: -21.534; episode_length: 10.686
iteration: 118; return: -20.761; episode_length: 10.663
iteration: 119; return: -19.975; episode_length: 10.604
iteration: 120; return: -20.722; episode_length: 10.626
iteration: 121; return: -21.335; episode_length: 10.612
iteration: 122; return: -20.794; episode_length: 10.717
iteration: 123; return: -19.665; episode_length: 10.444
iteration: 124; return: -20.612; episode_length: 10.667
iteration: 125; return: -20.224; episode_length: 10.686
iteration: 126; return: -20.763; episode_length: 10.690
iteration: 127; return: -19.887; episode_length: 10.670
iteration: 128; return: -19.748; episode_length: 10.667
iteration: 129; return: -21.810; episode_length: 10.701
iteration: 130; return: -21.442; episode_length: 10.827
iteration: 131; return: -20.092; episode_length: 10.763
iteration: 132; return: -20.224; episode_length: 10.795
iteration: 133; return: -20.441; episode_length: 10.517
iteration: 134; return: -20.667; episode_length: 10.674
iteration: 135; return: -19.976; episode_length: 10.835
iteration: 136; return: -21.274; episode_length: 10.900
iteration: 137; return: -19.654; episode_length: 10.759
iteration: 138; return: -19.377; episode_length: 10.606
iteration: 139; return: -20.379; episode_length: 10.469
iteration: 140; return: -20.198; episode_length: 10.672
iteration: 141; return: -20.611; episode_length: 10.647
iteration: 142; return: -20.572; episode_length: 10.553
iteration: 143; return: -21.215; episode_length: 10.738
iteration: 144; return: -20.343; episode_length: 10.734
iteration: 145; return: -20.712; episode_length: 10.840
iteration: 146; return: -20.238; episode_length: 10.846
iteration: 147; return: -19.650; episode_length: 10.736
iteration: 148; return: -20.139; episode_length: 10.755
iteration: 149; return: -20.495; episode_length: 10.676
iteration: 150; return: -20.560; episode_length: 10.705
iteration: 151; return: -20.857; episode_length: 10.876
iteration: 152; return: -19.612; episode_length: 10.913
iteration: 153; return: -19.637; episode_length: 10.996
iteration: 154; return: -19.583; episode_length: 10.932
iteration: 155; return: -20.387; episode_length: 10.943
iteration: 156; return: -19.896; episode_length: 10.840
iteration: 157; return: -19.819; episode_length: 10.819
iteration: 158; return: -19.361; episode_length: 10.755
iteration: 159; return: -19.491; episode_length: 10.745
iteration: 160; return: -20.521; episode_length: 10.904
iteration: 161; return: -20.918; episode_length: 10.859
iteration: 162; return: -21.231; episode_length: 10.722
iteration: 163; return: -19.714; episode_length: 10.614
iteration: 164; return: -19.876; episode_length: 10.722
iteration: 165; return: -19.760; episode_length: 10.902
iteration: 166; return: -20.255; episode_length: 10.842
iteration: 167; return: -19.910; episode_length: 10.928
iteration: 168; return: -21.079; episode_length: 10.974
iteration: 169; return: -20.208; episode_length: 10.943
iteration: 170; return: -20.704; episode_length: 10.974
iteration: 171; return: -19.822; episode_length: 10.900
iteration: 172; return: -20.456; episode_length: 10.855
iteration: 173; return: -20.013; episode_length: 10.711
iteration: 174; return: -19.541; episode_length: 10.745
iteration: 175; return: -20.371; episode_length: 10.663
iteration: 176; return: -19.275; episode_length: 10.579
iteration: 177; return: -20.633; episode_length: 10.676
iteration: 178; return: -21.207; episode_length: 10.793
iteration: 179; return: -21.222; episode_length: 10.701
iteration: 180; return: -19.521; episode_length: 10.697
iteration: 181; return: -20.507; episode_length: 10.885
iteration: 182; return: -20.078; episode_length: 10.808
iteration: 183; return: -19.422; episode_length: 10.831
iteration: 184; return: -19.169; episode_length: 10.712
iteration: 185; return: -20.415; episode_length: 10.690
iteration: 186; return: -20.321; episode_length: 10.722
iteration: 187; return: -19.364; episode_length: 10.795
iteration: 188; return: -20.079; episode_length: 10.624
iteration: 189; return: -20.257; episode_length: 10.902
iteration: 190; return: -20.452; episode_length: 10.878
iteration: 191; return: -20.752; episode_length: 10.904
iteration: 192; return: -19.927; episode_length: 10.816
iteration: 193; return: -20.109; episode_length: 10.928
iteration: 194; return: -19.050; episode_length: 10.816
iteration: 195; return: -20.428; episode_length: 10.734
iteration: 196; return: -19.434; episode_length: 10.902
iteration: 197; return: -20.747; episode_length: 10.565
iteration: 198; return: -19.563; episode_length: 10.770
iteration: 199; return: -19.368; episode_length: 10.913
iteration: 200; return: -19.569; episode_length: 10.782
In [12]:
Copied!
policy_array = [
derive_deterministic_policy(mdp=highway_mdp, policy=model)
for model in model_checkpoints
]
policy_array = [
derive_deterministic_policy(mdp=highway_mdp, policy=model)
for model in model_checkpoints
]
In [13]:
Copied!
plot_policy_step_grid_map = make_plot_policy_step_function(
columns=10, rows=4, policy_over_time=policy_array
)
plot_policy_step_grid_map = make_plot_policy_step_function(
columns=10, rows=4, policy_over_time=policy_array
)
In [14]:
Copied!
if interactive_widgets:
import ipywidgets
from IPython.display import display
iteration_slider = ipywidgets.IntSlider(
min=0, max=len(model_checkpoints) - 1, step=1, value=0
)
w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)
display(w)
else:
plot_policy_step_grid_map(200)
if interactive_widgets:
import ipywidgets
from IPython.display import display
iteration_slider = ipywidgets.IntSlider(
min=0, max=len(model_checkpoints) - 1, step=1, value=0
)
w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)
display(w)
else:
plot_policy_step_grid_map(200)