Skip to content

policy

This module contains the CategoricalPolicy implementation.

CategorialPolicy

Source code in src/behavior_generation_lecture_python/mdp/policy.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
class CategorialPolicy:
    def __init__(self, sizes: List[int], actions: List):
        assert sizes[-1] == len(actions)
        torch.manual_seed(1337)
        self.net = multi_layer_perceptron(sizes=sizes)
        self.actions = actions
        self._actions_tensor = torch.tensor(actions, dtype=torch.long).view(
            len(actions), -1
        )

    def _get_distribution(self, state: torch.Tensor):
        """Calls the model and returns a categorial distribution over the actions."""
        logits = self.net(state)
        return Categorical(logits=logits)

    def get_action(self, state: torch.Tensor, deterministic: bool = False):
        """Returns an action sample for the given state"""
        policy = self._get_distribution(state)
        if deterministic:
            return self.actions[policy.mode.item()]
        return self.actions[policy.sample().item()]

    def get_log_prob(self, states: torch.Tensor, actions: torch.Tensor):
        """Returns the log-probability for taking the action, when being the given state"""
        return self._get_distribution(states).log_prob(
            self._get_action_id_from_action(actions)
        )

    def _get_action_id_from_action(self, actions: torch.Tensor):
        """Returns the indices of the passed actions in self.actions"""
        reshaped_actions = actions.unsqueeze(1).expand(
            -1, self._actions_tensor.size(0), -1
        )
        reshaped_actions_tensor = self._actions_tensor.unsqueeze(0).expand(
            actions.size(0), -1, -1
        )
        return torch.where(
            torch.all(reshaped_actions == reshaped_actions_tensor, dim=-1)
        )[1]

get_action(state, deterministic=False)

Returns an action sample for the given state

Source code in src/behavior_generation_lecture_python/mdp/policy.py
41
42
43
44
45
46
def get_action(self, state: torch.Tensor, deterministic: bool = False):
    """Returns an action sample for the given state"""
    policy = self._get_distribution(state)
    if deterministic:
        return self.actions[policy.mode.item()]
    return self.actions[policy.sample().item()]

get_log_prob(states, actions)

Returns the log-probability for taking the action, when being the given state

Source code in src/behavior_generation_lecture_python/mdp/policy.py
48
49
50
51
52
def get_log_prob(self, states: torch.Tensor, actions: torch.Tensor):
    """Returns the log-probability for taking the action, when being the given state"""
    return self._get_distribution(states).log_prob(
        self._get_action_id_from_action(actions)
    )

multi_layer_perceptron(sizes, activation=nn.ReLU, output_activation=nn.Identity)

Returns a multi-layer perceptron

Source code in src/behavior_generation_lecture_python/mdp/policy.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def multi_layer_perceptron(
    sizes: List[int],
    activation: Type[nn.Module] = nn.ReLU,
    output_activation: Type[nn.Module] = nn.Identity,
):
    """Returns a multi-layer perceptron"""
    mlp = nn.Sequential()
    for i in range(len(sizes) - 1):
        mlp.append(nn.Linear(sizes[i], sizes[i + 1]))
        if i < len(sizes) - 2:
            mlp.append(activation())
        else:
            mlp.append(output_activation())
    return mlp