Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
zzmtsvv authored Sep 23, 2023
1 parent 807b6dc commit f04e4de
Show file tree
Hide file tree
Showing 6 changed files with 791 additions and 0 deletions.
227 changes: 227 additions & 0 deletions diffusion_ql/actor_modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
# mostly taken from official implementation https://github.com/Zhendong-Wang/Diffusion-Policies-for-Offline-RL
# and hugging face tutorial on DDPM for Computer Vision
from typing import Union, Tuple
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F


def vp_beta_schedule(timesteps: int, dtype=torch.float32):
t = np.arange(1, timesteps + 1)
T = timesteps
b_max = 10.
b_min = 0.1
alpha = np.exp(-b_min / T - 0.5 * (b_max - b_min) * (2 * t - 1) / T ** 2)
betas = 1 - alpha
return torch.tensor(betas, dtype=dtype)


def extract(a: torch.Tensor,
t: torch.Tensor,
x_shape) -> torch.Tensor:
batch_size = t.shape[0]
out = a.gather(-1, t.cpu())
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)


class PositionalEncoding(nn.Module):
def __init__(self, embedding_dim: int):
super().__init__()
self.embedding_dim = embedding_dim

def forward(self, time: torch.Tensor):
device = time.device
half = self.embedding_dim // 2

embeddings = np.log(10000) / (half - 1)
embeddings = torch.exp(torch.arange(half, device=device) * (-1) * embeddings)
embeddings = time.unsqueeze(1) * embeddings.unsqueeze(0)
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)

return embeddings


class MLP(nn.Module):
def __init__(self,
state_dim: int,
action_dim: int,
hidden_dim: int = 256,
time_dim: int = 16) -> None:
super().__init__()

self.time_net = nn.Sequential(
PositionalEncoding(time_dim),
nn.Linear(time_dim, time_dim * 2),
nn.Mish(),
nn.Linear(time_dim * 2, time_dim)
)
self.net = nn.Sequential(
nn.Linear(state_dim + action_dim + time_dim, hidden_dim),
nn.Mish(),
nn.Linear(hidden_dim, hidden_dim),
nn.Mish(),
nn.Linear(hidden_dim, hidden_dim),
nn.Mish(),
nn.Linear(hidden_dim, action_dim)
)

def forward(self,
action: torch.Tensor,
state: torch.Tensor,
time: torch.Tensor) -> torch.Tensor:
time_features = self.time_net(time)
net_input = torch.cat([action, state, time_features], dim=-1)

return self.net(net_input)


class WeightedL2(nn.Module):
def __init__(self) -> None:
super().__init__()

def forward(self,
prediction: torch.Tensor,
target: torch.Tensor,
weights: Union[torch.Tensor, float] = 1.0) -> torch.Tensor:
loss = F.mse_loss(prediction, target, reduction="none")
return (loss * weights).mean()


class DDPM(nn.Module):
def __init__(self,
state_dim: int,
action_dim: int,
trunk: MLP,
max_action: float = 1.0,
num_timesteps: int = 100,) -> None:
super().__init__()

self.state_dim = state_dim
self.action_dim = action_dim
self.trunk = trunk
self.max_action = max_action
self.num_timesteps = num_timesteps

betas = vp_beta_schedule(num_timesteps)
self.betas = betas
self.alphas = 1.0 - betas

self.alphas_cumprod: torch.Tensor = torch.cumprod(self.alphas, axis=0)
self.alphas_cumprod_prev = torch.cat([torch.ones(1), self.alphas_cumprod[:-1]])

self.sqrt_inverted_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
self.sqrt_inverted_minus_one_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1.)

self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)

self.log_one_minus_alphas_cumprod = torch.log(1. / self.alphas_cumprod)

self.posterior_variance = betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)

self.posterior_log_variance = torch.log(self.posterior_variance.clamp(min=1e-20))

