-
Notifications
You must be signed in to change notification settings - Fork 1
/
topic_modeling_dataset_generator.py
186 lines (151 loc) · 6.53 KB
/
topic_modeling_dataset_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
# -*- coding: utf-8 -*-
"""
Created on Mon Sep 26 21:15:28 2022
@author: arman hossain
!python -m spacy download en_core_web_sm
!python topic_modeling_dataset_generator.py
"""
import pickle
import pandas as pd
import gensim
import gensim.corpora as corpora
from gensim.utils import simple_preprocess
import nltk
# nltk.download('stopwords')
import re
import warnings
nltk.download('stopwords')
from nltk.corpus import stopwords
# spacy for lemmatization
import spacy
stop_words = stopwords.words('english')
stop_words.extend(['from', 'subject', 're', 'edu', 'use'])
# ------------------------------------
bigram_mod = []
id2word = []
corpus = []
nlp = spacy.load('en_core_web_sm', disable=['parser', 'ner'])
# with open('./epss_models/bigram_mod.pkl', 'rb') as f:
# bigram_mod = pickle.load(f) #bigram generated by our own
def sent_to_words(sentences):
for sentence in sentences:
yield(gensim.utils.simple_preprocess(str(sentence), deacc=True)) # deacc=True removes punctuations
def save_bigram(data_words,save=False):
bigram = gensim.models.Phrases(data_words, min_count=5, threshold=100) # higher threshold fewer phrases.
global bigram_mod
bigram_mod = gensim.models.phrases.Phraser(bigram)
if save==True:
with open('./epss_models/bigram_mod.pkl', 'wb') as fp:
pickle.dump(bigram_mod, fp)
def remove_stopwords(texts):
return [[word for word in simple_preprocess(str(doc)) if word not in stop_words] for doc in texts]
def make_bigrams(texts):
global bigram_mod
return [bigram_mod[doc] for doc in texts]
def lemmatization(texts, allowed_postags=['NOUN', 'ADJ', 'VERB', 'ADV']):
"""https://spacy.io/api/annotation"""
texts_out = []
for sent in texts:
doc = nlp(" ".join(sent))
texts_out.append([token.lemma_ for token in doc if token.pos_ in allowed_postags])
return texts_out
def preprocess(data,save=False,test=False): #list of list
# print("hello i am here",from_model_trainer)
data = [re.sub('\S*@\S*\s?', '', sent) for sent in data]
# Remove new line characters
data = [re.sub('\s+', ' ', sent) for sent in data]
# Remove distracting single quotes
data = [re.sub("\'", "", sent) for sent in data]
data_words = list(sent_to_words(data)) #[['receipt', 'of', 'malformed', 'packet', 'on', 'mx'...],[]]
if not test: save_bigram(data_words,save)
# Remove Stop Words
data_words_nostops = remove_stopwords(data_words)
# return data_words_nostops
data_words_bigrams = make_bigrams(data_words_nostops)
# # Do lemmatization keeping only noun, adj, vb, adv
data_lemmatized = lemmatization(data_words_bigrams, allowed_postags=['NOUN', 'ADJ', 'VERB', 'ADV'])
if save==True:
with open('./epss_models/lamma.pkl', 'wb') as fp:
pickle.dump(data_lemmatized, fp)
return data_lemmatized
def get_corpus(data,save=False):
data_lemmatized = preprocess(data,save)
global id2word
id2word = corpora.Dictionary(data_lemmatized)
corpus = [id2word.doc2bow(text) for text in data_lemmatized]
if save==True:
with open('./epss_models/id2word.pkl', 'wb') as fp:
pickle.dump(id2word, fp)
if save==True:
with open('./epss_models/corpus.pkl', 'wb') as fp:
pickle.dump(corpus, fp)
return corpus
def get_corpust_test(data):
data_lemmatized = preprocess(data,False,True)
global id2word
return [id2word.doc2bow(text) for text in data_lemmatized]
def load():
# self.lda_model = gensim.models.wrappers.LdaMallet.load("./epss_models/lda_model.bin")
with open('./epss_models/bigram_mod.pkl', 'rb') as f:
global bigram_mod
bigram_mod = pickle.load(f) #bigram generated by our own
with open('./epss_models/id2word.pkl', 'rb') as f:
global id2word
id2word = pickle.load(f)
with open('./epss_models/corpus.pkl', 'rb') as f:
global corpus
corpus = pickle.load(f)
lda_train = gensim.models.ldamulticore.LdaMulticore.load('lda_train.model')
# pprint(self.lda_model.show_topics(formatted=False, num_topics = self.optimal_topic, num_words=5))
return lda_train
def get_topics_dataset(lda_train,corpus):
df_ = lda_train[corpus]
df = pd.DataFrame(df_)
cdf = df.copy()
for col in df.columns:
for row in df.index:
if type(cdf[col][row]) == tuple or cdf[col][row] == None: # means not updated yet
cdf[col][row] = 0
touple = df[col][row]
if touple == None:
continue
index = touple[0]
prob = touple[1]
cdf[index][row] = prob
cdf.columns = ["Topic "+ str(i) for i in range(1,21)]
return cdf
def read_dataset(path):
data = pd.read_csv(path)
indexes = [item.find("** REJECT **") < 0 for item in data.description.values.tolist()]
return data[indexes]
def train(data,save=False):
# data = pd.read_csv('only_desc_1029.csv')
# train_corpus, train_id2word, bigram_train = get_corpus(data.description)
# data = data.description.values.tolist()
corpus = get_corpus(data,save)
global id2word
# print("hello")
# print(lamma[-1])
# print(corpus[-1])
with warnings.catch_warnings():
warnings.simplefilter('ignore')
lda_train = gensim.models.ldamulticore.LdaMulticore(
corpus=corpus,
num_topics=20,
id2word=id2word,
chunksize=100,
workers=3, # Num. Processing Cores - 1 it was 7
passes=50,
eval_every = 1,
per_word_topics=False)
if save == True:
lda_train.save('lda_train.model')
return lda_train,get_topics_dataset(lda_train, corpus)
if __name__ == '__main__':
data = read_dataset('./data/2016_22m_nvd.csv')
lda_train,topics_dataset = train(data.description.values.tolist(),True) # do not save
# lda_train = load()
# print(lda_train.print_topics(20, num_words=5))
# corpus = get_corpust_test(['The Microsoft (1) VBScript 5.7 and 5.8 and (2) JScript 5.7 and 5.8 engines, as used in Internet Explorer 8 through 11 and other products, allow remote attackers to execute arbitrary code via a crafted web site, aka Scripting Engine Memory Corruption Vulnerability.'])
# cdf2 = get_topics_dataset(lda_train, corpus)
topics_dataset.to_csv("./data/epss_topics_16_22.csv",index=False)