-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathbatch300wikipedia.py
executable file
·78 lines (69 loc) · 2.85 KB
/
batch300wikipedia.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# batch300wikipedia.py: Demonstrates the use of batchVB for LDA to
# analyze a bunch of random Wikipedia articles.
#
# Copyright (C) 2010 Matthew D. Hoffman
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import cPickle, string, numpy, getopt, sys, random, time, re, pprint
import batchldavb
import wikirandom
def main():
"""
Downloads and analyzes a bunch of random Wikipedia articles using
batch VB for LDA.
"""
# The number of documents to analyze each iteration
batchsize = 8
# The total number of documents in Wikipedia
D = 3.3e6
# The number of topics
K = 100
# The size of window
L = 30
# How many documents to look at
if (len(sys.argv) < 2):
documentstoanalyze = int(D/batchsize)
else:
documentstoanalyze = int(sys.argv[1])
if (len(sys.argv) >= 3):
L = int(sys.argv[2])
if (len(sys.argv) >= 4):
batchsize = int(sys.argv[3])
# Our vocabulary
vocab = file('./dictnostops.txt').readlines()
W = len(vocab)
# Initialize the algorithm with alpha=0.5, eta=0.5, rho = 10^-3
olda = batchldavb.batchLDA(vocab, K, D, 0.5, 0.5, 1e-2, -1, L)
# Run until we've seen D documents. (Feel free to interrupt *much*
# sooner than this.)
for iteration in range(0, documentstoanalyze):
# Download some articles
(docset, articlenames) = \
wikirandom.get_random_wikipedia_articles(batchsize)
# Give them to batch LDA
(gamma, bound) = olda.update_lambda_docs(docset)
# Compute an estimate of held-out perplexity
(wordids, wordcts) = batchldavb.parse_doc_list(docset, olda._vocab)
perwordbound = bound * len(docset) / (D * sum(map(sum, wordcts)))
print '%d: rho_t = %f, held-out perplexity estimate = %f' % \
(iteration, olda._rhot, numpy.exp(-perwordbound))
# Save lambda, the parameters to the variational distributions
# over topics, and gamma, the parameters to the variational
# distributions over topic weights for the articles analyzed in
# the last iteration.
if (iteration % 10 == 0):
numpy.savetxt('lambda-%d.dat' % iteration, olda._lambda)
numpy.savetxt('gamma-%d.dat' % iteration, gamma)
if __name__ == '__main__':
main()