Skip to content

Commit

Permalink
first try
Browse files Browse the repository at this point in the history
  • Loading branch information
Алексей Земцов authored and Алексей Земцов committed Aug 20, 2023
0 parents commit 75f0aa7
Show file tree
Hide file tree
Showing 11 changed files with 608 additions and 0 deletions.
27 changes: 27 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

weights

runs
video
wandb

# Jupyter Notebook
.ipynb_checkpoints

workshop.ipynb

# pyenv
.python-version

# vscode
.vscode

# mac
.DS_Store

train.ipynb
json_datasets
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# RORL Implementation

Hi there. This is my [RORL](https://arxiv.org/abs/2206.02829) implementation (not ideal one actually) on PyTorch. Feel free to tune hyperparameters in order to achieve best results and contact me on any mistakes you found :)s
39 changes: 39 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch
from typing import Optional
from dataclasses import dataclass


@dataclass
class rorl_config:
device: str = "cuda" if torch.cuda.is_available() else "cpu"
dataset_name: str = "walker2d-medium-v2"
seed: int = 42
eval_seed: int = 0
eval_freq: int = int(1e3)
num_episodes: int = 10
max_timesteps: int = int(3e5)

max_action : float = 1.0

# SAC-N
num_critics: int = 20
buffer_size: int = 1_000_000 # Replay buffer size
actor_learning_rate: float = 3e-4
critic_learning_rate: float = 3e-4
alpha_learning_rate: float = 3e-4
hidden_dim: int = 256
batch_size: int = 256 # Batch size for all networks
discount: float = 0.99 # Discount factor
tau: float = 0.005 # Target network update rate

# RORL
epsilon: float = 0.01
tau_rorl: float = 0.2
beta_smooth: float = 1e-4
beta_ood: float = 0.1
beta_divergence: float = 1.0

# Wandb logging
project: str = "RORL"
group: str = dataset_name
name: str = dataset_name + "_" + str(seed)
164 changes: 164 additions & 0 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import torch
import numpy as np
from typing import List, Tuple
import os


class ReplayBuffer:
def __init__(self,
state_dim: int,
action_dim: int,
buffer_size: int = 1000000) -> None:

self.state_dim = state_dim
self.action_dim = action_dim
self.buffer_size = buffer_size
self.pointer = 0
self.size = 0

device = "cpu"
self.device = device

self.states = torch.zeros((buffer_size, state_dim), dtype=torch.float32, device=device)
self.actions = torch.zeros((buffer_size, action_dim), dtype=torch.float32, device=device)
self.rewards = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
self.next_states = torch.zeros((buffer_size, state_dim), dtype=torch.float32, device=device)
self.dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)

# i/o order: state, action, reward, next_state, done

def from_json(self, json_file: str):
import json

if not json_file.endswith('.json'):
json_file = json_file + '.json'

json_file = os.path.join("json_datasets", json_file)
output = dict()

with open(json_file) as f:
dataset = json.load(f)

for k, v in dataset.items():
v = np.array(v)
if k != "terminals":
v = v.astype(np.float32)

output[k] = v

self.from_d4rl(output)

def get_moments(self) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
state_mean, state_std = self.states.mean(dim=0), self.states.std(dim=0)
action_mean, action_std = self.actions.mean(dim=0), self.actions.std(dim=0)

return (state_mean, state_std), (action_mean, action_std)

@staticmethod
def to_tensor(data: np.ndarray, device=None) -> torch.Tensor:
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

return torch.tensor(data, dtype=torch.float32, device=device)

def sample(self, batch_size: int):
indexes = np.random.randint(0, self.size, size=batch_size)

return (
self.states[indexes],
self.actions[indexes],
self.rewards[indexes],
self.next_states[indexes],
self.dones[indexes]
)

def from_d4rl(self, dataset):
if self.size:
print("Warning: loading data into non-empty buffer")
n_transitions = dataset["observations"].shape[0]

