forked from microsoft/qlib
-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainer.py
583 lines (487 loc) · 21.2 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
The Trainer will train a list of tasks and return a list of model recorders.
There are two steps in each Trainer including ``train``(make model recorder) and ``end_train``(modify model recorder).
This is a concept called ``DelayTrainer``, which can be used in online simulating for parallel training.
In ``DelayTrainer``, the first step is only to save some necessary info to model recorders, and the second step which will be finished in the end can do some concurrent and time-consuming operations such as model fitting.
``Qlib`` offer two kinds of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically.
"""
import socket
from typing import Callable, List
from tqdm.auto import tqdm
from qlib.data.dataset import Dataset
from qlib.model.base import Model
from qlib.utils import flatten_dict, init_instance_by_config, auto_filter_kwargs, fill_placeholder
from qlib.workflow import R
from qlib.workflow.recorder import Recorder
from qlib.workflow.task.manage import TaskManager, run_task
from qlib.data.dataset.weight import Reweighter
def _log_task_info(task_config: dict):
R.log_params(**flatten_dict(task_config))
R.save_objects(**{"task": task_config}) # keep the original format and datatype
R.set_tags(**{"hostname": socket.gethostname()})
def _exe_task(task_config: dict):
rec = R.get_recorder()
# model & dataset initiation
model: Model = init_instance_by_config(task_config["model"], accept_types=Model)
dataset: Dataset = init_instance_by_config(task_config["dataset"], accept_types=Dataset)
reweighter: Reweighter = task_config.get("reweighter", None)
# model training
auto_filter_kwargs(model.fit)(dataset, reweighter=reweighter)
R.save_objects(**{"params.pkl": model})
# this dataset is saved for online inference. So the concrete data should not be dumped
dataset.config(dump_all=False, recursive=True)
R.save_objects(**{"dataset": dataset})
# fill placehorder
placehorder_value = {"<MODEL>": model, "<DATASET>": dataset}
task_config = fill_placeholder(task_config, placehorder_value)
# generate records: prediction, backtest, and analysis
records = task_config.get("record", [])
if isinstance(records, dict): # prevent only one dict
records = [records]
for record in records:
# Some recorder require the parameter `model` and `dataset`.
# try to automatically pass in them to the initialization function
# to make defining the tasking easier
r = init_instance_by_config(
record,
recorder=rec,
default_module="qlib.workflow.record_temp",
try_kwargs={"model": model, "dataset": dataset},
)
r.generate()
def begin_task_train(task_config: dict, experiment_name: str, recorder_name: str = None) -> Recorder:
"""
Begin task training to start a recorder and save the task config.
Args:
task_config (dict): the config of a task
experiment_name (str): the name of experiment
recorder_name (str): the given name will be the recorder name. None for using rid.
Returns:
Recorder: the model recorder
"""
with R.start(experiment_name=experiment_name, recorder_name=recorder_name):
_log_task_info(task_config)
return R.get_recorder()
def end_task_train(rec: Recorder, experiment_name: str) -> Recorder:
"""
Finish task training with real model fitting and saving.
Args:
rec (Recorder): the recorder will be resumed
experiment_name (str): the name of experiment
Returns:
Recorder: the model recorder
"""
with R.start(experiment_name=experiment_name, recorder_id=rec.info["id"], resume=True):
task_config = R.load_object("task")
_exe_task(task_config)
return rec
def task_train(task_config: dict, experiment_name: str, recorder_name: str = None) -> Recorder:
"""
Task based training, will be divided into two steps.
Parameters
----------
task_config : dict
The config of a task.
experiment_name: str
The name of experiment
recorder_name: str
The name of recorder
Returns
----------
Recorder: The instance of the recorder
"""
with R.start(experiment_name=experiment_name, recorder_name=recorder_name):
_log_task_info(task_config)
_exe_task(task_config)
return R.get_recorder()
class Trainer:
"""
The trainer can train a list of models.
There are Trainer and DelayTrainer, which can be distinguished by when it will finish real training.
"""
def __init__(self):
self.delay = False
def train(self, tasks: list, *args, **kwargs) -> list:
"""
Given a list of task definitions, begin training, and return the models.
For Trainer, it finishes real training in this method.
For DelayTrainer, it only does some preparation in this method.
Args:
tasks: a list of tasks
Returns:
list: a list of models
"""
raise NotImplementedError(f"Please implement the `train` method.")
def end_train(self, models: list, *args, **kwargs) -> list:
"""
Given a list of models, finished something at the end of training if you need.
The models may be Recorder, txt file, database, and so on.
For Trainer, it does some finishing touches in this method.
For DelayTrainer, it finishes real training in this method.
Args:
models: a list of models
Returns:
list: a list of models
"""
# do nothing if you finished all work in `train` method
return models
def is_delay(self) -> bool:
"""
If Trainer will delay finishing `end_train`.
Returns:
bool: if DelayTrainer
"""
return self.delay
def __call__(self, *args, **kwargs) -> list:
return self.end_train(self.train(*args, **kwargs))
def has_worker(self) -> bool:
"""
Some trainer has backend worker to support parallel training
This method can tell if the worker is enabled.
Returns
-------
bool:
if the worker is enabled
"""
return False
def worker(self):
"""
start the worker
Raises
------
NotImplementedError:
If the worker is not supported
"""
raise NotImplementedError(f"Please implement the `worker` method")
class TrainerR(Trainer):
"""
Trainer based on (R)ecorder.
It will train a list of tasks and return a list of model recorders in a linear way.
Assumption: models were defined by `task` and the results will be saved to `Recorder`.
"""
# Those tag will help you distinguish whether the Recorder has finished traning
STATUS_KEY = "train_status"
STATUS_BEGIN = "begin_task_train"
STATUS_END = "end_task_train"
def __init__(self, experiment_name: str = None, train_func: Callable = task_train):
"""
Init TrainerR.
Args:
experiment_name (str, optional): the default name of experiment.
train_func (Callable, optional): default training method. Defaults to `task_train`.
"""
super().__init__()
self.experiment_name = experiment_name
self.train_func = train_func
def train(self, tasks: list, train_func: Callable = None, experiment_name: str = None, **kwargs) -> List[Recorder]:
"""
Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed.
Args:
tasks (list): a list of definitions based on `task` dict
train_func (Callable): the training method which needs at least `tasks` and `experiment_name`. None for the default training method.
experiment_name (str): the experiment name, None for use default name.
kwargs: the params for train_func.
Returns:
List[Recorder]: a list of Recorders
"""
if isinstance(tasks, dict):
tasks = [tasks]
if len(tasks) == 0:
return []
if train_func is None:
train_func = self.train_func
if experiment_name is None:
experiment_name = self.experiment_name
recs = []
for task in tqdm(tasks, desc="train tasks"):
rec = train_func(task, experiment_name, **kwargs)
rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN})
recs.append(rec)
return recs
def end_train(self, models: list, **kwargs) -> List[Recorder]:
"""
Set STATUS_END tag to the recorders.
Args:
models (list): a list of trained recorders.
Returns:
List[Recorder]: the same list as the param.
"""
if isinstance(models, Recorder):
models = [models]
for rec in models:
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
return models
class DelayTrainerR(TrainerR):
"""
A delayed implementation based on TrainerR, which means `train` method may only do some preparation and `end_train` method can do the real model fitting.
"""
def __init__(self, experiment_name: str = None, train_func=begin_task_train, end_train_func=end_task_train):
"""
Init TrainerRM.
Args:
experiment_name (str): the default name of experiment.
train_func (Callable, optional): default train method. Defaults to `begin_task_train`.
end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`.
"""
super().__init__(experiment_name, train_func)
self.end_train_func = end_train_func
self.delay = True
def end_train(self, models, end_train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
"""
Given a list of Recorder and return a list of trained Recorder.
This class will finish real data loading and model fitting.
Args:
models (list): a list of Recorder, the tasks have been saved to them
end_train_func (Callable, optional): the end_train method which needs at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
experiment_name (str): the experiment name, None for use default name.
kwargs: the params for end_train_func.
Returns:
List[Recorder]: a list of Recorders
"""
if isinstance(models, Recorder):
models = [models]
if end_train_func is None:
end_train_func = self.end_train_func
if experiment_name is None:
experiment_name = self.experiment_name
for rec in models:
if rec.list_tags()[self.STATUS_KEY] == self.STATUS_END:
continue
end_train_func(rec, experiment_name, **kwargs)
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
return models
class TrainerRM(Trainer):
"""
Trainer based on (R)ecorder and Task(M)anager.
It can train a list of tasks and return a list of model recorders in a multiprocessing way.
Assumption: `task` will be saved to TaskManager and `task` will be fetched and trained from TaskManager
"""
# Those tag will help you distinguish whether the Recorder has finished traning
STATUS_KEY = "train_status"
STATUS_BEGIN = "begin_task_train"
STATUS_END = "end_task_train"
# This tag is the _id in TaskManager to distinguish tasks.
TM_ID = "_id in TaskManager"
def __init__(
self, experiment_name: str = None, task_pool: str = None, train_func=task_train, skip_run_task: bool = False
):
"""
Init TrainerR.
Args:
experiment_name (str): the default name of experiment.
task_pool (str): task pool name in TaskManager. None for use same name as experiment_name.
train_func (Callable, optional): default training method. Defaults to `task_train`.
skip_run_task (bool):
If skip_run_task == True:
Only run_task in the worker. Otherwise skip run_task.
"""
super().__init__()
self.experiment_name = experiment_name
self.task_pool = task_pool
self.train_func = train_func
self.skip_run_task = skip_run_task
def train(
self,
tasks: list,
train_func: Callable = None,
experiment_name: str = None,
before_status: str = TaskManager.STATUS_WAITING,
after_status: str = TaskManager.STATUS_DONE,
**kwargs,
) -> List[Recorder]:
"""
Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed.
This method defaults to a single process, but TaskManager offered a great way to parallel training.
Users can customize their train_func to realize multiple processes or even multiple machines.
Args:
tasks (list): a list of definitions based on `task` dict
train_func (Callable): the training method which needs at least `task`s and `experiment_name`. None for the default training method.
experiment_name (str): the experiment name, None for use default name.
before_status (str): the tasks in before_status will be fetched and trained. Can be STATUS_WAITING, STATUS_PART_DONE.
after_status (str): the tasks after trained will become after_status. Can be STATUS_WAITING, STATUS_PART_DONE.
kwargs: the params for train_func.
Returns:
List[Recorder]: a list of Recorders
"""
if isinstance(tasks, dict):
tasks = [tasks]
if len(tasks) == 0:
return []
if train_func is None:
train_func = self.train_func
if experiment_name is None:
experiment_name = self.experiment_name
task_pool = self.task_pool
if task_pool is None:
task_pool = experiment_name
tm = TaskManager(task_pool=task_pool)
_id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB
query = {"_id": {"$in": _id_list}}
if not self.skip_run_task:
run_task(
train_func,
task_pool,
query=query, # only train these tasks
experiment_name=experiment_name,
before_status=before_status,
after_status=after_status,
**kwargs,
)
if not self.is_delay():
tm.wait(query=query)
recs = []
for _id in _id_list:
rec = tm.re_query(_id)["res"]
rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN})
rec.set_tags(**{self.TM_ID: _id})
recs.append(rec)
return recs
def end_train(self, recs: list, **kwargs) -> List[Recorder]:
"""
Set STATUS_END tag to the recorders.
Args:
recs (list): a list of trained recorders.
Returns:
List[Recorder]: the same list as the param.
"""
if isinstance(recs, Recorder):
recs = [recs]
for rec in recs:
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
return recs
def worker(
self,
train_func: Callable = None,
experiment_name: str = None,
):
"""
The multiprocessing method for `train`. It can share a same task_pool with `train` and can run in other progress or other machines.
Args:
train_func (Callable): the training method which needs at least `task`s and `experiment_name`. None for the default training method.
experiment_name (str): the experiment name, None for use default name.
"""
if train_func is None:
train_func = self.train_func
if experiment_name is None:
experiment_name = self.experiment_name
task_pool = self.task_pool
if task_pool is None:
task_pool = experiment_name
run_task(train_func, task_pool=task_pool, experiment_name=experiment_name)
def has_worker(self) -> bool:
return True
class DelayTrainerRM(TrainerRM):
"""
A delayed implementation based on TrainerRM, which means `train` method may only do some preparation and `end_train` method can do the real model fitting.
"""
def __init__(
self,
experiment_name: str = None,
task_pool: str = None,
train_func=begin_task_train,
end_train_func=end_task_train,
skip_run_task: bool = False,
):
"""
Init DelayTrainerRM.
Args:
experiment_name (str): the default name of experiment.
task_pool (str): task pool name in TaskManager. None for use same name as experiment_name.
train_func (Callable, optional): default train method. Defaults to `begin_task_train`.
end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`.
skip_run_task (bool):
If skip_run_task == True:
Only run_task in the worker. Otherwise skip run_task.
E.g. Starting trainer on a CPU VM and then waiting tasks to be finished on GPU VMs.
"""
super().__init__(experiment_name, task_pool, train_func)
self.end_train_func = end_train_func
self.delay = True
self.skip_run_task = skip_run_task
def train(self, tasks: list, train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
"""
Same as `train` of TrainerRM, after_status will be STATUS_PART_DONE.
Args:
tasks (list): a list of definition based on `task` dict
train_func (Callable): the train method which need at least `task`s and `experiment_name`. Defaults to None for using self.train_func.
experiment_name (str): the experiment name, None for use default name.
Returns:
List[Recorder]: a list of Recorders
"""
if isinstance(tasks, dict):
tasks = [tasks]
if len(tasks) == 0:
return []
_skip_run_task = self.skip_run_task
self.skip_run_task = False # The task preparation can't be skipped
res = super().train(
tasks,
train_func=train_func,
experiment_name=experiment_name,
after_status=TaskManager.STATUS_PART_DONE,
**kwargs,
)
self.skip_run_task = _skip_run_task
return res
def end_train(self, recs, end_train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
"""
Given a list of Recorder and return a list of trained Recorder.
This class will finish real data loading and model fitting.
Args:
recs (list): a list of Recorder, the tasks have been saved to them.
end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
experiment_name (str): the experiment name, None for use default name.
kwargs: the params for end_train_func.
Returns:
List[Recorder]: a list of Recorders
"""
if isinstance(recs, Recorder):
recs = [recs]
if end_train_func is None:
end_train_func = self.end_train_func
if experiment_name is None:
experiment_name = self.experiment_name
task_pool = self.task_pool
if task_pool is None:
task_pool = experiment_name
_id_list = []
for rec in recs:
_id_list.append(rec.list_tags()[self.TM_ID])
query = {"_id": {"$in": _id_list}}
if not self.skip_run_task:
run_task(
end_train_func,
task_pool,
query=query, # only train these tasks
experiment_name=experiment_name,
before_status=TaskManager.STATUS_PART_DONE,
**kwargs,
)
TaskManager(task_pool=task_pool).wait(query=query)
for rec in recs:
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
return recs
def worker(self, end_train_func=None, experiment_name: str = None):
"""
The multiprocessing method for `end_train`. It can share a same task_pool with `end_train` and can run in other progress or other machines.
Args:
end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
experiment_name (str): the experiment name, None for use default name.
"""
if end_train_func is None:
end_train_func = self.end_train_func
if experiment_name is None:
experiment_name = self.experiment_name
task_pool = self.task_pool
if task_pool is None:
task_pool = experiment_name
run_task(
end_train_func,
task_pool=task_pool,
experiment_name=experiment_name,
before_status=TaskManager.STATUS_PART_DONE,
)
def has_worker(self) -> bool:
return True