-
Notifications
You must be signed in to change notification settings - Fork 0
/
configs.py
67 lines (63 loc) · 1.87 KB
/
configs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from dataclasses import dataclass
import os
@dataclass
class vae_config:
seed: int = 0
env: str = "hopper" # halfcheetah walker2d
dataset: str = "medium" # medium, medium-replay, medium-expert, expert
version: str = "v2"
hidden_dim: int = 750
beta: float = 0.5
num_iterations: int = 100000
batch_size: int = 256
lr: float = 3e-4
weight_decay: float = 0
use_scheduler: bool = False
gamma: float = 0.95
max_action_exists: bool = True
clip_to_eps: bool = False
eps: float = 1e-4
#latent_dim: int # action_dim * 2
normalize_states: bool = True
eval_size: float = 0.0
weights_dir: str = "weights"
base_dir: str = "spot"
@dataclass
class spot_config:
save_video: bool = False
buffer_size: int = 1000000
env: str = "hopper" # halfcheetah walker2d
dataset: str = "medium" # medium, medium-replay
version: str = "v0"
env_name: str = f"{env}-{dataset}-{version}"
seed: int = 0
eval_frequency: int = 5e3
max_timesteps: int = 1000000
save_model: bool = False
save_final_model: bool = True
eval_episodes: int = 10
clip: bool = False
exploration_noise: float = 0.1
batch_size: int = 256
discount_factor: float = 0.99
tau: float = 0.005
policy_noise: float = 0.2
noise_clip: float = 0.5
policy_frequency: int = 2
lr: float = 3e-4
actor_lr: float = None
actor_hidden_dim: int = 256
critic_hidden_dim: int = 256
actor_dropout: float = 0.1
alpha: float = 0.4
normalize_env: bool = True
vae_model_path: str = os.path.join("spot", "weights", f"vae_{env}-{dataset}.pt")
beta: float = 0.5
use_importance_sampling: bool = False
num_samples: int = 1
lambda_: float = 1.0
with_q_norm: bool = True
lambda_cool: float = False
lambda_end: float = 0.2
base_dir: str = "spot"
weights_dir: str = "policy_weights"