if n_transitions < self.buffer_size:
self.states[:n_transitions] = self.to_tensor(dataset["observations"][-n_transitions:], self.device)
self.actions[:n_transitions] = self.to_tensor(dataset["actions"][-n_transitions:], self.device)
self.next_states[:n_transitions] = self.to_tensor(dataset["next_observations"][-n_transitions:], self.device)
self.rewards[:n_transitions] = self.to_tensor(dataset["rewards"][-n_transitions:].reshape(-1, 1), self.device)
self.dones[:n_transitions] = self.to_tensor(dataset["terminals"][-n_transitions:].reshape(-1, 1), self.device)

else:
self.buffer_size = n_transitions

self.states = self.to_tensor(dataset["observations"][-n_transitions:], self.device)
self.actions = self.to_tensor(dataset["actions"][-n_transitions:])
self.next_states = self.to_tensor(dataset["next_observations"][-n_transitions:], self.device)
self.rewards = self.to_tensor(dataset["rewards"][-n_transitions:].reshape(-1, 1), self.device)
self.dones = self.to_tensor(dataset["terminals"][-n_transitions:].reshape(-1, 1), self.device)

self.size = n_transitions
self.pointer = n_transitions % self.buffer_size

def from_d4rl_finetune(self, dataset):
raise NotImplementedError()

def normalize_states(self, eps=1e-3):
mean = self.states.mean(0, keepdim=True)
std = self.states.std(0, keepdim=True) + eps
self.states = (self.states - mean) / std
self.next_states = (self.next_states - mean) / std
return mean, std

def clip(self, eps=1e-5):
self.actions = torch.clip(self.actions, - 1 + eps, 1 - eps)

def add_transition(self,
state: torch.Tensor,
action: torch.Tensor,
reward: torch.Tensor,
next_state: torch.Tensor,
done: torch.Tensor):
if not isinstance(state, torch.Tensor):
state = self.to_tensor(state, self.device)
action = self.to_tensor(action, self.device)
reward = self.to_tensor(reward, self.device)
next_state = self.to_tensor(next_state, self.device)
done = self.to_tensor(done, self.device)


self.states[self.pointer] = state
self.actions[self.pointer] = action
self.rewards[self.pointer] = reward
self.next_states[self.pointer] = next_state
self.dones[self.pointer] = done

self.pointer = (self.pointer + 1) % self.buffer_size
self.size = min(self.size + 1, self.buffer_size)

def add_batch(self,
states: List[torch.Tensor],
actions: List[torch.Tensor],
rewards: List[torch.Tensor],
next_states: List[torch.Tensor],
dones: List[torch.Tensor]):
for state, action, reward, next_state, done in zip(states, actions, rewards, next_states, dones):
self.add_transition(state, action, reward, next_state, done)

@staticmethod
def dataset_stats(dataset):
episode_returns = []
returns = 0
episode_length = 0

for reward, done in zip(dataset["rewards"], dataset["terminals"]):
if done:
episode_returns.append(returns)
returns = 0
episode_length = 0
else:
episode_length += 1
returns += reward
if episode_length == 1000:
episode_returns.append(returns)
returns = 0
episode_length = 0

episode_returns = np.array(episode_returns)
return episode_returns.mean(), episode_returns.std()
143 changes: 143 additions & 0 deletions modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from math import sqrt
from typing import Optional, Tuple
import torch
from torch import nn
from torch.distributions import Normal
import numpy as np


class Actor(nn.Module):
def __init__(self,
state_dim: int,
action_dim: int,
hidden_dim: int = 256,
edac_init: bool = False,
max_action: float = 1.0) -> None:
super().__init__()
self.action_dim = action_dim
self.max_action = max_action

self.trunk = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)

self.mu = nn.Linear(hidden_dim, action_dim)
self.log_std = nn.Linear(hidden_dim, action_dim)

if edac_init:
# init as in the EDAC paper
for layer in self.trunk[::2]:
nn.init.constant_(layer.bias, 0.1)

