Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
zzmtsvv authored Sep 12, 2023
1 parent 997a6ea commit 9b6ac8d
Show file tree
Hide file tree
Showing 4 changed files with 444 additions and 0 deletions.
36 changes: 36 additions & 0 deletions offline_o3f/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from dataclasses import dataclass
import torch


@dataclass
class o3f_config:
# Experiment
device: str = "cuda" if torch.cuda.is_available() else "cpu"
dataset_name: str = "halfcheetah-medium-v2"
seed: int = 42
max_timesteps: int = int(1e6)

max_action : float = 1.0

action_dim: int = 6
state_dim: int = 17

buffer_size: int = 1_000_000
actor_lr: float = 3e-4
critic_lr: float = 3e-4
alpha_lr: float = 3e-4

hidden_dim: int = 256
batch_size: int = 256
discount: float = 0.99
tau: float = 0.005

critic_ln: bool = True
num_critics: int = 5
normalize: bool = True
standard_deviation: float = 0.2
num_action_candidates: int = 100

project: str = "offline_O3F"
group: str = dataset_name
name: str = dataset_name + "_" + str(seed)
164 changes: 164 additions & 0 deletions offline_o3f/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import torch
import numpy as np
from typing import List, Tuple
import os


class ReplayBuffer:
def __init__(self,
state_dim: int,
action_dim: int,
buffer_size: int = 1000000) -> None:

self.state_dim = state_dim
self.action_dim = action_dim
self.buffer_size = buffer_size
self.pointer = 0
self.size = 0

device = "cpu"
self.device = device

self.states = torch.zeros((buffer_size, state_dim), dtype=torch.float32, device=device)
self.actions = torch.zeros((buffer_size, action_dim), dtype=torch.float32, device=device)
self.rewards = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
self.next_states = torch.zeros((buffer_size, state_dim), dtype=torch.float32, device=device)
self.dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)

# i/o order: state, action, reward, next_state, done

def from_json(self, json_file: str):
import json

if not json_file.endswith('.json'):
json_file = json_file + '.json'

json_file = os.path.join("json_datasets", json_file)
output = dict()

with open(json_file) as f:
dataset = json.load(f)

for k, v in dataset.items():
v = np.array(v)
if k != "terminals":
v = v.astype(np.float32)

output[k] = v

self.from_d4rl(output)

def get_moments(self) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
state_mean, state_std = self.states.mean(dim=0), self.states.std(dim=0)
action_mean, action_std = self.actions.mean(dim=0), self.actions.std(dim=0)

return (state_mean, state_std), (action_mean, action_std)

@staticmethod
def to_tensor(data: np.ndarray, device=None) -> torch.Tensor:
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

return torch.tensor(data, dtype=torch.float32, device=device)

def sample(self, batch_size: int):
indexes = np.random.randint(0, self.size, size=batch_size)

return (
self.states[indexes],
self.actions[indexes],
self.rewards[indexes],
self.next_states[indexes],
self.dones[indexes]
)

def from_d4rl(self, dataset):
if self.size:
print("Warning: loading data into non-empty buffer")
n_transitions = dataset["observations"].shape[0]

if n_transitions < self.buffer_size:
self.states[:n_transitions] = self.to_tensor(dataset["observations"][-n_transitions:], self.device)
self.actions[:n_transitions] = self.to_tensor(dataset["actions"][-n_transitions:], self.device)
self.next_states[:n_transitions] = self.to_tensor(dataset["next_observations"][-n_transitions:], self.device)
self.rewards[:n_transitions] = self.to_tensor(dataset["rewards"][-n_transitions:].reshape(-1, 1), self.device)
self.dones[:n_transitions] = self.to_tensor(dataset["terminals"][-n_transitions:].reshape(-1, 1), self.device)

else:
self.buffer_size = n_transitions

self.states = self.to_tensor(dataset["observations"][-n_transitions:], self.device)
self.actions = self.to_tensor(dataset["actions"][-n_transitions:])
self.next_states = self.to_tensor(dataset["next_observations"][-n_transitions:], self.device)
self.rewards = self.to_tensor(dataset["rewards"][-n_transitions:].reshape(-1, 1), self.device)
self.dones = self.to_tensor(dataset["terminals"][-n_transitions:].reshape(-1, 1), self.device)

