Skip to content

Commit 3ea5288

Browse files
committed
support to load a text file as corpus
1 parent 924e851 commit 3ea5288

File tree

1 file changed

+39
-24
lines changed

1 file changed

+39
-24
lines changed

ngram/rnnlm.py

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
#!/usr/bin/env python
22
# encode: utf-8
33

4-
# Recurrent Neural Network Langage Model
4+
# Recurrent Neural Network Language Model
55
# This code is available under the MIT License.
66
# (c)2014 Nakatani Shuyo / Cybozu Labs Inc.
77

8-
import numpy, nltk
8+
import numpy, nltk, codecs, re
99
import optparse
1010

1111
class RNNLM:
@@ -63,7 +63,7 @@ def dist(self, w):
6363

6464
class RNNLM_BPTT(RNNLM):
6565
"""RNNLM with BackPropagation Through Time"""
66-
def learn(self, docs, alpha=0.1, tau=10):
66+
def learn(self, docs, alpha=0.1, tau=3):
6767
index = numpy.arange(len(docs))
6868
numpy.random.shuffle(index)
6969
for i in index:
@@ -125,10 +125,14 @@ def perplexity(self, docs):
125125
N += len(doc)
126126
return log_like / N
127127

128+
def CorpusWrapper(corpus):
129+
for id in corpus.fileids():
130+
yield corpus.words(id)
128131

129132
def main():
130133
parser = optparse.OptionParser()
131-
parser.add_option("-c", dest="corpus", help="corpus module name under nltk.corpus (e.g. brown, reuters)", default='nps_chat')
134+
parser.add_option("-c", dest="corpus", help="corpus module name under nltk.corpus (e.g. brown, reuters)")
135+
parser.add_option("-f", dest="filename", help="corpus filename name (each line is regarded as a document)")
132136
parser.add_option("-a", dest="alpha", type="float", help="additive smoothing parameter of bigram", default=0.001)
133137
parser.add_option("-k", dest="K", type="int", help="size of hidden layer", default=10)
134138
parser.add_option("-i", dest="I", type="int", help="learning interval", default=10)
@@ -138,50 +142,61 @@ def main():
138142

139143
numpy.random.seed(opt.seed)
140144

141-
m = __import__('nltk.corpus', globals(), locals(), [opt.corpus], -1)
142-
corpus = getattr(m, opt.corpus)
143-
ids = corpus.fileids()
144-
D = len(ids)
145-
print "found corpus : %s (D=%d)" % (opt.corpus, D)
145+
if opt.corpus:
146+
m = __import__('nltk.corpus', globals(), locals(), [opt.corpus], -1)
147+
corpus = CorpusWrapper(getattr(m, opt.corpus))
148+
elif opt.filename:
149+
corpus = []
150+
with codecs.open(opt.filename, "rb", "utf-8") as f:
151+
for s in f:
152+
s = re.sub(r'(["\.,!\?:;])', r' \1 ', s).strip()
153+
d = re.split(r'\s+', s)
154+
if len(d) > 0: corpus.append(d)
155+
else:
156+
raise "need -f or -c"
146157

147158
voca = {"<s>":0, "</s>":1}
148159
vocalist = ["<s>", "</s>"]
149160
docs = []
150-
for id in corpus.fileids()[:2]:
161+
N = 0
162+
for words in corpus:
151163
doc = []
152-
for w in corpus.words(id):
164+
for w in words:
153165
w = w.lower()
154166
if w not in voca:
155167
voca[w] = len(vocalist)
156168
vocalist.append(w)
157169
doc.append(voca[w])
158170
if len(doc) > 0:
159-
doc.append(1)
171+
N += len(doc)
172+
doc.append(1) # </s>
160173
docs.append(doc)
161-
V = len(vocalist)
162-
print "vocabulary : %d / %d" % (V, len(corpus.words()))
163174

164175
D = len(docs)
165-
166-
print ">> BIGRAM(alpha=%f)" % opt.alpha
167-
model = BIGRAM(V, opt.alpha)
168-
model.learn(docs)
169-
print model.perplexity(docs)
176+
V = len(vocalist)
177+
print "corpus : %s (D=%d)" % (opt.corpus or opt.filename, D)
178+
print "vocabulary : %d / %d" % (V, N)
170179

171180
print ">> RNNLM(K=%d)" % opt.K
172-
model = RNNLM(V, opt.K)
173-
print model.perplexity(docs)
174-
intervals = [1.0, 1.0, 0.5, 0.5, 0.4, 0.3, 0.2]
181+
model = RNNLM_BPTT(V, opt.K)
182+
a = 1.0
175183
for i in xrange(opt.I):
176-
a = intervals[i] if i < len(intervals) else 0.1
184+
print i, model.perplexity(docs)
177185
model.learn(docs, a)
178-
print model.perplexity(docs)
186+
a = a * 0.95 + 0.01
187+
print opt.I, model.perplexity(docs)
179188

180189
if opt.output:
181190
import cPickle
182191
with open(opt.output, 'wb') as f:
183192
cPickle.dump([model, voca, vocalist], f)
184193

194+
print ">> BIGRAM(alpha=%f)" % opt.alpha
195+
model = BIGRAM(V, opt.alpha)
196+
model.learn(docs)
197+
print model.perplexity(docs)
198+
199+
185200
"""
186201
testids = set(random.sample(ids, int(D * opt.testrate)))
187202
trainids = [id for id in ids if id not in testids]

0 commit comments

Comments
 (0)