nn.init.uniform_(self.mu.weight, -1e-3, 1e-3)
nn.init.uniform_(self.mu.bias, -1e-3, 1e-3)
nn.init.uniform_(self.log_std.weight, -1e-3, 1e-3)
nn.init.uniform_(self.log_std.bias, -1e-3, 1e-3)

def forward(self,
state: torch.Tensor,
deterministic: bool = False,
need_log_prob: bool = False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden = self.trunk(state)
mu, log_std = self.mu(hidden), self.log_std(hidden)

log_std = torch.clip(log_std, -20, 2) # log_std = torch.clip(log_std, -5, 2) EDAC clipping
policy_distribution = Normal(mu, torch.exp(log_std))

if deterministic:
action = mu
else:
action = policy_distribution.rsample()

tanh_action, log_prob = torch.tanh(action), None
if need_log_prob:
log_prob = policy_distribution.log_prob(action).sum(-1)
log_prob = log_prob - torch.log(1 - tanh_action.pow(2) + 1e-6).sum(-1)
# shape [batch_size,]

return tanh_action * self.max_action, log_prob

@torch.no_grad()
def act(self, state: np.ndarray, device: str) -> np.ndarray:
deterministic = not self.training
state = torch.tensor(state, device=device, dtype=torch.float32)
action = self(state, deterministic=deterministic)[0].cpu().numpy()
return action


class EnsembledLinear(nn.Module):
def __init__(self,
in_features: int,
out_features: int,
ensemble_size: int) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.ensemble_size = ensemble_size

self.weight = nn.Parameter(torch.empty(ensemble_size, in_features, out_features))
self.bias = nn.Parameter(torch.empty(ensemble_size, 1, out_features))

self.reset_parameters()

def reset_parameters(self):
for layer in range(self.ensemble_size):
nn.init.kaiming_uniform_(self.weight[layer], a=sqrt(5))

fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
bound = 0
if fan_in > 0:
bound = 1 / sqrt(fan_in)

nn.init.uniform_(self.bias, -bound, bound)

def forward(self, x: torch.Tensor) -> torch.Tensor:
out = x @ self.weight + self.bias
return out


class EnsembledCritic(nn.Module):
def __init__(self,
state_dim: int,
action_dim: int,
hidden_dim: int = 256,
num_critics: int = 2,
layer_norm: bool = True,
edac_init: bool = True) -> None:
super().__init__()

#block = nn.LayerNorm(hidden_dim) if layer_norm else nn.Identity()
self.num_critics = num_critics

self.critic = nn.Sequential(
EnsembledLinear(state_dim + action_dim, hidden_dim, num_critics),
nn.LayerNorm(hidden_dim) if layer_norm else nn.Identity(),
nn.ReLU(),
EnsembledLinear(hidden_dim, hidden_dim, num_critics),
nn.LayerNorm(hidden_dim) if layer_norm else nn.Identity(),
nn.ReLU(),
EnsembledLinear(hidden_dim, hidden_dim, num_critics),
nn.LayerNorm(hidden_dim) if layer_norm else nn.Identity(),
nn.ReLU(),
EnsembledLinear(hidden_dim, 1, num_critics)
)

if edac_init:
# init as in the EDAC paper
for layer in self.critic[::3]:
nn.init.constant_(layer.bias, 0.1)

nn.init.uniform_(self.critic[-1].weight, -3e-3, 3e-3)
nn.init.uniform_(self.critic[-1].bias, -3e-3, 3e-3)

def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
concat = torch.cat([state, action], dim=-1)
concat = concat.unsqueeze(0)
concat = concat.repeat_interleave(self.num_critics, dim=0)
q_values = self.critic(concat).squeeze(-1)
return q_values
Binary file added paper/algorithm.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added paper/eq_1-2.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added paper/eq_3-4.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added paper/eq_5-6.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 75f0aa7

Please sign in to comment.