forked from allegro/allRank
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
114 lines (90 loc) · 4.11 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from urllib.parse import urlparse
import allrank.models.losses as losses
import numpy as np
import os
import torch
from allrank.config import Config
from allrank.data.dataset_loading import load_libsvm_dataset, create_data_loaders
from allrank.models.model import make_model
from allrank.models.model_utils import get_torch_device, CustomDataParallel
from allrank.training.train_utils import fit
from allrank.utils.command_executor import execute_command
from allrank.utils.experiments import dump_experiment_result, assert_expected_metrics
from allrank.utils.file_utils import create_output_dirs, PathsContainer, copy_local_to_gs
from allrank.utils.ltr_logging import init_logger
from allrank.utils.python_utils import dummy_context_mgr
from argparse import ArgumentParser, Namespace
from attr import asdict
from functools import partial
from pprint import pformat
from torch import optim
def parse_args() -> Namespace:
parser = ArgumentParser("allRank")
parser.add_argument("--job-dir", help="Base output path for all experiments", required=True)
parser.add_argument("--run-id", help="Name of this run to be recorded (must be unique within output dir)",
required=True)
parser.add_argument("--config-file-name", required=True, type=str, help="Name of json file with config")
return parser.parse_args()
def run():
# reproducibility
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
np.random.seed(42)
args = parse_args()
paths = PathsContainer.from_args(args.job_dir, args.run_id, args.config_file_name)
create_output_dirs(paths.output_dir)
logger = init_logger(paths.output_dir)
logger.info(f"created paths container {paths}")
# read config
config = Config.from_json(paths.config_path)
logger.info("Config:\n {}".format(pformat(vars(config), width=1)))
output_config_path = os.path.join(paths.output_dir, "used_config.json")
execute_command("cp {} {}".format(paths.config_path, output_config_path))
# train_ds, val_ds
train_ds, val_ds = load_libsvm_dataset(
input_path=config.data.path,
slate_length=config.data.slate_length,
validation_ds_role=config.data.validation_ds_role,
)
n_features = train_ds.shape[-1]
assert n_features == val_ds.shape[-1], "Last dimensions of train_ds and val_ds do not match!"
# train_dl, val_dl
train_dl, val_dl = create_data_loaders(
train_ds, val_ds, num_workers=config.data.num_workers, batch_size=config.data.batch_size)
# gpu support
dev = get_torch_device()
logger.info("Model training will execute on {}".format(dev.type))
# instantiate model
model = make_model(n_features=n_features, **asdict(config.model, recurse=False))
if torch.cuda.device_count() > 1:
model = CustomDataParallel(model)
logger.info("Model training will be distributed to {} GPUs.".format(torch.cuda.device_count()))
model.to(dev)
# load optimizer, loss and LR scheduler
optimizer = getattr(optim, config.optimizer.name)(params=model.parameters(), **config.optimizer.args)
loss_func = partial(getattr(losses, config.loss.name), **config.loss.args)
if config.lr_scheduler.name:
scheduler = getattr(optim.lr_scheduler, config.lr_scheduler.name)(optimizer, **config.lr_scheduler.args)
else:
scheduler = None
with torch.autograd.detect_anomaly() if config.detect_anomaly else dummy_context_mgr(): # type: ignore
# run training
result = fit(
model=model,
loss_func=loss_func,
optimizer=optimizer,
scheduler=scheduler,
train_dl=train_dl,
valid_dl=val_dl,
config=config,
device=dev,
output_dir=paths.output_dir,
tensorboard_output_path=paths.tensorboard_output_path,
**asdict(config.training)
)
dump_experiment_result(args, config, paths.output_dir, result)
if urlparse(args.job_dir).scheme == "gs":
copy_local_to_gs(paths.local_base_output_path, args.job_dir)
assert_expected_metrics(result, config.expected_metrics)
if __name__ == "__main__":
run()