-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
59 lines (39 loc) · 1.66 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import os
import fire
from tqdm import tqdm
import core.config as conf
from utils.dataiter import Dataiter
from models.model.XGBoost import XGBoost
from models.model.DNN import DNN
# from models.model.DeepFM import DeepFM
from models.model.FFNN import FFNN
from models.model.FFNN_ALL import FFNN_ALL
from models.model.FFNN_ALL_DEFAULT import FFNN_ALL_DEFAULT
from models.model.Ensemble_FFNN_ALL import Ensemble_FFNN_ALL
from models.network import Network
class Train(object):
def __init__(self, target='all'):
TARGET_id = conf.target_to_idx[target]
self.df = Dataiter(conf.dataset_path, TARGET_id, train=True) # dataset_path, small_dataset_path
if conf.net_structure == 'ensemble_ffnn_all':
model = Ensemble_FFNN_ALL(self.df, TARGET_id)
elif conf.net_structure == 'xgboost':
model = XGBoost(self.df, TARGET_id)
elif conf.net_structure == 'deepfm':
model = DeepFM(self.df, TARGET_id)
elif conf.net_structure == 'dnn' :
model = DNN(self.df, TARGET_id)
elif conf.net_structure == 'ffnn' :
model = FFNN(self.df, TARGET_id)
elif conf.net_structure == 'ffnn_all':
model = FFNN_ALL(self.df, TARGET_id)
elif conf.net_structure == 'ffnn_all_default':
model = FFNN_ALL_DEFAULT(self.df, TARGET_id)
else:
print('Unidentified Network... exit')
exit()
self.model = Network(model)
def train(self):
self.model.train()
if __name__ == "__main__":
fire.Fire(Train)