-
Notifications
You must be signed in to change notification settings - Fork 1
/
train_secundary.py
executable file
·74 lines (60 loc) · 2.39 KB
/
train_secundary.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#!/usr/bin/env python3
import argparse
from copy import deepcopy
from pathlib import Path
import numpy as np
from python_tools import caching
from python_tools.generic import namespace_as_string
from python_tools.ml.data_loader import DataLoader
from python_tools.ml.pytorch_tools import dict_to_batched_data
from train import train
def get_data(training: Path, transformation: Path):
results = {
key: caching.read_pickle(
training.parent / training.name.replace("training", key)
)[0]
for key in ("training", "evaluation", "test")
}
# meta data
x_names = ["guess"] + [
f"embedding_{i}" for i in range(results["training"]["meta_embedding"].shape[1])
]
meta_data = {"y_names": np.array(["loss"]), "x_names": np.array(x_names)}
# apply transformation
transformation = caching.read_pickle(transformation)[0][0]["y"]
for dataset in results:
for key in ("y", "y_hat"):
results[dataset][key] = (
results[dataset][key] - transformation["mean"]
) / transformation["std"]
# generate data
datasets = {}
for key, data in results.items():
dataset = {
"x": np.concatenate([data["y_hat"], data["meta_embedding"]], axis=1),
"y": np.abs(data["y_hat"] - data["y"]),
"meta_id": data["meta_id"],
"meta_frame": data["meta_frame"],
"meta_y_hat": data["y_hat"],
"meta_y": data["y"],
}
datasets[key] = DataLoader(
dict_to_batched_data(dataset), properties=deepcopy(meta_data)
)
return {0: datasets}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", choices=["disfa", "bp4d_plus", "mnist", "mnisti"])
parser.add_argument("--uncertainty", choices=["umlp", "dwar"], default="umlp")
parser.add_argument("--method", choices=[""], default="")
parser.add_argument("--au", type=int, default=6)
parser.add_argument("--workers", type=int, default=2)
args = parser.parse_args()
# find best model
primary_folder = Path(f"method=dropout_dataset={args.dataset}_au={args.au}")
path = next(primary_folder.glob("*_training_predictions.pickle"))
# get data
data = get_data(path, primary_folder / "partition_0.pickle")
# train
folder = Path(namespace_as_string(args, exclude=("workers",)))
train(data, folder, args)