Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
zzmtsvv committed Apr 9, 2023
0 parents commit 1d93cbe
Show file tree
Hide file tree
Showing 29 changed files with 2,759 additions and 0 deletions.
29 changes: 29 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

vae_weights
weights
weights
redq_bc_weights
policy_weights
spot_weights
runs
video

# Jupyter Notebook
.ipynb_checkpoints

workshop.ipynb

# pyenv
.python-version

# vscode
.vscode

# mac
.DS_Store

json_datasets
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
TODO: write about methods, compare it and insert results
Empty file added adaptive_bc/__init__.py
Empty file.
46 changes: 46 additions & 0 deletions adaptive_bc/configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from dataclasses import dataclass


max_target_returns = {
"halfcheetah-medium-replay-v0": 15.743,
"halfcheetah-medium-v0": 15.743,
"hopper-medium-replay-v0": 6.918,
"hopper-medium-v0": 6.918,
"walker2d-medium-replay-v0": 10.271,
"walker2d-medium-v0": 10.271
}


@dataclass
class train_config:
policy: str = "REDQ_BC"
env: str = "hopper-medium-replay-v0" # [halfcheetah-medium-replay-v0 walker2d-medium-replay-v0]
seed: int = 42
eval_frequency: int = 5000
max_timesteps: int = 250000
pretrain_timesteps: int = 1000000
num_updates: int = 10
save_model: bool = True
load_policy_path: str = ""
episode_length: int = 1000
exploration_noise: float = 0.1 # standard deviation of a gaussian devoted to the action space exploration noise
batch_size: int = 256
discount_factor: float = 0.99
tau: float = 0.005 # see algo.jpeg in 'paper' folder
policy_noise: float = 0.2
noise_clip: float = 0.5
policy_frequency: int = 2
alpha: float = 0.4
alpha_finetune: float = 0.4
sample_method: str = "random" # best
sample_ratio: float = 0.05 # see algo.jpeg in 'paper' folder (ratio to keep offline data in replay buffer)
minimize_over_q: bool = False # if false, use randomized ensembles, else min Q values for steps, see eq3.PNG in 'paper' folder
Kp: float = 0.00003 # see eq2.PNG in 'paper' folder
Kd: float = 0.0001 # see eq2.PNG in 'paper' folder
normalize_returns: bool = True # if true, divide returns by a factor of a target return defined in 'max_target_returns' dataclass
save_model: bool = True


if __name__ == "__main__":
print({k: v for k, v in train_config.__dict__.items() if not k.startswith("__")})

219 changes: 219 additions & 0 deletions adaptive_bc/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
import torch
import numpy as np
from typing import List, Tuple
import os


class ReplayBuffer:
data_size_threshold = 50000
distill_methods = ["random", "best"]

def __init__(self,
state_dim: int,
action_dim: int,
buffer_size: int = 1000000) -> None:
self.state_dim = state_dim
self.action_dim = action_dim

self.buffer_size = buffer_size
self.pointer = 0
self.size = 0

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = device

self.states = torch.zeros((buffer_size, state_dim), dtype=torch.float32, device=device)
self.actions = torch.zeros((buffer_size, action_dim), dtype=torch.float32, device=device)
self.rewards = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
self.next_states = torch.zeros((buffer_size, state_dim), dtype=torch.float32, device=device)
self.dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)

# i/o order: state, action, reward, next_state, done

@staticmethod
def to_tensor(data: np.ndarray, device=None) -> torch.Tensor:
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

return torch.tensor(data, dtype=torch.float32, device=device)

def from_json(self, json_file):
import json
json_file = os.path.join("json_datasets", json_file)
output = dict()

with open(json_file) as f:
dataset = json.load(f)

for k, v in dataset.items():
v = np.array(v)
if k != "terminals":
v = v.astype(np.float32)

output[k] = v

self.from_d4rl(output)

def sample(self, batch_size: int):
indexes = np.random.randint(0, self.size, size=batch_size)

return (
self.to_tensor(self.states[indexes], self.device),
self.to_tensor(self.actions[indexes], self.device),
self.to_tensor(self.rewards[indexes], self.device),
self.to_tensor(self.next_states[indexes], self.device),
self.to_tensor(self.dones[indexes], self.device)
)

