-
Notifications
You must be signed in to change notification settings - Fork 0
/
config.py
47 lines (36 loc) · 1.04 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from dataclasses import dataclass
import torch
@dataclass
class doge_config:
# Experiment
device: str = "cuda" if torch.cuda.is_available() else "cpu"
dataset_name: str = "halfcheetah-medium-v2"
seed: int = 42
max_timesteps: int = int(1e6)
max_action : float = 1.0
action_dim: int = 6
state_dim: int = 17
buffer_size: int = 1_000_000
actor_lr: float = 3e-4
critic_lr: float = 3e-4
distance_lr: float = 1e-3
hidden_dim: int = 256
batch_size: int = 256
discount: float = 0.99
tau: float = 0.005
policy_noise: float = 0.2
noise_clip: float = 0.5
policy_freq: int = 2
initial_lambda: float = 6.0
lambda_max: float = 100.0
lambda_min: float = 1.0
lambda_threshold: float = 0.0
num_negative_samples: int = 20
alpha: float = 17.5 # 7.5
distance_steps: int = int(1e5)
lambda_lr: float = 3e-4
critic_ln: bool = True
normalize: bool = True
project: str = "DOGE"
group: str = dataset_name
name: str = dataset_name + "_" + str(seed)