11#!/usr/bin/env python
22# encode: utf-8
33
4- # Recurrent Neural Network Langage Model
4+ # Recurrent Neural Network Language Model
55# This code is available under the MIT License.
66# (c)2014 Nakatani Shuyo / Cybozu Labs Inc.
77
8- import numpy , nltk
8+ import numpy , nltk , codecs , re
99import optparse
1010
1111class RNNLM :
@@ -63,7 +63,7 @@ def dist(self, w):
6363
6464class RNNLM_BPTT (RNNLM ):
6565 """RNNLM with BackPropagation Through Time"""
66- def learn (self , docs , alpha = 0.1 , tau = 10 ):
66+ def learn (self , docs , alpha = 0.1 , tau = 3 ):
6767 index = numpy .arange (len (docs ))
6868 numpy .random .shuffle (index )
6969 for i in index :
@@ -125,10 +125,14 @@ def perplexity(self, docs):
125125 N += len (doc )
126126 return log_like / N
127127
128+ def CorpusWrapper (corpus ):
129+ for id in corpus .fileids ():
130+ yield corpus .words (id )
128131
129132def main ():
130133 parser = optparse .OptionParser ()
131- parser .add_option ("-c" , dest = "corpus" , help = "corpus module name under nltk.corpus (e.g. brown, reuters)" , default = 'nps_chat' )
134+ parser .add_option ("-c" , dest = "corpus" , help = "corpus module name under nltk.corpus (e.g. brown, reuters)" )
135+ parser .add_option ("-f" , dest = "filename" , help = "corpus filename name (each line is regarded as a document)" )
132136 parser .add_option ("-a" , dest = "alpha" , type = "float" , help = "additive smoothing parameter of bigram" , default = 0.001 )
133137 parser .add_option ("-k" , dest = "K" , type = "int" , help = "size of hidden layer" , default = 10 )
134138 parser .add_option ("-i" , dest = "I" , type = "int" , help = "learning interval" , default = 10 )
@@ -138,50 +142,61 @@ def main():
138142
139143 numpy .random .seed (opt .seed )
140144
141- m = __import__ ('nltk.corpus' , globals (), locals (), [opt .corpus ], - 1 )
142- corpus = getattr (m , opt .corpus )
143- ids = corpus .fileids ()
144- D = len (ids )
145- print "found corpus : %s (D=%d)" % (opt .corpus , D )
145+ if opt .corpus :
146+ m = __import__ ('nltk.corpus' , globals (), locals (), [opt .corpus ], - 1 )
147+ corpus = CorpusWrapper (getattr (m , opt .corpus ))
148+ elif opt .filename :
149+ corpus = []
150+ with codecs .open (opt .filename , "rb" , "utf-8" ) as f :
151+ for s in f :
152+ s = re .sub (r'(["\.,!\?:;])' , r' \1 ' , s ).strip ()
153+ d = re .split (r'\s+' , s )
154+ if len (d ) > 0 : corpus .append (d )
155+ else :
156+ raise "need -f or -c"
146157
147158 voca = {"<s>" :0 , "</s>" :1 }
148159 vocalist = ["<s>" , "</s>" ]
149160 docs = []
150- for id in corpus .fileids ()[:2 ]:
161+ N = 0
162+ for words in corpus :
151163 doc = []
152- for w in corpus . words ( id ) :
164+ for w in words :
153165 w = w .lower ()
154166 if w not in voca :
155167 voca [w ] = len (vocalist )
156168 vocalist .append (w )
157169 doc .append (voca [w ])
158170 if len (doc ) > 0 :
159- doc .append (1 )
171+ N += len (doc )
172+ doc .append (1 ) # </s>
160173 docs .append (doc )
161- V = len (vocalist )
162- print "vocabulary : %d / %d" % (V , len (corpus .words ()))
163174
164175 D = len (docs )
165-
166- print ">> BIGRAM(alpha=%f)" % opt .alpha
167- model = BIGRAM (V , opt .alpha )
168- model .learn (docs )
169- print model .perplexity (docs )
176+ V = len (vocalist )
177+ print "corpus : %s (D=%d)" % (opt .corpus or opt .filename , D )
178+ print "vocabulary : %d / %d" % (V , N )
170179
171180 print ">> RNNLM(K=%d)" % opt .K
172- model = RNNLM (V , opt .K )
173- print model .perplexity (docs )
174- intervals = [1.0 , 1.0 , 0.5 , 0.5 , 0.4 , 0.3 , 0.2 ]
181+ model = RNNLM_BPTT (V , opt .K )
182+ a = 1.0
175183 for i in xrange (opt .I ):
176- a = intervals [ i ] if i < len ( intervals ) else 0.1
184+ print i , model . perplexity ( docs )
177185 model .learn (docs , a )
178- print model .perplexity (docs )
186+ a = a * 0.95 + 0.01
187+ print opt .I , model .perplexity (docs )
179188
180189 if opt .output :
181190 import cPickle
182191 with open (opt .output , 'wb' ) as f :
183192 cPickle .dump ([model , voca , vocalist ], f )
184193
194+ print ">> BIGRAM(alpha=%f)" % opt .alpha
195+ model = BIGRAM (V , opt .alpha )
196+ model .learn (docs )
197+ print model .perplexity (docs )
198+
199+
185200"""
186201 testids = set(random.sample(ids, int(D * opt.testrate)))
187202 trainids = [id for id in ids if id not in testids]
0 commit comments