policy
This module contains the CategoricalPolicy implementation.
CategoricalPolicy
A categorical policy parameterized by a neural network.
Source code in src/behavior_generation_lecture_python/mdp/policy.py
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 | |
__init__(sizes, actions, seed=None)
Initialize the categorical policy.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sizes
|
List[int]
|
List of layer sizes for the MLP. |
required |
actions
|
List[Any]
|
List of available actions. |
required |
seed
|
Optional[int]
|
Random seed for reproducibility (default: None). |
None
|
Source code in src/behavior_generation_lecture_python/mdp/policy.py
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 | |
get_action(state, deterministic=False)
Returns an action sample for the given state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
Tensor
|
The current state tensor. |
required |
deterministic
|
bool
|
If True, return the most likely action. |
False
|
Returns:
| Type | Description |
|---|---|
Any
|
The selected action. |
Source code in src/behavior_generation_lecture_python/mdp/policy.py
60 61 62 63 64 65 66 67 68 69 70 71 72 73 | |
get_log_prob(states, actions)
Returns the log-probability for taking the action, when being in the given state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
states
|
Tensor
|
Batch of state tensors. |
required |
actions
|
Tensor
|
Batch of action tensors. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Log-probabilities of the actions. |
Source code in src/behavior_generation_lecture_python/mdp/policy.py
75 76 77 78 79 80 81 82 83 84 85 86 87 | |
multi_layer_perceptron(sizes, activation=nn.ReLU, output_activation=nn.Identity)
Returns a multi-layer perceptron
Source code in src/behavior_generation_lecture_python/mdp/policy.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 | |