forked from speechbrain/speechbrain
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_pretrainer.py
More file actions
22 lines (19 loc) · 843 Bytes
/
Copy pathtest_pretrainer.py
File metadata and controls
22 lines (19 loc) · 843 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def test_pretrainer(tmpdir, device):
import torch
from torch.nn import Linear
# save a model in tmpdir/original/model.ckpt
first_model = Linear(32, 32).to(device)
pretrained_dir = tmpdir / "original"
pretrained_dir.mkdir()
with open(pretrained_dir / "model.ckpt", "wb") as fo:
torch.save(first_model.state_dict(), fo)
# Make a new model and Pretrainer
pretrained_model = Linear(32, 32).to(device)
assert not torch.all(torch.eq(pretrained_model.weight, first_model.weight))
from speechbrain.utils.parameter_transfer import Pretrainer
pt = Pretrainer(
collect_in=tmpdir / "reused", loadables={"model": pretrained_model}
)
pt.collect_files(default_source=pretrained_dir)
pt.load_collected()
assert torch.all(torch.eq(pretrained_model.weight, first_model.weight))