-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
60 lines (47 loc) · 1.95 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import h5py
import numpy as np
import random
class Dataset:
def __init__(self, params):
self.params = params
dataset = h5py.File(params.dataset_path, 'r')
self.embeddings = dataset['embeddings'][:]
self.train = {}
for d in dataset['train'].values():
if d.shape[1] > params.max_sequence_len or d.shape[0] < params.batch_size:
continue
self.train[d.shape[1]] = d[:]
self.validation = {}
for d in dataset['validation'].values():
if d.shape[1] > params.max_sequence_len:
continue
self.validation[d.shape[1]] = d[:]
self.test = {}
for d in dataset['test'].values():
if d.shape[1] > params.max_sequence_len:
continue
self.test[d.shape[1]] = d[:]
dataset.close()
self.reset()
def reset(self):
self.indices_per_dataset = {}
self.current_ids = {}
self.dataset_iterations = []
for k, d in self.train.items():
ids = np.arange(d.shape[0])
np.random.shuffle(ids)
self.indices_per_dataset[k] = ids
self.dataset_iterations += [k] * (d.shape[0] // self.params.batch_size)
self.current_ids[k] = 0
random.shuffle(self.dataset_iterations)
def train_epoch(self):
while len(self.dataset_iterations) != 0:
i = self.dataset_iterations.pop()
yield self.embeddings[self.train[i][self.indices_per_dataset[i][self.current_ids[i] : self.current_ids[i] + self.params.batch_size]]]
self.current_ids[i] += self.params.batch_size
self.reset()
def test_epoch(self, val_or_test):
_set = eval(f'self.{val_or_test}')
for d in _set.values():
for i in range(0, d.shape[0] - (d.shape[0] % self.params.batch_size), self.params.batch_size):
yield self.embeddings[d[i:i+self.params.batch_size]]