Skip to content

Commit

Permalink
a little bit of details
Browse files Browse the repository at this point in the history
  • Loading branch information
zzmtsvv committed Jan 28, 2024
1 parent dad5d1c commit 9dcc539
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 4 deletions.
17 changes: 15 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
<!-- https://wandb.ai/zzmtsvv/sac_rnd/runs/d03hrwpr?workspace=user-zzmtsvv -->
<!-- https://wandb.ai/zzmtsvv/sac_drnd/runs/d03hrwpr?workspace=user-zzmtsvv -->

# Anti-Exploration with Distributional Random Network Distillation on PyTorch

This repository contains possible (not ideal one actually) PyTorch implementation of offline [SAC DRND](https://arxiv.org/abs/2401.09750) with the [wandb](https://wandb.ai/zzmtsvv/sac_drnd?workspace=user-zzmtsvv) integration. Actually, It is just a slightly modified [my realization](https://github.com/zzmtsvv/sac_rnd) of [SAC RND](https://arxiv.org/abs/2301.13616).

if you want to train the model, setup `drnd_config` in `config.py`, initialize `SACDRNDTrainer` in `trainer.py` and run its `train` method:

```python3
from trainer import SACDRNDTrainer

trainer = SACDRNDTrainer()
trainer.train()
```
if you find any bugs and mistakes in the code, please contact me :)

# Anti-Exploration with Distributional Random Network Distillation on PyTorch
2 changes: 1 addition & 1 deletion drnd_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def drnd_bonus(self,
target_mean_squared = target_mean.square()

b1 = self.loss_fn(predictor_out, target_mean).sum(dim=-1)
b2 = ((predictor_out.square() - target_mean_squared) / (B2 - target_mean_squared)).sqrt().sum(dim=-1)
b2 = ((predictor_out.square() - target_mean_squared) / (B2 - target_mean_squared + self.eps)).abs().sqrt().sum(dim=-1)

return self.alpha * b1 + (1 - self.alpha) * b2

Expand Down
6 changes: 5 additions & 1 deletion trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ def __init__(self,
seed_everything(cfg.train_seed)

def train_drnd(self) -> DRND:
(self.state_mean, self.state_std), (self.action_mean, self.action_std) = self.buffer.get_moments()
(state_mean, state_std), (action_mean, action_std) = self.buffer.get_moments()
self.state_mean = state_mean.to(self.device)
self.state_std = state_std.to(self.device)
self.action_mean = action_mean.to(self.device)
self.action_std = action_std.to(self.device)

drnd = DRND(self.state_dim,
self.action_dim,
Expand Down

0 comments on commit 9dcc539

Please sign in to comment.