-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
88 lines (67 loc) · 2.13 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
from typing import Optional, Tuple
import gym
import numpy as np
import torch
from imageio import mimsave
import random
import json
def make_dir(dir_path):
try:
os.mkdir(dir_path)
except OSError:
pass
return dir_path
def parse_json_dataset(filename: str) -> Tuple[int, int, float]:
max_action = 1.0
if not filename.endswith('.json'):
filename = filename + '.json'
filename_ = os.path.join("json_datasets", filename)
with open(filename_) as f:
obj = json.load(f)
states = np.array(obj["observations"])
actions = np.array(obj["actions"])
return states.shape[1], actions.shape[1], max_action
class DummyScheduler:
def __init__(self) -> None:
pass
def step(self):
pass
def seed_everything(seed: int,
env: Optional[gym.Env] = None,
use_deterministic_algos: bool = False):
if env is not None:
env.seed(seed)
env.action_space.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.use_deterministic_algorithms(use_deterministic_algos)
random.seed(seed)
class VideoRecorder:
def __init__(self, dir_name, height=512, width=512, camera_id=0, fps=60):
self.dir_name = dir_name
self.height = height
self.width = width
self.camera_id = camera_id
self.fps = fps
self.frames = []
def init(self, enabled=True):
self.frames = []
self.enabled = self.dir_name is not None and enabled
def record(self, env: gym.Env):
if self.enabled:
frame = env.render(
mode='rgb_array',
height=self.height,
width=self.width,
# camera_id=self.camera_id
)
self.frames.append(frame)
def save(self, file_name):
if self.enabled:
path = os.path.join(self.dir_name, file_name)
mimsave(path, self.frames, fps=self.fps)
if __name__ == "__main__":
print(parse_json_dataset("halfcheetah-medium-replay-v0.json"))