def from_d4rl(self, dataset):
if self.size:
print("Warning: loading data into non-empty buffer")
n_transitions = dataset["observations"].shape[0]

if n_transitions < self.buffer_size:
self.states[:n_transitions] = self.to_tensor(dataset["observations"][-n_transitions:], self.device)
self.actions[:n_transitions] = self.to_tensor(dataset["actions"][-n_transitions:], self.device)
self.next_states[:n_transitions] = self.to_tensor(dataset["next_observations"][-n_transitions:], self.device)
self.rewards[:n_transitions] = self.to_tensor(dataset["rewards"][-n_transitions:].reshape(-1, 1), self.device)
self.dones[:n_transitions] = self.to_tensor(dataset["terminals"][-n_transitions:].reshape(-1, 1), self.device)

else:
self.buffer_size = n_transitions

self.states = self.to_tensor(dataset["observations"][-n_transitions:], self.device)
self.actions = self.to_tensor(dataset["actions"][-n_transitions:])
self.next_states = self.to_tensor(dataset["next_observations"][-n_transitions:], self.device)
self.rewards = self.to_tensor(dataset["rewards"][-n_transitions:].reshape(-1, 1), self.device)
self.dones = self.to_tensor(dataset["terminals"][-n_transitions:].reshape(-1, 1), self.device)

self.size = n_transitions
self.pointer = n_transitions % self.buffer_size

def normalize_states(self, eps=1e-3):
mean = self.states.mean(0, keepdim=True)
std = self.states.std(0, keepdim=True) + eps
self.states = (self.states - mean) / std
self.next_states = (self.next_states - mean) / std
return mean, std

def get_all(self):
return (
self.states[:self.size].to(self.device),
self.actions[:self.size].to(self.device),
self.rewards[:self.size].to(self.device),
self.next_states[:self.size].to(self.device),
self.dones[:self.size].to(self.device)
)

def add_transition(self,
state: torch.Tensor,
action: torch.Tensor,
reward: torch.Tensor,
next_state: torch.Tensor,
done: torch.Tensor):
if not isinstance(state, torch.Tensor):
state = self.to_tensor(state)
action = self.to_tensor(action)
reward = self.to_tensor(reward)
next_state = self.to_tensor(next_state)
done = self.to_tensor(done)


self.states[self.pointer] = state
self.actions[self.pointer] = action
self.rewards[self.pointer] = reward
self.next_states[self.pointer] = next_state
self.dones[self.pointer] = done

self.pointer = (self.pointer + 1) % self.buffer_size
self.size = min(self.size + 1, self.buffer_size)

def add_batch(self,
states: List[torch.Tensor],
actions: List[torch.Tensor],
rewards: List[torch.Tensor],
next_states: List[torch.Tensor],
dones: List[torch.Tensor]):
for state, action, reward, next_state, done in zip(states, actions, rewards, next_states, dones):
self.add_transition(state, action, reward, next_state, done)

def distill(self,
dataset,
env_name,
sample_method,
ratio=0.05):
data_size = max(int(ratio * dataset["observations"].shape[0]), self.data_size_threshold)
assert sample_method in self.distill_methods, "Unknown sample method"

if sample_method == "random":
indexes = np.random.randint(0, dataset["observations"].shape[0], size=data_size)
if sample_method == "best":
full_datas_size = dataset["observations"].shape[0]
indexes = np.arange(full_datas_size - data_size)

if data_size < self.buffer_size:
self.states[:data_size] = self.to_tensor(dataset["observations"][indexes], self.device)
self.actions[:data_size] = self.to_tensor(dataset["actions"][indexes], self.device)
self.rewards[:data_size] = self.to_tensor(dataset["rewards"][indexes].reshape(-1, 1), self.device)
self.next_states[:data_size] = self.to_tensor(dataset["next_observations"][indexes], self.device)
self.dones[:data_size] = self.to_tensor(dataset["terminals"][indexes].reshape(-1, 1), self.device)
else:
self.buffer_size = data_size
self.states = self.to_tensor(dataset["observations"][indexes], self.device)
self.actions = self.to_tensor(dataset["actions"][indexes], self.device)
self.rewards = self.to_tensor(dataset["rewards"][indexes].reshape(-1, 1), self.device)
self.next_states = self.to_tensor(dataset["next_observations"][indexes], self.device)
self.dones = self.to_tensor(dataset["terminals"][indexes].reshape(-1, 1), self.device)

