forked from speechbrain/speechbrain
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcore.py
More file actions
1367 lines (1202 loc) · 50.2 KB
/
core.py
File metadata and controls
1367 lines (1202 loc) · 50.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""Core SpeechBrain code for running experiments.
Authors
* Peter Plantinga 2020
* Abdel Heba 2020
* Mirco Ravanelli 2020
* Aku Rouhe 2021
* Andreas Nautsch 2022
"""
import os
import sys
import yaml
import time
import torch
import shutil
import logging
import inspect
import pathlib
import argparse
import tempfile
import warnings
from contextlib import contextmanager
import speechbrain as sb
from datetime import date
from enum import Enum, auto
from tqdm.contrib import tqdm
from types import SimpleNamespace
from torch.nn import SyncBatchNorm
from torch.utils.data import DataLoader
from torch.nn import DataParallel as DP
from torch.utils.data import IterableDataset
from torch.utils.data import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from hyperpyyaml import resolve_references
from speechbrain.utils.distributed import run_on_main
from speechbrain.dataio.dataloader import LoopedLoader
from speechbrain.dataio.dataloader import SaveableDataLoader
from speechbrain.dataio.sampler import DistributedSamplerWrapper
from speechbrain.dataio.sampler import ReproducibleRandomSampler
logger = logging.getLogger(__name__)
DEFAULT_LOG_CONFIG = os.path.dirname(os.path.abspath(__file__))
DEFAULT_LOG_CONFIG = os.path.join(DEFAULT_LOG_CONFIG, "log-config.yaml")
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
INTRA_EPOCH_CKPT_FLAG = "brain_intra_epoch_ckpt"
PYTHON_VERSION_MAJOR = 3
PYTHON_VERSION_MINOR = 7
def create_experiment_directory(
experiment_directory,
hyperparams_to_save=None,
overrides={},
log_config=DEFAULT_LOG_CONFIG,
save_env_desc=True,
):
"""Create the output folder and relevant experimental files.
Arguments
---------
experiment_directory : str
The place where the experiment directory should be created.
hyperparams_to_save : str
A filename of a yaml file representing the parameters for this
experiment. If passed, references are resolved, and the result is
written to a file in the experiment directory called "hyperparams.yaml".
overrides : dict
A mapping of replacements made in the yaml file, to save in yaml.
log_config : str
A yaml filename containing configuration options for the logger.
save_env_desc : bool
If True, an environment state description is saved to the experiment
directory, in a file called env.log in the experiment directory.
"""
try:
# all writing command must be done with the main_process
if sb.utils.distributed.if_main_process():
if not os.path.isdir(experiment_directory):
os.makedirs(experiment_directory)
# Write the parameters file
if hyperparams_to_save is not None:
hyperparams_filename = os.path.join(
experiment_directory, "hyperparams.yaml"
)
with open(hyperparams_to_save) as f:
resolved_yaml = resolve_references(f, overrides)
with open(hyperparams_filename, "w") as w:
print("# Generated %s from:" % date.today(), file=w)
print("# %s" % os.path.abspath(hyperparams_to_save), file=w)
print("# yamllint disable", file=w)
shutil.copyfileobj(resolved_yaml, w)
# Copy executing file to output directory
module = inspect.getmodule(inspect.currentframe().f_back)
if module is not None:
callingfile = os.path.realpath(module.__file__)
shutil.copy(callingfile, experiment_directory)
# Log exceptions to output automatically
log_file = os.path.join(experiment_directory, "log.txt")
logger_overrides = {
"handlers": {"file_handler": {"filename": log_file}}
}
sb.utils.logger.setup_logging(log_config, logger_overrides)
sys.excepthook = _logging_excepthook
# Log beginning of experiment!
logger.info("Beginning experiment!")
logger.info(f"Experiment folder: {experiment_directory}")
# Save system description:
if save_env_desc:
description_str = sb.utils.logger.get_environment_description()
with open(
os.path.join(experiment_directory, "env.log"), "w"
) as fo:
fo.write(description_str)
finally:
# wait for main_process if ddp is used
sb.utils.distributed.ddp_barrier()
def _logging_excepthook(exc_type, exc_value, exc_traceback):
"""Interrupt exception raising to log the error."""
logger.error("Exception:", exc_info=(exc_type, exc_value, exc_traceback))
def parse_arguments(arg_list=None):
"""Parse command-line arguments to the experiment.
Arguments
---------
arg_list : list, None
A list of arguments to parse. If not given, this is read from
`sys.argv[1:]`
Returns
-------
param_file : str
The location of the parameters file.
run_opts : dict
Run options, such as distributed, device, etc.
overrides : dict
The overrides to pass to ``load_hyperpyyaml``.
Example
-------
>>> argv = ['hyperparams.yaml', '--device', 'cuda:1', '--seed', '10']
>>> filename, run_opts, overrides = parse_arguments(argv)
>>> filename
'hyperparams.yaml'
>>> run_opts["device"]
'cuda:1'
>>> overrides
'seed: 10'
"""
if arg_list is None:
arg_list = sys.argv[1:]
parser = argparse.ArgumentParser(description="Run a SpeechBrain experiment")
parser.add_argument(
"param_file",
type=str,
help="A yaml-formatted file using the extended YAML syntax. "
"defined by SpeechBrain.",
)
parser.add_argument(
"--debug",
default=False,
action="store_true",
help="Run the experiment with only a few batches for all "
"datasets, to ensure code runs without crashing.",
)
parser.add_argument(
"--debug_batches",
type=int,
default=2,
help="Number of batches to run in debug mode.",
)
parser.add_argument(
"--debug_epochs",
type=int,
default=2,
help="Number of epochs to run in debug mode. "
"If a non-positive number is passed, all epochs are run.",
)
parser.add_argument(
"--log_config",
type=str,
help="A file storing the configuration options for logging",
)
# if use_env = False in torch.distributed.lunch then local_rank arg is given
parser.add_argument("--local_rank", type=int, help="Rank on local machine")
parser.add_argument(
"--device",
type=str,
default="cuda:0",
help="The device to run the experiment on (e.g. 'cuda:0')",
)
parser.add_argument(
"--data_parallel_backend",
default=False,
action="store_true",
help="This flag enables training with data_parallel.",
)
parser.add_argument(
"--distributed_launch",
default=False,
action="store_true",
help="This flag enables training with DDP. Assumes script run with "
"`torch.distributed.launch`",
)
parser.add_argument(
"--distributed_backend",
type=str,
default="nccl",
help="One of {nccl, gloo, mpi}",
)
parser.add_argument(
"--find_unused_parameters",
default=False,
action="store_true",
help="This flag disable unused parameters detection",
)
parser.add_argument(
"--jit_module_keys",
type=str,
nargs="*",
help="A list of keys in the 'modules' dict to jitify",
)
parser.add_argument(
"--auto_mix_prec",
default=None,
action="store_true",
help="This flag enables training with automatic mixed-precision.",
)
parser.add_argument(
"--max_grad_norm",
type=float,
help="Gradient norm will be clipped to this value, "
"enter negative value to disable.",
)
parser.add_argument(
"--nonfinite_patience",
type=int,
help="Max number of batches per epoch to skip if loss is nonfinite.",
)
parser.add_argument(
"--noprogressbar",
default=None,
action="store_true",
help="This flag disables the data loop progressbars.",
)
parser.add_argument(
"--ckpt_interval_minutes",
type=float,
help="Amount of time between saving intra-epoch checkpoints "
"in minutes. If non-positive, intra-epoch checkpoints are not saved.",
)
parser.add_argument(
"--grad_accumulation_factor",
type=int,
help="Number of batches to accumulate gradients before optimizer step",
)
parser.add_argument(
"--optimizer_step_limit",
type=int,
help="Number of optimizer steps to run. If not passed, all epochs are run.",
)
# Accept extra args to override yaml
run_opts, overrides = parser.parse_known_args(arg_list)
# Ignore items that are "None", they were not passed
run_opts = {k: v for k, v in vars(run_opts).items() if v is not None}
param_file = run_opts["param_file"]
del run_opts["param_file"]
overrides = _convert_to_yaml(overrides)
# Checking that DataParallel use the right number of GPU
if run_opts["data_parallel_backend"]:
if torch.cuda.device_count() == 0:
raise ValueError("You must have at least 1 GPU.")
# For DDP, the device args must equal to local_rank used by
# torch.distributed.launch. If run_opts["local_rank"] exists,
# use os.environ["LOCAL_RANK"]
local_rank = None
if "local_rank" in run_opts:
local_rank = run_opts["local_rank"]
else:
if "LOCAL_RANK" in os.environ and os.environ["LOCAL_RANK"] != "":
local_rank = int(os.environ["LOCAL_RANK"])
# force device arg to be the same as local_rank from torch.distributed.lunch
if local_rank is not None and "cuda" in run_opts["device"]:
run_opts["device"] = run_opts["device"][:-1] + str(local_rank)
return param_file, run_opts, overrides
def _convert_to_yaml(overrides):
"""Convert args to yaml for overrides"""
yaml_string = ""
# Handle '--arg=val' type args
joined_args = "=".join(overrides)
split_args = joined_args.split("=")
for arg in split_args:
if arg.startswith("--"):
yaml_string += "\n" + arg[len("--") :] + ":"
else:
yaml_string += " " + arg
return yaml_string.strip()
class Stage(Enum):
"""Simple enum to track stage of experiments."""
TRAIN = auto()
VALID = auto()
TEST = auto()
@sb.utils.checkpoints.register_checkpoint_hooks
class Brain:
"""Brain class abstracts away the details of data loops.
The primary purpose of the `Brain` class is the implementation of
the ``fit()`` method, which iterates epochs and datasets for the
purpose of "fitting" a set of modules to a set of data.
In order to use the ``fit()`` method, one should sub-class the ``Brain``
class and override any methods for which the default behavior does not
match the use case. For a simple use case (e.g., training a single model
with a single dataset) the only methods that need to be overridden are:
* ``compute_forward()``
* ``compute_objectives()``
The example below illustrates how overriding these two methods is done.
For more complicated use cases, such as multiple modules that need to
be updated, the following methods can be overridden:
* ``fit_batch()``
* ``evaluate_batch()``
Arguments
---------
modules : dict of str:torch.nn.Module pairs
These modules are passed to the optimizer by default if they have
trainable parameters, and will have ``train()``/``eval()`` called on them.
opt_class : torch.optim class
A torch optimizer constructor that has takes only the list of
parameters (e.g. a lambda or partial function definition). By default,
this will be passed all modules in ``modules`` at the
beginning of the ``fit()`` method. This behavior can be changed
by overriding the ``configure_optimizers()`` method.
hparams : dict
Each key:value pair should consist of a string key and a hyperparameter
that is used within the overridden methods. These will
be accessible via an ``hparams`` attribute, using "dot" notation:
e.g., self.hparams.model(x).
run_opts : dict
A set of options to change the runtime environment, including
debug (bool)
If ``True``, this will only iterate a few batches for all
datasets, to ensure code runs without crashing.
debug_batches (int)
Number of batches to run in debug mode, Default ``2``.
debug_epochs (int)
Number of epochs to run in debug mode, Default ``2``.
If a non-positive number is passed, all epochs are run.
jit_module_keys (list of str)
List of keys in ``modules`` that should be jit compiled.
distributed_backend (str)
One of ``nccl``, ``gloo``, ``mpi``.
device (str)
The location for performing computations.
auto_mix_prec (bool)
If ``True``, automatic mixed-precision is used.
Activate it only with cuda.
max_grad_norm (float)
Default implementation of ``fit_batch()`` uses
``clip_grad_norm_`` with this value. Default: ``5``.
nonfinite_patience (int)
Number of times to ignore non-finite losses before stopping.
Default: ``3``.
noprogressbar (bool)
Whether to turn off progressbar when training. Default: ``False``.
ckpt_interval_minutes (float)
Amount of time between saving intra-epoch checkpoints,
in minutes, default: ``15.0``. If non-positive, these are not saved.
Typically in a script this comes from ``speechbrain.parse_args``, which
has different defaults than Brain. If an option is not defined here
(keep in mind that parse_args will inject some options by default),
then the option is also searched for in hparams (by key).
checkpointer : speechbrain.Checkpointer
By default, this will be used to load checkpoints, and will have the
optimizer added to continue training if interrupted.
profiler : torch.profiler.profile
Context manager for profiling and benchmarking of training/inference steps.
Default: ``None`` (skip profiling).
Example
-------
>>> from torch.optim import SGD
>>> class SimpleBrain(Brain):
... def compute_forward(self, batch, stage):
... return self.modules.model(batch[0])
... def compute_objectives(self, predictions, batch, stage):
... return torch.nn.functional.l1_loss(predictions, batch[0])
>>> model = torch.nn.Linear(in_features=10, out_features=10)
>>> brain = SimpleBrain({"model": model}, opt_class=lambda x: SGD(x, 0.1))
>>> brain.fit(range(1), ([torch.rand(10, 10), torch.rand(10, 10)],))
"""
def __init__( # noqa: C901
self,
modules=None,
opt_class=None,
hparams=None,
run_opts=None,
checkpointer=None,
profiler=None,
):
self.opt_class = opt_class
self.checkpointer = checkpointer
self.profiler = profiler
# Arguments passed via the run opts dictionary
run_opt_defaults = {
"debug": False,
"debug_batches": 2,
"debug_epochs": 2,
"device": "cpu",
"data_parallel_backend": False,
"distributed_launch": False,
"distributed_backend": "nccl",
"find_unused_parameters": False,
"jit_module_keys": None,
"auto_mix_prec": False,
"max_grad_norm": 5.0,
"nonfinite_patience": 3,
"noprogressbar": False,
"ckpt_interval_minutes": 0,
"grad_accumulation_factor": 1,
"optimizer_step_limit": None,
}
for arg, default in run_opt_defaults.items():
if run_opts is not None and arg in run_opts:
if hparams is not None and arg in hparams:
logger.info(
"Info: "
+ arg
+ " arg overridden by command line input to: "
+ str(run_opts[arg])
)
setattr(self, arg, run_opts[arg])
else:
# If any arg from run_opt_defaults exist in hparams and
# not in command line args "run_opts"
if hparams is not None and arg in hparams:
logger.info(
"Info: " + arg + " arg from hparam file is used"
)
setattr(self, arg, hparams[arg])
else:
setattr(self, arg, default)
# Check Python version
if not (
sys.version_info.major == PYTHON_VERSION_MAJOR
and sys.version_info.minor >= PYTHON_VERSION_MINOR
):
logger.warn(
"Detected Python "
+ str(sys.version_info.major)
+ "."
+ str(sys.version_info.minor)
+ ". We suggest using SpeechBrain with Python >="
+ str(PYTHON_VERSION_MAJOR)
+ "."
+ str(PYTHON_VERSION_MINOR)
)
if self.data_parallel_backend and self.distributed_launch:
sys.exit(
"To use data_parallel backend, start your script with:\n\t"
"python experiment.py hyperparams.yaml "
"--data_parallel_backend=True"
"To use DDP backend, start your script with:\n\t"
"python -m torch.distributed.lunch [args]\n"
"experiment.py hyperparams.yaml --distributed_launch=True "
"--distributed_backend=nccl"
)
# Switch to the right context
if self.device == "cuda":
torch.cuda.set_device(0)
elif "cuda" in self.device:
torch.cuda.set_device(int(self.device[-1]))
# Put modules on the right device, accessible with dot notation
self.modules = torch.nn.ModuleDict(modules).to(self.device)
# Make hyperparams available with dot notation too
if hparams is not None:
self.hparams = SimpleNamespace(**hparams)
# Checkpointer should point at a temporary directory in debug mode
if (
self.debug
and self.checkpointer is not None
and hasattr(self.checkpointer, "checkpoints_dir")
):
tempdir = tempfile.TemporaryDirectory()
logger.info(
"Since debug mode is active, switching checkpointer "
f"output to temporary directory: {tempdir.name}"
)
self.checkpointer.checkpoints_dir = pathlib.Path(tempdir.name)
# Keep reference to tempdir as long as checkpointer exists
self.checkpointer.tempdir = tempdir
# Sampler should be handled by `make_dataloader`
# or if you provide a DataLoader directly, you can set
# this.train_sampler = your_sampler
# to have your_sampler.set_epoch() called on each epoch.
self.train_sampler = None
# Automatic mixed precision init
if self.auto_mix_prec:
self.scaler = torch.cuda.amp.GradScaler()
if self.checkpointer is not None:
self.checkpointer.add_recoverable("scaler", self.scaler)
# List parameter count for the user
total_params = sum(
p.numel() for p in self.modules.parameters() if p.requires_grad
)
if total_params > 0:
clsname = self.__class__.__name__
fmt_num = sb.utils.logger.format_order_of_magnitude(total_params)
logger.info(f"{fmt_num} trainable parameters in {clsname}")
if self.distributed_launch:
self.rank = int(os.environ["RANK"])
if not torch.distributed.is_initialized():
if self.rank > 0:
sys.exit(
" ================ WARNING ==============="
"Please add sb.ddp_init_group() into your exp.py"
"To use DDP backend, start your script with:\n\t"
"python -m torch.distributed.launch [args]\n\t"
"experiment.py hyperparams.yaml "
"--distributed_launch=True --distributed_backend=nccl"
)
else:
logger.warn(
"To use DDP, please add "
"sb.utils.distributed.ddp_init_group() into your exp.py"
)
logger.info(
"Only the main process is alive, "
"all other subprocess were killed."
)
# Prepare iterating variables
self.avg_train_loss = 0.0
self.step = 0
self.optimizer_step = 0
# Add this class to the checkpointer for intra-epoch checkpoints
if self.checkpointer is not None:
self.checkpointer.add_recoverable("brain", self)
def compute_forward(self, batch, stage):
"""Forward pass, to be overridden by sub-classes.
Arguments
---------
batch : torch.Tensor or tensors
An element from the dataloader, including inputs for processing.
stage : Stage
The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST
Returns
-------
torch.Tensor or Tensors
The outputs after all processing is complete.
Directly passed to ``compute_objectives()``.
"""
raise NotImplementedError
def compute_objectives(self, predictions, batch, stage):
"""Compute loss, to be overridden by sub-classes.
Arguments
---------
predictions : torch.Tensor or Tensors
The output tensor or tensors to evaluate.
Comes directly from ``compute_forward()``.
batch : torch.Tensor or tensors
An element from the dataloader, including targets for comparison.
stage : Stage
The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST
Returns
-------
loss : torch.Tensor
A tensor with the computed loss.
"""
raise NotImplementedError
def on_stage_start(self, stage, epoch=None):
"""Gets called when a stage starts.
Useful for defining class variables used during the stage.
Arguments
---------
stage : Stage
The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST
epoch : int
The current epoch count.
"""
pass
def on_stage_end(self, stage, stage_loss, epoch=None):
"""Gets called at the end of a stage.
Useful for computing stage statistics, saving checkpoints, etc.
Arguments
---------
stage : Stage
The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST
stage_loss : float
The average loss over the completed stage.
epoch : int
The current epoch count.
"""
pass
def make_dataloader(
self, dataset, stage, ckpt_prefix="dataloader-", **loader_kwargs
):
"""Creates DataLoaders for Datasets.
This is used by ``fit()`` and ``evaluate()`` if they just receive
Datasets.
Alternatively, this can be called from outside the Brain subclass.
In that case, the DataLoader should be passed to ``fit()`` in place
of the dataset.
The Stage.TRAIN DataLoader is handled specially. It has extra args for
shuffle and drop_last. In DDP a DistributedSampler is created (unless
the dataset is an IterableDataset).
NOTE
----
Some important DataLoader arguments are passed via **loader_kwargs,
e.g., batch_size, num_workers, pin_memory.
NOTE
----
By default, ``evaluate()`` specifies ckpt_prefix=None to stop the test
DataLoader being added to the checkpointer. If you need to add a
recoverable after saving checkpoints (e.g., at test time, after
checkpointing the training), and still be able to recover reasonably,
you should probably specify ``allow_partial_load=True``.
Arguments
---------
dataset : Dataset
A set of data to use to create data loader. If the Dataset is a
DynamicItemDataset, PaddedBatch is used as the default collate_fn,
unless specified in loader_kwargs.
stage : Stage
The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST
ckpt_prefix : str, None
Prefix to use for SaveableDataLoader Checkpoint name. The Stage
name is added to this to create the full key. Set to None to not
save the DataLoader.
**loader_kwargs : dict
Additional keyword arguments to the DataLoader.
E.g., batch_size, num_workers, pin_memory.
"""
# TRAIN stage is handled specially.
if stage == sb.Stage.TRAIN:
loader_kwargs = self._train_loader_specifics(dataset, loader_kwargs)
dataloader = sb.dataio.dataloader.make_dataloader(
dataset, **loader_kwargs
)
if (
self.checkpointer is not None
and ckpt_prefix is not None
and (
isinstance(dataloader, SaveableDataLoader)
or isinstance(dataloader, LoopedLoader)
)
):
ckpt_key = ckpt_prefix + stage.name
self.checkpointer.add_recoverable(ckpt_key, dataloader)
return dataloader
def _train_loader_specifics(self, dataset, loader_kwargs):
sampler = loader_kwargs.get("sampler", None)
# Shuffling should really only matter for the train stage. Shuffling
# will also lead to more padding in batches if the order was otherwise
# sorted by length.
shuffle = loader_kwargs.get("shuffle", False)
if shuffle and not self.distributed_launch:
if sampler is not None:
raise ValueError(
"Cannot specify both shuffle=True"
"and a sampler in loader_kwargs"
)
sampler = ReproducibleRandomSampler(dataset)
self.train_sampler = sampler
loader_kwargs["sampler"] = self.train_sampler
# Delete the shuffle flag, since you cannot specify both a sampler and
# shuffling:
del loader_kwargs["shuffle"]
# Possibly make a DistributedSampler or a wrapper for some other sampler
if self.distributed_launch and not isinstance(dataset, IterableDataset):
drop_last = loader_kwargs.get("drop_last", False)
# num_replicas arg is equal to world_size
# and retrieved automatically within
# DistributedSampler obj.
if sampler is not None:
self.train_sampler = DistributedSamplerWrapper(
sampler,
rank=self.rank,
drop_last=drop_last,
shuffle=shuffle,
)
# with DistributedSamplerWrapper, one must disable shuffling for dataloader
loader_kwargs["shuffle"] = False
loader_kwargs["sampler"] = self.train_sampler
elif loader_kwargs.get("batch_sampler") is None:
# no sampler and batch-sampler
self.train_sampler = DistributedSampler(
dataset, rank=self.rank, shuffle=True, drop_last=drop_last
)
# with DistributedSamplerWrapper, one must disable shuffling for dataloader
loader_kwargs["shuffle"] = False
loader_kwargs["sampler"] = self.train_sampler
else: # batch_sampler was specified
self.train_sampler = DistributedSamplerWrapper(
loader_kwargs.get("batch_sampler", None),
rank=self.rank,
shuffle=True,
)
loader_kwargs["batch_sampler"] = self.train_sampler
elif self.distributed_launch and isinstance(dataset, IterableDataset):
logger.warning(
"Cannot automatically solve distributed sampling "
"for IterableDataset."
)
return loader_kwargs
def on_fit_start(self):
"""Gets called at the beginning of ``fit()``, on multiple processes
if ``distributed_count > 0`` and backend is ddp.
Default implementation compiles the jit modules, initializes
optimizers, and loads the latest checkpoint to resume training.
"""
# Run this *after* starting all processes since jit modules cannot be
# pickled.
self._compile_jit()
# Wrap modules with parallel backend after jit
self._wrap_distributed()
# Initialize optimizers after parameters are configured
self.init_optimizers()
# Load latest checkpoint to resume training if interrupted
if self.checkpointer is not None:
self.checkpointer.recover_if_possible(
device=torch.device(self.device)
)
def init_optimizers(self):
"""Called during ``on_fit_start()``, initialize optimizers
after parameters are fully configured (e.g. DDP, jit).
The default implementation of this method depends on an optimizer
class being passed at initialization that takes only a list
of parameters (e.g., a lambda or a partial function definition).
This creates a single optimizer that optimizes all trainable params.
Override this class if there are multiple optimizers.
"""
if self.opt_class is not None:
self.optimizer = self.opt_class(self.modules.parameters())
if self.checkpointer is not None:
self.checkpointer.add_recoverable("optimizer", self.optimizer)
def on_evaluate_start(self, max_key=None, min_key=None):
"""Gets called at the beginning of ``evaluate()``
Default implementation loads the best-performing checkpoint for
evaluation, based on stored metrics.
Arguments
---------
max_key : str
Key to use for finding best checkpoint (higher is better).
By default, passed to ``self.checkpointer.recover_if_possible()``.
min_key : str
Key to use for finding best checkpoint (lower is better).
By default, passed to ``self.checkpointer.recover_if_possible()``.
"""
# Recover best checkpoint for evaluation
if self.checkpointer is not None:
self.checkpointer.recover_if_possible(
max_key=max_key,
min_key=min_key,
device=torch.device(self.device),
)
def fit_batch(self, batch):
"""Fit one batch, override to do multiple updates.
The default implementation depends on a few methods being defined
with a particular behavior:
* ``compute_forward()``
* ``compute_objectives()``
Also depends on having optimizers passed at initialization.
Arguments
---------
batch : list of torch.Tensors
Batch of data to use for training. Default implementation assumes
this batch has two elements: inputs and targets.
Returns
-------
detached loss
"""
should_step = self.step % self.grad_accumulation_factor == 0
# Managing automatic mixed precision
if self.auto_mix_prec:
self.optimizer.zero_grad()
with torch.cuda.amp.autocast():
outputs = self.compute_forward(batch, Stage.TRAIN)
loss = self.compute_objectives(outputs, batch, Stage.TRAIN)
with self.no_sync(not should_step):
self.scaler.scale(
loss / self.grad_accumulation_factor
).backward()
if should_step:
self.scaler.unscale_(self.optimizer)
if self.check_gradients(loss):
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer_step += 1
else:
outputs = self.compute_forward(batch, Stage.TRAIN)
loss = self.compute_objectives(outputs, batch, Stage.TRAIN)
with self.no_sync(not should_step):
(loss / self.grad_accumulation_factor).backward()
if should_step:
if self.check_gradients(loss):
self.optimizer.step()
self.optimizer.zero_grad()
self.optimizer_step += 1
self.on_fit_batch_end(batch, outputs, loss, should_step)
return loss.detach().cpu()
def on_fit_batch_end(self, batch, outputs, loss, should_step):
"""Called after ``fit_batch()``, meant for calculating and logging metrics.
Arguments
---------
batch : list of torch.Tensors
Batch of data to use for training. Default implementation assumes
this batch has two elements: inputs and targets.
outputs : list or dictionary of torch.Tensors
Returned value of compute_forward().
loss : torch.Tensor
Returned value of compute_objectives().
should_step : boolean
Whether optimizer.step() was called or not.
"""
pass
def check_gradients(self, loss):
"""Check if gradients are finite and not too large.
Automatically clips large gradients.
Arguments
---------
loss : tensor
The loss tensor after ``backward()`` has been called but
before the optimizers ``step()``.
Returns
-------
bool
Whether or not the optimizer step should be carried out.
"""
if not torch.isfinite(loss):
self.nonfinite_count += 1
# Print helpful debug info
logger.warn(f"Loss is {loss}.")
for p in self.modules.parameters():
if not torch.isfinite(p).all():
logger.warn("Parameter is not finite: " + str(p))
# Check if patience is exhausted
if self.nonfinite_count > self.nonfinite_patience:
raise ValueError(
"Loss is not finite and patience is exhausted. "
"To debug, wrap `fit()` with "
"autograd's `detect_anomaly()`, e.g.\n\nwith "
"torch.autograd.detect_anomaly():\n\tbrain.fit(...)"
)
else:
logger.warn("Patience not yet exhausted, ignoring this batch.")
return False
# Clip gradient norm
if self.max_grad_norm > 0.0:
torch.nn.utils.clip_grad_norm_(
(p for p in self.modules.parameters()), self.max_grad_norm
)
return True
def evaluate_batch(self, batch, stage):
"""Evaluate one batch, override for different procedure than train.
The default implementation depends on two methods being defined
with a particular behavior:
* ``compute_forward()``
* ``compute_objectives()``
Arguments
---------
batch : list of torch.Tensors
Batch of data to use for evaluation. Default implementation assumes
this batch has two elements: inputs and targets.
stage : Stage
The stage of the experiment: Stage.VALID, Stage.TEST
Returns
-------
detached loss
"""
out = self.compute_forward(batch, stage=stage)
loss = self.compute_objectives(out, batch, stage=stage)
return loss.detach().cpu()
def _fit_train(self, train_set, epoch, enable):
# Training stage
self.on_stage_start(Stage.TRAIN, epoch)
self.modules.train()
# Reset nonfinite count to 0 each epoch
self.nonfinite_count = 0
if self.train_sampler is not None and hasattr(
self.train_sampler, "set_epoch"
):
self.train_sampler.set_epoch(epoch)
# Time since last intra-epoch checkpoint
last_ckpt_time = time.time()
with tqdm(