Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
zzmtsvv authored Sep 15, 2023
1 parent c4bef12 commit 15ac23b
Show file tree
Hide file tree
Showing 4 changed files with 555 additions and 0 deletions.
48 changes: 48 additions & 0 deletions td7/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from dataclasses import dataclass
import torch


@dataclass
class td7_config:
# Experiment
device: str = "cuda" if torch.cuda.is_available() else "cpu"
dataset_name: str = "halfcheetah-medium-v2"
seed: int = 42

max_timesteps: int = int(1e6)

state_dim: int = 17
action_dim: int = 6

batch_size: int = 256
buffer_size: int = int(1e6)
discount: float = 0.99
target_update_freq: int = 250
exploration_noise: float = 0.1

policy_noise: float = 0.2
noise_clip: float = 0.5
policy_freq: int = 2

max_action: float = 1.0

alpha: float = 0.4
min_priority: float = 1.0

lambda_coef: float = 0.1

embedding_dim: int = 256
hidden_dim: int = 256
encoder_lr: float = 3e-4
encoder_activation: str = "elu"
actor_lr: float = 3e-4
actor_activation: str = "relu"
critic_lr: float = 3e-4
critic_activation: str = "elu"

normalize_actions: bool = True
priority_buffer: bool = True

project: str = "TD7"
group: str = dataset_name
name: str = dataset_name + "_" + str(seed)
154 changes: 154 additions & 0 deletions td7/modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
from math import sqrt
import torch
from torch import nn


class AvgL1Norm(nn.Module):
# class name is weird but i try to be consistent with the paper
def __init__(self) -> None:
super().__init__()

def forward(self, x: torch.Tensor, eps: float = 1e-8):
return x / x.abs().mean(dim=-1, keepdim=True).clamp_min(eps)


class TD7Encoder(nn.Module):
def __init__(self,
state_dim: int,
action_dim: int,
embedding_dim: int = 256,
hidden_dim: int = 256,
activation: nn.Module = nn.ELU) -> None:
super().__init__()

self.f_layers = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
activation(),
nn.Linear(hidden_dim, hidden_dim),
activation(),
nn.Linear(hidden_dim, embedding_dim),
AvgL1Norm()
)

self.g_layers = nn.Sequential(
nn.Linear(embedding_dim + action_dim, hidden_dim),
activation(),
nn.Linear(hidden_dim, hidden_dim),
activation(),
nn.Linear(hidden_dim, embedding_dim)
)

def f(self, states: torch.Tensor) -> torch.Tensor:
return self.f_layers(states)

def g(self,
embeddings: torch.Tensor,
actions: torch.Tensor) -> torch.Tensor:
input_ = torch.cat([embeddings, actions], dim=-1)
return self.g_layers(input_)


class TD7Actor(nn.Module):
def __init__(self,
state_dim: int,
action_dim: int,
embedding_dim: int = 256,
hidden_dim: int = 256,
activation: nn.Module = nn.ReLU) -> None:
super().__init__()

self.state_layers = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
AvgL1Norm()
)
self.layers = nn.Sequential(
nn.Linear(embedding_dim + hidden_dim, hidden_dim),
activation(),
nn.Linear(hidden_dim, hidden_dim),
activation(),
nn.Linear(hidden_dim, action_dim),
nn.Tanh()
)

def forward(self,
states: torch.Tensor,
embeddings: torch.Tensor) -> torch.Tensor:
out = self.state_layers(states)
out = torch.cat([out, embeddings], dim=-1)
return self.layers(out)

def sample(self,
states: torch.Tensor,
embeddings: torch.Tensor) -> torch.Tensor:
return torch.tanh(self.forward(states, embeddings))


class EnsembledLinear(nn.Module):
def __init__(self,
in_features: int,
out_features: int,
ensemble_size: int) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.ensemble_size = ensemble_size

self.weight = nn.Parameter(torch.empty(ensemble_size, in_features, out_features))
self.bias = nn.Parameter(torch.empty(ensemble_size, 1, out_features))

self.reset_parameters()

def reset_parameters(self):
for layer in range(self.ensemble_size):
nn.init.kaiming_uniform_(self.weight[layer], a=sqrt(5))

fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
bound = 0
if fan_in > 0:
bound = 1 / sqrt(fan_in)

nn.init.uniform_(self.bias, -bound, bound)

def forward(self, x: torch.Tensor) -> torch.Tensor:
out = x @ self.weight + self.bias
return out


class TD7Critic(nn.Module):
def __init__(self,
state_dim: int,
action_dim: int,
embedding_dim: int = 256,
hidden_dim: int = 256,
num_critics: int = 2,
activation: nn.Module = nn.ELU) -> None:
super().__init__()

self.num_critics = num_critics

self.state_action_layers = nn.Sequential(
EnsembledLinear(state_dim + action_dim, hidden_dim, ensemble_size=num_critics),
AvgL1Norm()
)
self.layers = nn.Sequential(
EnsembledLinear(2 * embedding_dim + hidden_dim, hidden_dim, ensemble_size=num_critics),
activation(),
EnsembledLinear(hidden_dim, hidden_dim, ensemble_size=num_critics),
activation(),
EnsembledLinear(hidden_dim, 1, ensemble_size=num_critics)
)

