-
Notifications
You must be signed in to change notification settings - Fork 0
/
modules.py
114 lines (93 loc) · 3.76 KB
/
modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import numpy as np
from math import sqrt
from typing import Tuple, Optional
import torch
from torch import nn
from torch.distributions import Normal
class DeterministicActor(nn.Module):
def __init__(self,
state_dim: int,
action_dim: int,
hidden_dim: int = 256,
edac_init: bool = False,
max_action: float = 1.0) -> None:
super().__init__()
self.action_dim = action_dim
self.max_action = max_action
self.trunk = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim)
)
if edac_init:
# init as in the EDAC paper
for layer in self.trunk[::2]:
nn.init.constant_(layer.bias, 0.1)
def forward(self, state: torch.Tensor) -> torch.Tensor:
out = self.trunk(state)
out = torch.tanh(out)
return self.max_action * out
@torch.no_grad()
def act(self, state: np.ndarray, device: str) -> np.ndarray:
state = torch.tensor(state, device=device, dtype=torch.float32)
action = self(state).cpu().numpy()
return action
class EnsembledLinear(nn.Module):
def __init__(self,
in_features: int,
out_features: int,
ensemble_size: int) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.ensemble_size = ensemble_size
self.weight = nn.Parameter(torch.empty(ensemble_size, in_features, out_features))
self.bias = nn.Parameter(torch.empty(ensemble_size, 1, out_features))
self.reset_parameters()
def reset_parameters(self):
for layer in range(self.ensemble_size):
nn.init.kaiming_uniform_(self.weight[layer], a=sqrt(5))
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
bound = 0
if fan_in > 0:
bound = 1 / sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = x @ self.weight + self.bias
return out
class EnsembledCritic(nn.Module):
def __init__(self,
state_dim: int,
action_dim: int,
hidden_dim: int = 256,
num_critics: int = 2,
layer_norm: bool = True,
edac_init: bool = True) -> None:
super().__init__()
#block = nn.LayerNorm(hidden_dim) if layer_norm else nn.Identity()
self.num_critics = num_critics
self.critic = nn.Sequential(
EnsembledLinear(state_dim + action_dim, hidden_dim, num_critics),
nn.LayerNorm(hidden_dim) if layer_norm else nn.Identity(),
nn.ReLU(),
EnsembledLinear(hidden_dim, hidden_dim, num_critics),
nn.LayerNorm(hidden_dim) if layer_norm else nn.Identity(),
nn.ReLU(),
EnsembledLinear(hidden_dim, 1, num_critics)
)
if edac_init:
# init as in the EDAC paper
for layer in self.critic[::3]:
nn.init.constant_(layer.bias, 0.1)
nn.init.uniform_(self.critic[-1].weight, -3e-3, 3e-3)
nn.init.uniform_(self.critic[-1].bias, -3e-3, 3e-3)
def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
concat = torch.cat([state, action], dim=-1)
concat = concat.unsqueeze(0)
concat = concat.repeat_interleave(self.num_critics, dim=0)
q_values = self.critic(concat).squeeze(-1)
return q_values