self.posterior_mean1 = self.betas * np.sqrt(self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
self.posterior_mean2 = (1. - self.alphas_cumprod_prev) * np.sqrt(self.alphas) / (1. - self.alphas_cumprod)

self.loss_fn = WeightedL2()

def forward_from_noise(self,
a_t: torch.Tensor,
timestep: torch.Tensor,
noise: torch.Tensor) -> torch.Tensor:
return extract(self.sqrt_inverted_alphas_cumprod, timestep, a_t.shape) * a_t - \
extract(self.sqrt_inverted_minus_one_alphas_cumprod, timestep, a_t.shape) * noise

def q_posterior(self,
a_start: torch.Tensor,
a_t: torch.Tensor,
timestep: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
mean = extract(self.posterior_mean1, timestep, a_t.shape) * a_start + \
extract(self.posterior_mean2, timestep, a_t.shape) * a_t

variance = extract(self.posterior_variance, timestep, a_t.shape)
log_variance = extract(self.posterior_log_variance, timestep, a_t.shape)
return mean, variance, log_variance

def p(self,
action: torch.Tensor,
state: torch.Tensor,
timestep: torch.Tensor):
reconstructed = self.forward_from_noise(action, timestep, self.trunk(action, state, timestep))
reconstructed = reconstructed.clamp(-self.max_action, self.max_action)

return self.q_posterior(reconstructed, action, timestep)

def p_sample(self,
action: torch.Tensor,
state: torch.Tensor,
timestep: torch.Tensor) -> torch.Tensor:
batch_size = action.shape[0]

model_mean, _, model_log_variance = self.p(action, state, timestep)
noise = torch.randn_like(action)

mask = (1. - (timestep == 0).float()).reshape(batch_size, *((1,) * (len(action.shape) - 1)))

return model_mean + mask * (model_log_variance / 2).exp() * noise

def p_sample_loop(self, state: torch.Tensor) -> torch.Tensor:
device = state.device
batch_size = state.shape[0]

action = torch.randn((batch_size, self.action_dim), device=device)

for i in reversed(range(self.num_timesteps)):
timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long)
action = self.p_sample(action, state, timesteps)

return action

def sample(self, state: torch.Tensor) -> torch.Tensor:
action = self.p_sample_loop(state)
return action.clamp(-self.max_action, self.max_action)

def q_sample(self,
a_start: torch.Tensor,
timestep: torch.Tensor,
noise: torch.Tensor):
return extract(self.sqrt_alphas_cumprod, timestep, a_start.shape) * a_start + \
extract(self.sqrt_one_minus_alphas_cumprod, timestep, a_start.shape) * noise

def p_loss(self,
a_start: torch.Tensor,
state: torch.Tensor,
timestep: torch.Tensor,
weights: Union[torch.Tensor, float] = 1.0) -> torch.Tensor:
noise = torch.randn_like(a_start)

a_noisy = self.q_sample(a_start, timestep, noise)
reconstructed = self.trunk(a_noisy, state, timestep)

assert noise.shape == reconstructed.shape

loss = self.loss_fn(reconstructed, noise, weights)
return loss

def loss(self,
action: torch.Tensor,
state: torch.Tensor,
weights: Union[torch.Tensor, float] = 1.0) -> torch.Tensor:
batch_size = action.shape[0]

timesteps = torch.randint(0, self.num_timesteps, (batch_size,), device=action.device).long()
return self.p_loss(action, state, timesteps, weights)

def forward(self, state: torch.Tensor) -> torch.Tensor:
return self.sample(state)


if __name__ == "__main__":
mlp = MLP(17, 6)
actor = DDPM(17, 6, mlp, num_timesteps=5)

state = torch.rand(32, 17)

print(actor(state).shape)
32 changes: 32 additions & 0 deletions diffusion_ql/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch
from dataclasses import dataclass


@dataclass
class dql_config:
device: str = "cuda" if torch.cuda.is_available() else "cpu"
dataset_name: str = "halfcheetah-medium-v2"
seed: int = 42

state_dim: int = 17
action_dim: int = 6

actor_update_freq: int = 5
steps_not_updating_actor_target: int = 1000

learning_rate: float = 3e-4
batch_size: int = 256
buffer_size: int = 1_000_000
discount: float = 0.99
hidden_dim: int = 256
max_action: float = 1.0
max_timesteps: int = 1_000_000
tau: float = 5e-3

T: int = 5
eta: float = 1.0
grad_norm: float = 9.0 # 1.0

project: str = "DiffusionQL"
group: str = dataset_name
name: str = dataset_name + "_" + str(seed)
80 changes: 80 additions & 0 deletions diffusion_ql/critic_modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from math import sqrt
import torch
from torch import nn


class EnsembledLinear(nn.Module):
def __init__(self,
in_features: int,
out_features: int,
ensemble_size: int) -> None:
super().__init__()

self.in_features = in_features
self.out_features = out_features
self.ensemble_size = ensemble_size

self.weight = nn.Parameter(torch.empty(ensemble_size, in_features, out_features))
self.bias = nn.Parameter(torch.empty(ensemble_size, 1, out_features))

self.reset_parameters()

def reset_parameters(self) -> None:
scale_factor = sqrt(5)
# default pytorch init
for layer in range(self.ensemble_size):
nn.init.kaiming_normal_(self.weight[layer], a=scale_factor)

fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
bound = 1 / sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(self.bias, -bound, bound)

def forward(self, x: torch.Tensor) -> torch.Tensor:
'''
x: [ensemble_size, batch_size, input_size]
weight: [ensemble_size, input_size, out_size]
bias: [ensemble_size, batch_size, out_size]
'''
return x @ self.weight + self.bias


class EnsembledCritic(nn.Module):
def __init__(self,
state_dim: int,
action_dim: int,
hidden_dim: int = 256,
num_critics: int = 2,
layer_norm: bool = False,
edac_init: bool = True) -> None:
super().__init__()

#block = nn.LayerNorm(hidden_dim) if layer_norm else nn.Identity()
self.num_critics = num_critics

self.critic = nn.Sequential(
EnsembledLinear(state_dim + action_dim, hidden_dim, num_critics),
nn.LayerNorm(hidden_dim) if layer_norm else nn.Identity(),
nn.Mish(),
EnsembledLinear(hidden_dim, hidden_dim, num_critics),
nn.LayerNorm(hidden_dim) if layer_norm else nn.Identity(),
nn.Mish(),
EnsembledLinear(hidden_dim, hidden_dim, num_critics),
nn.LayerNorm(hidden_dim) if layer_norm else nn.Identity(),
nn.Mish(),
EnsembledLinear(hidden_dim, 1, num_critics)
)

if edac_init:
# init as in the EDAC paper
for layer in self.critic[::3]:
nn.init.constant_(layer.bias, 0.1)

nn.init.uniform_(self.critic[-1].weight, -3e-3, 3e-3)
nn.init.uniform_(self.critic[-1].bias, -3e-3, 3e-3)

def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
concat = torch.cat([state, action], dim=-1)
concat = concat.unsqueeze(0)
concat = concat.repeat_interleave(self.num_critics, dim=0)
q_values = self.critic(concat).squeeze(-1)
return q_values
Loading

0 comments on commit f04e4de

Please sign in to comment.