-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtft_dataset.py
111 lines (80 loc) · 3.37 KB
/
tft_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from data_formatters.base import InputTypes, DataTypes
from data_formatters.utils import get_single_col_by_input_type
class TFTDataset(Dataset):
def __init__(self, data, column_definition, params) -> None:
# self.data = data
self.column_definition = column_definition
self.params = params
self.data_map = self.convert_data(data)
def convert_data(self, data):
"""
Slice or not slice
"""
id_col = self._get_single_col_by_type(InputTypes.ID)
time_col = self._get_single_col_by_type(InputTypes.TIME)
target_col = self._get_single_col_by_type(InputTypes.TARGET)
input_cols = [
tup[0]
for tup in self.column_definition
if tup[2] not in {InputTypes.ID, InputTypes.TIME}
]
# Dev
def convert_to_array(input_data, lags):
time_steps = len(input_data)
# lags = 257 # self.time_steps
x = input_data.to_numpy()
if x.dtype == np.float64:
x = x.astype(np.float32)
if time_steps >= lags:
return x
def mapper(data_len, lags, last_index_map):
start = np.arange(data_len - lags + 1)
end = start + lags
return np.stack([start, end], axis=1) + last_index_map
data_map = {'index': []}
last_index_map = 0
for _, sliced in data.groupby(id_col):
if not len(sliced) == 0:
col_mappings = {
'identifier': [id_col],
'time': [time_col],
'outputs': [target_col],
'inputs': input_cols
}
# No index here
# total_index += (len(sliced) - 257) # check here maybe + 1
map = mapper(len(sliced), self.params['total_time_steps'], last_index_map)
last_index_map = map[-1][1] # sum every slice's last index number so we can add for the next
data_map['index'].append(map)
for k in col_mappings:
cols = col_mappings[k]
arr = convert_to_array(sliced[cols], self.params['total_time_steps'])
if k not in data_map:
data_map[k] = [arr]
else:
data_map[k].append(arr)
# Combine all data
for k in data_map:
data_map[k] = np.concatenate(data_map[k], axis=0) # hangisi nerde belli değil
return data_map
def _get_single_col_by_type(self, input_type):
"""Returns name of single column for input type."""
return get_single_col_by_input_type(input_type, self.column_definition)
def __len__(self):
return len(self.data_map['index'])
def __getitem__(self, idx):
"""
Index to slices and then to tensors
"""
index = self.data_map['index'][idx]
start, end = index[0], index[1]
inputs = self.data_map['inputs'][start:end]
outputs = self.data_map['outputs'][start:end][self.params['num_encoder_steps']:, :]
return inputs, outputs
return {'inputs': inputs, 'outputs': outputs}
if __name__ == '__main__':
dataset = TFTDataset('output/hourly_electricity.csv')