-
Notifications
You must be signed in to change notification settings - Fork 3
/
my_datasets.py
112 lines (80 loc) · 2.85 KB
/
my_datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
import torchaudio
import string
import pandas as pd
import math
from torch import distributions
from torch.nn.utils.rnn import pad_sequence
from utils import TextTransform
class TrainDataset(torch.utils.data.Dataset):
"""Custom competition dataset."""
def __init__(self, csv_file, transform=None):
"""
Args:
csv_file (string): Path to the csv file with annotations.
transform (callable, optional): Optional transform to be applied on a sample.
"""
self.answers = pd.read_csv(csv_file, '\t')
self.transform = transform
def __len__(self):
return len(self.answers)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
utt_name = 'cv-corpus-5.1-2020-06-22/en/clips/' + self.answers.loc[idx, 'path']
utt = torchaudio.load(utt_name)[0].squeeze()
if len(utt.shape) != 1:
utt = utt[1]
answer = self.answers.loc[idx, 'sentence']
if self.transform:
utt = self.transform(utt)
sample = {'utt': utt, 'answer': answer}
return sample
class TestDataset(torch.utils.data.Dataset):
"""Custom test dataset."""
def __init__(self, csv_file, transform=None):
"""
Args:
transform (callable, optional): Optional transform to be applied on a sample.
"""
self.names = pd.read_csv(csv_file, '\t')
self.transform = transform
def __len__(self):
return len(self.names)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
utt_name = 'cv-corpus-5.1-2020-06-22/en/clips/' + self.names.loc[idx, 'path']
utt = torchaudio.load(utt_name)[0].squeeze()
if self.transform:
utt = self.transform(utt)
sample = {'utt': utt}
return sample
#win_len=1024, hop_len=256
# counting len of MelSpec before doing it (cause of padding)
def mel_len(x):
return int(x // 256) + 1
def transform_tr(wav):
aug_num = torch.randint(low=0, high=3, size=(1,)).item()
augs = [
lambda x: x,
lambda x: (x + distributions.Normal(0, 0.01).sample(x.size())).clamp_(-1, 1),
lambda x: torchaudio.transforms.Vol(.1)(x)
]
return augs[aug_num](wav)
# collate_fn
def preprocess_data(data):
text_transform = TextTransform()
wavs = []
input_lens = []
labels = []
label_lens = []
for el in data:
wavs.append(el['utt'])
input_lens.append(math.ceil(mel_len(el['utt'].shape[0]) / 2)) # cause of stride 2
label = torch.Tensor(text_transform.text_to_int(el['answer']))
labels.append(label)
label_lens.append(len(label))
wavs = pad_sequence(wavs, batch_first=True)
labels = pad_sequence(labels, batch_first=True)
return wavs, input_lens, labels, label_lens