self.size = data_size
self.pointer = data_size % self.buffer_size

@staticmethod
def dataset_stats(dataset):
episode_returns = []
returns = 0
episode_length = 0

for reward, done in zip(dataset["rewards"], dataset["terminals"]):
if done:
episode_returns.append(returns)
returns = 0
episode_length = 0
else:
episode_length += 1
returns += reward
if episode_length == 1000:
episode_returns.append(returns)
returns = 0
episode_length = 0

episode_returns = np.array(episode_returns)
return episode_returns.mean(), episode_returns.std()


def train_val_split(replay_buffer: ReplayBuffer, val_size: float) -> Tuple[ReplayBuffer, ReplayBuffer]:
data_size = replay_buffer.size
val_size = int(data_size * val_size)

permutation = torch.randperm(data_size)

train_rb = ReplayBuffer(replay_buffer.state_dim, replay_buffer.action_dim)
val_rb = ReplayBuffer(replay_buffer.state_dim, replay_buffer.action_dim)

train_rb.add_batch(
replay_buffer.states[permutation[val_size:]],
replay_buffer.actions[permutation[val_size:]],
replay_buffer.rewards[permutation[val_size:]],
replay_buffer.next_states[permutation[val_size:]],
replay_buffer.dones[permutation[val_size:]]
)

val_rb.add_batch(
replay_buffer.states[permutation[:val_size]],
replay_buffer.actions[permutation[:val_size]],
replay_buffer.rewards[permutation[:val_size]],
replay_buffer.next_states[permutation[:val_size]],
replay_buffer.dones[permutation[:val_size]]
)

return train_rb, val_rb
86 changes: 86 additions & 0 deletions adaptive_bc/modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import torch
from torch import nn
import numpy as np


class Actor(nn.Module):
def __init__(self,
state_dim: int,
action_dim: int,
max_action: float,
hidden_dim: int = 256) -> None:
super().__init__()

self.max_action = max_action
self.actor = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim),
nn.Tanh()
)

def forward(self, state: torch.Tensor) -> torch.Tensor:
return self.max_action * self.actor(state)

@torch.no_grad()
def act(self, state, device: str = "cpu") -> np.ndarray:
state = state.reshape(1, -1)

if not isinstance(state, torch.Tensor):
state = torch.tensor(state, device=device, dtype=torch.float32)

return self(state).cpu().data.numpy().flatten()


class EnsembleLinear(nn.Module):
def __init__(self,
in_features: int,
out_features: int,
ensemble_size: int) -> None:
super().__init__()

self.ensemble_size = ensemble_size
scale_factor = 2 * in_features ** 0.5

self.weight = nn.Parameter(torch.zeros(ensemble_size, in_features, out_features))
self.bias = nn.Parameter(torch.zeros(ensemble_size, 1, out_features))

nn.init.trunc_normal_(self.weight, std=1 / scale_factor)

def forward(self, x: torch.Tensor):

if len(x.shape) == 2:
#print(x.shape, self.weight.shape)
x = torch.einsum('ij,bjk->bik', x, self.weight)
else:
x = torch.einsum('bij,bjk->bik', x, self.weight)

x = x + self.bias
return x


class EnsembledCritic(nn.Module):
def __init__(self,
state_dim: int,
action_dim: int,
hidden_dim: int = 256,
num_critics: int = 10) -> None:
super().__init__()

self.critics = nn.Sequential(
EnsembleLinear(state_dim + action_dim, hidden_dim, ensemble_size=num_critics),
nn.ReLU(),
EnsembleLinear(hidden_dim, hidden_dim, ensemble_size=num_critics),
nn.ReLU(),
EnsembleLinear(hidden_dim, 1, ensemble_size=num_critics)
)

def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
# shape: (num_critics, batch, 1)
concat = torch.cat([state, action], 1)
#print(f"concat shape {concat.shape}")

#print(self.critics(concat).shape)
return self.critics(concat)
Loading

0 comments on commit 1d93cbe

Please sign in to comment.