self.size = n_transitions
self.pointer = n_transitions % self.buffer_size

def from_d4rl_finetune(self, dataset):
raise NotImplementedError()

def normalize_states(self, eps=1e-3):
mean = self.states.mean(0, keepdim=True)
std = self.states.std(0, keepdim=True) + eps
self.states = (self.states - mean) / std
self.next_states = (self.next_states - mean) / std
return mean, std

def clip(self, eps=1e-5):
self.actions = torch.clip(self.actions, - 1 + eps, 1 - eps)

def add_transition(self,
state: torch.Tensor,
action: torch.Tensor,
reward: torch.Tensor,
next_state: torch.Tensor,
done: torch.Tensor):
if not isinstance(state, torch.Tensor):
state = self.to_tensor(state, self.device)
action = self.to_tensor(action, self.device)
reward = self.to_tensor(reward, self.device)
next_state = self.to_tensor(next_state, self.device)
done = self.to_tensor(done, self.device)


self.states[self.pointer] = state
self.actions[self.pointer] = action
self.rewards[self.pointer] = reward
self.next_states[self.pointer] = next_state
self.dones[self.pointer] = done

self.pointer = (self.pointer + 1) % self.buffer_size
self.size = min(self.size + 1, self.buffer_size)

def add_batch(self,
states: List[torch.Tensor],
actions: List[torch.Tensor],
rewards: List[torch.Tensor],
next_states: List[torch.Tensor],
dones: List[torch.Tensor]):
for state, action, reward, next_state, done in zip(states, actions, rewards, next_states, dones):
self.add_transition(state, action, reward, next_state, done)

@staticmethod
def dataset_stats(dataset):
episode_returns = []
returns = 0
episode_length = 0

for reward, done in zip(dataset["rewards"], dataset["terminals"]):
if done:
episode_returns.append(returns)
returns = 0
episode_length = 0
else:
episode_length += 1
returns += reward
if episode_length == 1000:
episode_returns.append(returns)
returns = 0
episode_length = 0

episode_returns = np.array(episode_returns)
return episode_returns.mean(), episode_returns.std()
194 changes: 194 additions & 0 deletions offline_o3f/modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
from math import sqrt
from typing import Optional, Tuple
import torch
from torch import nn
from torch.distributions import Normal
import numpy as np


class Actor(nn.Module):
def __init__(self,
state_dim: int,
action_dim: int,
hidden_dim: int = 256,
edac_init: bool = False,
max_action: float = 1.0) -> None:
super().__init__()
self.action_dim = action_dim
self.max_action = max_action

self.trunk = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)

self.mu = nn.Linear(hidden_dim, action_dim)
self.log_std = nn.Linear(hidden_dim, action_dim)

if edac_init:
# init as in the EDAC paper
for layer in self.trunk[::2]:
nn.init.constant_(layer.bias, 0.1)

nn.init.uniform_(self.mu.weight, -1e-3, 1e-3)
nn.init.uniform_(self.mu.bias, -1e-3, 1e-3)
nn.init.uniform_(self.log_std.weight, -1e-3, 1e-3)
nn.init.uniform_(self.log_std.bias, -1e-3, 1e-3)

