-
Notifications
You must be signed in to change notification settings - Fork 228
/
Copy pathembedder.py
33 lines (27 loc) · 1.19 KB
/
embedder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import torch.nn as nn
class LinearNorm(nn.Module):
def __init__(self, hp):
super(LinearNorm, self).__init__()
self.linear_layer = nn.Linear(hp.embedder.lstm_hidden, hp.embedder.emb_dim)
def forward(self, x):
return self.linear_layer(x)
class SpeechEmbedder(nn.Module):
def __init__(self, hp):
super(SpeechEmbedder, self).__init__()
self.lstm = nn.LSTM(hp.embedder.num_mels,
hp.embedder.lstm_hidden,
num_layers=hp.embedder.lstm_layers,
batch_first=True)
self.proj = LinearNorm(hp)
self.hp = hp
def forward(self, mel):
# (num_mels, T)
mels = mel.unfold(1, self.hp.embedder.window, self.hp.embedder.stride) # (num_mels, T', window)
mels = mels.permute(1, 2, 0) # (T', window, num_mels)
x, _ = self.lstm(mels) # (T', window, lstm_hidden)
x = x[:, -1, :] # (T', lstm_hidden), use last frame only
x = self.proj(x) # (T', emb_dim)
x = x / torch.norm(x, p=2, dim=1, keepdim=True) # (T', emb_dim)
x = x.sum(0) / x.size(0) # (emb_dim), average pooling over time frames
return x