def forward(self,
states: torch.Tensor,
actions: torch.Tensor,
zsa: torch.Tensor,
zs: torch.Tensor) -> torch.Tensor:
state_action = torch.cat([states, actions], dim=-1)
out = self.state_action_layers(state_action)
out = torch.cat([
out,
zsa.repeat([self.num_critics] + [1] * len(zsa.shape)),
zs.repeat([self.num_critics] + [1] * len(zs.shape))
], dim=-1)
out = self.layers(out)
return out
154 changes: 154 additions & 0 deletions td7/replay_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import os
from typing import List, Dict, Tuple
import numpy as np
import torch


class LAP:
'''
Loss-Adjusted Prioritized Experience Replay
https://arxiv.org/abs/2007.06049
'''
def __init__(self,
state_dim: int,
action_dim: int,
device: str = "cpu",
buffer_size: int = 1_000_000,
max_action: float = 1.0,
normalize_actions: bool = True,
with_priority: bool = True) -> None:

self.buffer_size = buffer_size
self.device = device

self.pointer = 0
self.size = 0

self.states = torch.zeros((buffer_size, state_dim), dtype=torch.float32, device=device)
self.actions = torch.zeros((buffer_size, action_dim), dtype=torch.float32, device=device)
self.rewards = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
self.next_states = torch.zeros((buffer_size, state_dim), dtype=torch.float32, device=device)
self.dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)

self.with_priortiy = with_priority

if with_priority:
self.priortiy = torch.zeros(buffer_size, device=device)
self.max_priority = 1.0

self.normalizing_factor = max_action if normalize_actions else 1.0

@staticmethod
def to_tensor(data: np.ndarray, device=None) -> torch.Tensor:
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

return torch.tensor(data, dtype=torch.float32, device=device)

def add_transition(self,
state: torch.Tensor,
action: torch.Tensor,
reward: torch.Tensor,
next_state: torch.Tensor,
done: torch.Tensor):
if not isinstance(state, torch.Tensor):
state = self.to_tensor(state, self.device)
action = self.to_tensor(action, self.device)
reward = self.to_tensor(reward, self.device)
next_state = self.to_tensor(next_state, self.device)
done = self.to_tensor(done, self.device)


self.states[self.pointer] = state
self.actions[self.pointer] = action / self.normalizing_factor
self.rewards[self.pointer] = reward
self.next_states[self.pointer] = next_state
self.dones[self.pointer] = done

if self.with_priortiy:
self.priortiy[self.pointer] = self.max_priority

self.pointer = (self.pointer + 1) % self.buffer_size
self.size = min(self.size + 1, self.buffer_size)

def add_batch(self,
states: List[torch.Tensor],
actions: List[torch.Tensor],
rewards: List[torch.Tensor],
next_states: List[torch.Tensor],
dones: List[torch.Tensor]):
for state, action, reward, next_state, done in zip(states, actions, rewards, next_states, dones):
self.add_transition(state, action, reward, next_state, done)

def sample(self, batch_size: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
if self.with_priortiy:
cumsum = torch.cumsum(self.priortiy[:self.size], dim=0)
value = torch.rand(size=(batch_size,), device=self.device) * cumsum[-1]
self.indexes: np.ndarray = torch.searchsorted(cumsum, value).cpu().data.numpy()
else:
self.indexes = np.random.randint(0, self.size, size=batch_size)

# states actions rewards next_states dones
return (
self.states[self.indexes],
self.actions[self.indexes],
self.rewards[self.indexes],
self.next_states[self.indexes],
self.dones[self.indexes]
)

def update_priority(self, priority: torch.Tensor):
self.priortiy[self.indexes] = priority.reshape(-1).detach()
self.max_priority = max(float(priority.max()), self.max_priority)

def update_max_priority(self):
self.max_priority = float(self.priortiy[:self.size].max())

def from_d4rl(self, dataset: Dict[str, np.ndarray]):
if self.size:
print("Warning: loading data into non-empty buffer")
n_transitions = dataset["observations"].shape[0]

if n_transitions < self.buffer_size:
self.states[:n_transitions] = self.to_tensor(dataset["observations"][-n_transitions:], self.device)
self.actions[:n_transitions] = self.to_tensor(dataset["actions"][-n_transitions:], self.device)
self.next_states[:n_transitions] = self.to_tensor(dataset["next_observations"][-n_transitions:], self.device)
self.rewards[:n_transitions] = self.to_tensor(dataset["rewards"][-n_transitions:].reshape(-1, 1), self.device)
self.dones[:n_transitions] = self.to_tensor(dataset["terminals"][-n_transitions:].reshape(-1, 1), self.device)

else:
self.buffer_size = n_transitions

self.states = self.to_tensor(dataset["observations"][-n_transitions:], self.device)
self.actions = self.to_tensor(dataset["actions"][-n_transitions:])
self.next_states = self.to_tensor(dataset["next_observations"][-n_transitions:], self.device)
self.rewards = self.to_tensor(dataset["rewards"][-n_transitions:].reshape(-1, 1), self.device)
self.dones = self.to_tensor(dataset["terminals"][-n_transitions:].reshape(-1, 1), self.device)

self.size = n_transitions
self.pointer = n_transitions % self.buffer_size

if self.with_priortiy:
self.priortiy = torch.ones(self.size).to(self.device)

def from_json(self, json_file: str):
import json

if not json_file.endswith('.json'):
json_file = json_file + '.json'

json_file = os.path.join("json_datasets", json_file)
output = dict()

with open(json_file) as f:
dataset = json.load(f)

for k, v in dataset.items():
v = np.array(v)
if k != "terminals":
v = v.astype(np.float32)

output[k] = v

self.from_d4rl(output)
Loading

0 comments on commit 15ac23b

Please sign in to comment.