Skip to content

Commit

Permalink
Update td7.py
Browse files Browse the repository at this point in the history
  • Loading branch information
zzmtsvv authored Sep 15, 2023
1 parent 816f760 commit 37c1e40
Showing 1 changed file with 24 additions and 2 deletions.
26 changes: 24 additions & 2 deletions td7/td7.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from math import pow
from typing import Tuple, Dict, Any
import torch
from torch.nn import functional as F
Expand Down Expand Up @@ -44,6 +45,7 @@ def __init__(self,
self.running_min_q = float("inf")

self.lambda_coef = cfg.lambda_coef
self.alpha = cfg.alpha

self.total_iterations = 0

Expand Down Expand Up @@ -147,12 +149,33 @@ def critic_loss(self,
td = (target_q - current_q).abs()
critic_loss = self.huber_loss(td)

priority = td.max(0).values.detach().squeeze(-1).clamp_min(self.min_priority).pow(self.cfg.alpha)
priority = td.max(0).values.detach().squeeze(-1).clamp_min(self.min_priority).pow(self.alpha)

return critic_loss, current_q.mean(), priority, self.running_max_q, self.running_min_q

def huber_loss(self, x: torch.Tensor) -> torch.Tensor:
return torch.where(x < self.min_priority, x.pow(2) / 2, x).sum(dim=0).mean()

def pal(self, td: torch.Tensor) -> torch.Tensor:
'''
Prioritized Approximation Loss
is used with uniform sampling from the replay buffer
'''
if self.min_priority == 1.0:
loss = torch.where(
td.abs() < 1.0,
td.pow(2) / 2,
td.abs().pow(1.0 + self.alpha) / (1.0 + self.alpha)
).mean()
else:
loss = torch.where(
td.abs() < self.min_priority,
pow(self.min_priority, self.alpha) * td.pow(2) / 2,
self.min_priority * td.abs().pow(1.0 + self.alpha) / (1.0 + self.alpha)
).mean()

lambda_coef = td.abs().clamp_min(self.min_priority).pow(self.alpha).mean().detach()
return loss / lambda_coef

def update_target_models(self):
self.actor_target.load_state_dict(self.actor.state_dict())
Expand Down Expand Up @@ -196,4 +219,3 @@ def load_state_dict(self, state_dict: Dict[str, Any]):
self.min_target = state_dict["min_target"]
self.running_max_q = state_dict["running_max_q"]
self.running_min_q = state_dict["running_min_q"]

0 comments on commit 37c1e40

Please sign in to comment.