-
Notifications
You must be signed in to change notification settings - Fork 0
/
hmm_train.py
33 lines (29 loc) · 1.2 KB
/
hmm_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from utils.data_loader import DataLoader
from utils.hmm_model import HMMBasic
import pickle
if __name__ == '__main__':
print('*********************************')
print('*** Name: Truong Nhat Hoang ***')
print('*** Student ID: 51703092 ***')
print('*** HMM POS Tagger ***')
print('*** NLP - 504045 ***')
print('*********************************')
print()
print('[Loading data] ...')
dloader = DataLoader('./data/')
dloader.load()
train_word = [j for i in dloader.data_train['label'] for j in i].__len__()
test_word = [j for i in dloader.data_test['label'] for j in i].__len__()
print("------[TRAIN SET]------")
print(f"\t{dloader.data_train.__len__()}: Sentences")
print(f"\t{train_word}: Words")
print("------[TEST SET]------")
print(f"\t{dloader.data_test.__len__()}: Sentences")
print(f"\t{test_word}: Words")
print('[TRANING] ...')
hmm_basic = HMMBasic()
hmm_basic.train(dloader.data_train)
score = hmm_basic.eval(dloader.data_test)
print(f'Accurancy: {score}', end='\n\n')
print('[SAVING MODEL TO FILE] >>> model/hmm_pos.hmm')
pickle.dump(hmm_basic, open('model/hmm_pos.hmm', mode='wb'))