-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
31 lines (26 loc) · 812 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import math
import torch
import random
import numpy as np
from numpy.random import randint
def set_seed(seed=42):
np.random.seed(seed)
random.seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def next_batch(X1, X2, batch_size):
tot = X1.shape[0]
total = math.ceil(tot / batch_size)
for i in range(int(total)):
start_idx = i * batch_size
end_idx = (i + 1) * batch_size
end_idx = min(tot, end_idx)
batch_x1 = X1[start_idx: end_idx, ...]
batch_x2 = X2[start_idx: end_idx, ...]
yield (batch_x1, batch_x2, (i + 1))
def normalize(x):
x = (x - np.min(x)) / (np.max(x) - np.min(x))
return x