forked from speechbrain/speechbrain
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathenhancement.py
More file actions
178 lines (146 loc) · 5.34 KB
/
Copy pathenhancement.py
File metadata and controls
178 lines (146 loc) · 5.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
""" Specifies the inference interfaces for speech enhancement modules.
Authors:
* Aku Rouhe 2021
* Peter Plantinga 2021
* Loren Lugosch 2020
* Mirco Ravanelli 2020
* Titouan Parcollet 2021
* Abdel Heba 2021
* Andreas Nautsch 2022, 2023
* Pooneh Mousavi 2023
* Sylvain de Langen 2023
* Adel Moumen 2023
* Pradnya Kandarkar 2023
"""
import torch
import torchaudio
from speechbrain.inference.interfaces import Pretrained
from speechbrain.utils.callchains import lengths_arg_exists
class SpectralMaskEnhancement(Pretrained):
"""A ready-to-use model for speech enhancement.
Arguments
---------
See ``Pretrained``.
Example
-------
>>> import torch
>>> from speechbrain.inference.enhancement import SpectralMaskEnhancement
>>> # Model is downloaded from the speechbrain HuggingFace repo
>>> tmpdir = getfixture("tmpdir")
>>> enhancer = SpectralMaskEnhancement.from_hparams(
... source="speechbrain/metricgan-plus-voicebank",
... savedir=tmpdir,
... )
>>> enhanced = enhancer.enhance_file(
... "speechbrain/metricgan-plus-voicebank/example.wav"
... )
"""
HPARAMS_NEEDED = ["compute_stft", "spectral_magnitude", "resynth"]
MODULES_NEEDED = ["enhance_model"]
def compute_features(self, wavs):
"""Compute the log spectral magnitude features for masking.
Arguments
---------
wavs : torch.Tensor
A batch of waveforms to convert to log spectral mags.
"""
feats = self.hparams.compute_stft(wavs)
feats = self.hparams.spectral_magnitude(feats)
return torch.log1p(feats)
def enhance_batch(self, noisy, lengths=None):
"""Enhance a batch of noisy waveforms.
Arguments
---------
noisy : torch.Tensor
A batch of waveforms to perform enhancement on.
lengths : torch.Tensor
The lengths of the waveforms if the enhancement model handles them.
Returns
-------
torch.Tensor
A batch of enhanced waveforms of the same shape as input.
"""
noisy = noisy.to(self.device)
noisy_features = self.compute_features(noisy)
# Perform masking-based enhancement, multiplying output with input.
if lengths is not None:
mask = self.mods.enhance_model(noisy_features, lengths=lengths)
else:
mask = self.mods.enhance_model(noisy_features)
enhanced = torch.mul(mask, noisy_features)
# Return resynthesized waveforms
return self.hparams.resynth(torch.expm1(enhanced), noisy)
def enhance_file(self, filename, output_filename=None, **kwargs):
"""Enhance a wav file.
Arguments
---------
filename : str
Location on disk to load file for enhancement.
output_filename : str
If provided, writes enhanced data to this file.
"""
noisy = self.load_audio(filename, **kwargs)
noisy = noisy.to(self.device)
# Fake a batch:
batch = noisy.unsqueeze(0)
if lengths_arg_exists(self.enhance_batch):
enhanced = self.enhance_batch(batch, lengths=torch.tensor([1.0]))
else:
enhanced = self.enhance_batch(batch)
if output_filename is not None:
torchaudio.save(output_filename, enhanced, channels_first=False)
return enhanced.squeeze(0)
class WaveformEnhancement(Pretrained):
"""A ready-to-use model for speech enhancement.
Arguments
---------
See ``Pretrained``.
Example
-------
>>> from speechbrain.inference.enhancement import WaveformEnhancement
>>> # Model is downloaded from the speechbrain HuggingFace repo
>>> tmpdir = getfixture("tmpdir")
>>> enhancer = WaveformEnhancement.from_hparams(
... source="speechbrain/mtl-mimic-voicebank",
... savedir=tmpdir,
... )
>>> enhanced = enhancer.enhance_file(
... "speechbrain/mtl-mimic-voicebank/example.wav"
... )
"""
MODULES_NEEDED = ["enhance_model"]
def enhance_batch(self, noisy, lengths=None):
"""Enhance a batch of noisy waveforms.
Arguments
---------
noisy : torch.Tensor
A batch of waveforms to perform enhancement on.
lengths : torch.Tensor
The lengths of the waveforms if the enhancement model handles them.
Returns
-------
torch.Tensor
A batch of enhanced waveforms of the same shape as input.
"""
noisy = noisy.to(self.device)
enhanced_wav, _ = self.mods.enhance_model(noisy)
return enhanced_wav
def enhance_file(self, filename, output_filename=None, **kwargs):
"""Enhance a wav file.
Arguments
---------
filename : str
Location on disk to load file for enhancement.
output_filename : str
If provided, writes enhanced data to this file.
"""
noisy = self.load_audio(filename, **kwargs)
# Fake a batch:
batch = noisy.unsqueeze(0)
enhanced = self.enhance_batch(batch)
if output_filename is not None:
torchaudio.save(output_filename, enhanced, channels_first=False)
return enhanced.squeeze(0)
def forward(self, noisy, lengths=None):
"""Runs enhancement on the noisy input"""
return self.enhance_batch(noisy, lengths)