def forward(self,
state: torch.Tensor,
deterministic: bool = False,
need_log_prob: bool = False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden = self.trunk(state)
mu, log_std = self.mu(hidden), self.log_std(hidden)

log_std = torch.clip(log_std, -20, 2) # log_std = torch.clip(log_std, -5, 2) EDAC clipping
policy_distribution = Normal(mu, torch.exp(log_std))

if deterministic:
action = mu
else:
action = policy_distribution.rsample()

tanh_action, log_prob = torch.tanh(action), None
if need_log_prob:
log_prob = policy_distribution.log_prob(action).sum(-1)
log_prob = log_prob - torch.log(1 - tanh_action.pow(2) + 1e-6).sum(-1)
# shape [batch_size,]

return tanh_action * self.max_action, log_prob

@torch.no_grad()
def act(self, state: np.ndarray, device: str) -> np.ndarray:
deterministic = not self.training
state = torch.tensor(state, device=device, dtype=torch.float32)
action = self(state, deterministic=deterministic)[0].cpu().numpy()
return action


class Critic(nn.Module):
def __init__(self,
state_dim: int,
action_dim: int,
hidden_dim: int = 256,
layer_norm: bool = True,
edac_init: bool = True) -> None:
super().__init__()

#block = nn.LayerNorm(hidden_dim) if layer_norm else nn.Identity()

self.critic = nn.Sequential(
nn.Linear(state_dim + action_dim, hidden_dim),
nn.LayerNorm(hidden_dim) if layer_norm else nn.Identity(),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim) if layer_norm else nn.Identity(),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim) if layer_norm else nn.Identity(),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)

if edac_init:
# init as in the EDAC paper
for layer in self.critic[::3]:
nn.init.constant_(layer.bias, 0.1)

nn.init.uniform_(self.critic[-1].weight, -3e-3, 3e-3)
nn.init.uniform_(self.critic[-1].bias, -3e-3, 3e-3)

def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
concat = torch.cat([state, action], dim=-1)
q_values = self.critic(concat).squeeze(-1) # shape: [batch_size,]
return q_values


class EnsembledLinear(nn.Module):
def __init__(self,
in_features: int,
out_features: int,
ensemble_size: int) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.ensemble_size = ensemble_size

self.weight = nn.Parameter(torch.empty(ensemble_size, in_features, out_features))
self.bias = nn.Parameter(torch.empty(ensemble_size, 1, out_features))

self.reset_parameters()

def reset_parameters(self):
for layer in range(self.ensemble_size):
nn.init.kaiming_uniform_(self.weight[layer], a=sqrt(5))

fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
bound = 0
if fan_in > 0:
bound = 1 / sqrt(fan_in)

nn.init.uniform_(self.bias, -bound, bound)

def forward(self, x: torch.Tensor) -> torch.Tensor:
out = x @ self.weight + self.bias
return out


class EnsembledCritic(nn.Module):
def __init__(self,
state_dim: int,
action_dim: int,
hidden_dim: int = 256,
num_critics: int = 2,
layer_norm: bool = True,
edac_init: bool = True) -> None:
super().__init__()

#block = nn.LayerNorm(hidden_dim) if layer_norm else nn.Identity()
self.num_critics = num_critics

self.critic = nn.Sequential(
EnsembledLinear(state_dim + action_dim, hidden_dim, num_critics),
nn.LayerNorm(hidden_dim) if layer_norm else nn.Identity(),
nn.ReLU(),
EnsembledLinear(hidden_dim, hidden_dim, num_critics),
nn.LayerNorm(hidden_dim) if layer_norm else nn.Identity(),
nn.ReLU(),
EnsembledLinear(hidden_dim, hidden_dim, num_critics),
nn.LayerNorm(hidden_dim) if layer_norm else nn.Identity(),
nn.ReLU(),
EnsembledLinear(hidden_dim, 1, num_critics)
)

if edac_init:
# init as in the EDAC paper
for layer in self.critic[::3]:
nn.init.constant_(layer.bias, 0.1)

nn.init.uniform_(self.critic[-1].weight, -3e-3, 3e-3)
nn.init.uniform_(self.critic[-1].bias, -3e-3, 3e-3)

def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
concat = torch.cat([state, action], dim=-1)
concat = concat.unsqueeze(0)
concat = concat.repeat_interleave(self.num_critics, dim=0)
q_values = self.critic(concat).squeeze(-1)
return q_values


if __name__ == "__main__":
critic = EnsembledCritic(17, 6)

state_repeat = torch.rand(32, 17)
action_repeat = torch.rand(32, 6)

meow = critic(state_repeat, action_repeat).min(0).values.view(32, -1)

print(meow.max(1).values.shape)


Loading

0 comments on commit 9b6ac8d

Please sign in to comment.