@@ -19,11 +19,14 @@ namespace std {
1919#include < regex>
2020#endif
2121#include < algorithm>
22+ #include < cassert>
2223#include < string>
2324#include < unordered_map>
25+ #include < vector>
2426#include < cybozu/string.hpp>
2527#include < cybozu/string_operation.hpp>
2628#include < cybozu/mmap.hpp>
29+ #include < random>
2730
2831namespace cybozu {
2932namespace ldacvb0 {
@@ -34,7 +37,13 @@ namespace ldacvb0 {
3437std::regex rexword (" (\\ S+)/\\ S+" );
3538
3639/*
37- id and counter of vocabulary
40+
41+ */
42+ typedef std::vector<double > Vec;
43+ typedef std::vector<Vec> Mat;
44+
45+ /*
46+ id and counter of vocabulary
3847*/
3948class IdCount {
4049public:
@@ -44,8 +53,10 @@ class IdCount {
4453 IdCount (size_t id_, int count_) : id(id_), count(count_) {}
4554};
4655
56+
57+
4758/*
48- identify and count freaquency of vocabulary
59+ identify and count freaquency of vocabulary
4960*/
5061template <class WORD >
5162class Vocabularies {
@@ -57,10 +68,10 @@ class Vocabularies {
5768 size_t add (const WORD &word) {
5869 WORD key (word);
5970 normalize (key);
60- if (voca. find (key) ! = voca.end ()) {
61- IdCount& x = voca[key];
62- x.count += 1 ;
63- return x.id ;
71+ auto x = voca.find (key);
72+ if (x ! = voca. end ()) {
73+ x-> second .count += 1 ;
74+ return x-> second .id ;
6475 } else {
6576 size_t new_id = vocalist.size ();
6677 voca[key] = IdCount (new_id, 1 );
@@ -103,28 +114,45 @@ class Vocabularies {
103114 }
104115};
105116
106- typedef std::vector<size_t > Document;
117+ class Term {
118+ public:
119+ size_t id;
120+ int freq;
121+ Term () : id(0 ), freq(0 ) {}
122+ Term (size_t id_, int freq_) : id(id_), freq(freq_) {}
123+ };
124+ typedef std::vector<Term> Document;
107125
108126/*
109- doocument loader
127+ doocument loader
110128*/
111129template <class STRING , class CHAR >
112- class Documents : std::vector<Document> {
130+ class Documents : public std ::vector<Document> {
113131public:
114132 Vocabularies<STRING> vocabularies;
115133
116134private:
117135 template <class T > void addeachword (std::regex_iterator<T> i) {
118- push_back (Document ());
119- Document& doc = back ();
120-
136+ std::unordered_map<size_t , int > count;
121137 std::regex_iterator<T> iend;
122138 for (; i != iend; ++i) {
123139 const std::string& w = (*i)[1 ].str ();
124140 char c = w[0 ];
125141 if (c < ' A' || (c > ' Z' && c < ' a' ) || c > ' z' ) continue ;
126142 size_t id = vocabularies.add (w);
127- doc.push_back (id);
143+ auto x = count.find (id);
144+ if (x != count.end ()) {
145+ x->second += 1 ;
146+ } else {
147+ count[id] = 1 ;
148+ }
149+ }
150+
151+ push_back (Document ());
152+ Document& doc = back ();
153+ auto j = count.begin (), jend = count.end ();
154+ for (;j!=jend;++j) {
155+ doc.push_back (Term (j->first , j->second ));
128156 }
129157 }
130158
@@ -150,15 +178,15 @@ class Documents : std::vector<Document> {
150178
151179
152180/*
153- call the procedure for each word (mmap)
154- */
181+ call the procedure for each word (mmap)
182+ */
155183template <class T , class CHAR >
156184void eachwords (const CHAR* p, const CHAR* end, T& func) {
157185};
158186
159187/*
160- call the procedure for each word (std::string, cybozu::String)
161- */
188+ call the procedure for each word (std::string, cybozu::String)
189+ */
162190template <class T , class STRING >
163191void eachwords (STRING st, T& func) {
164192 std::regex_token_iterator<STRING::iterator> i ( st.begin (), st.end (), rexword ), iend;
@@ -180,16 +208,199 @@ template <class T> void loadeachwords(std::string filename, T& func) {
180208 }
181209}
182210
183- class LDA_CVB0 {
211+
212+
213+
214+
215+
216+ /*
217+
218+ */
219+ void update_for_word (
220+ Vec& gamma_k,
221+ Vec::iterator i_wk_buf, Vec::iterator i_jk_buf, Vec::iterator i_k_buf,
222+ Vec::const_iterator& i_wk, Vec::const_iterator i_jk, Vec::const_iterator i_k,
223+ const size_t w, const int freq, const int K
224+ ) {
225+ i_wk += w * K;
226+ i_wk_buf += w * K;
227+
228+ auto i_gamma = gamma_k.begin ();
229+ double sum_gamma = 0 ;
230+ for (int k=0 ;k<K;++k) {
231+ double gamma = *i_gamma;
232+ double new_gamma = (*i_wk++ - gamma) * (*i_jk++ - gamma) / (*i_k++ - gamma);
233+ *i_gamma++ = new_gamma;
234+ sum_gamma += new_gamma;
235+ }
236+ i_gamma = gamma_k.begin ();
237+ for (int k=0 ;k<K;++k) {
238+ double gamma = *i_gamma / sum_gamma;
239+ *i_gamma++ = gamma;
240+ gamma *= freq;
241+ *i_wk_buf++ += gamma;
242+ *i_jk_buf++ += gamma;
243+ *i_k_buf++ += gamma;
244+ }
245+ }
246+
247+
248+ void parameter_init (Vec& n_wk, Vec& n_jk, Vec& n_k, const Documents<std::string, char >& docs, const int K) {
249+ const size_t M = docs.size ();
250+ const size_t V = docs.vocabularies .size ();
251+ n_wk.resize (V*K);
252+ n_jk.resize (M*K);
253+ n_k.resize (K);
254+
255+ }
256+
257+ class dirichlet_distribution {
184258private:
185- int K_;
259+ std::mt19937 generator;
260+ public:
261+ dirichlet_distribution () {
262+
263+ }
264+
265+ dirichlet_distribution (unsigned long seed) {
266+ if (seed>0 ) {
267+ generator.seed (seed);
268+ } else {
269+ dirichlet_distribution ();
270+ }
271+ }
272+ void draw (Vec& vec, const int K, const double alpha) {
273+ std::gamma_distribution<double > distribution (alpha, 1.0 );
274+ if (vec.size () != K) vec.resize (K);
275+ double sum = 0 ;
276+ auto i = vec.begin (), iend = vec.end ();
277+ for (;i!=iend;++i) {
278+ double x = distribution (generator);
279+ sum += x;
280+ *i = x;
281+ }
282+ for (i=vec.begin ();i!=iend;++i) {
283+ *i /= sum;
284+ }
285+ }
286+ void draw (Vec& vec, const Vec& alpha) {
287+ if (vec.size () != alpha.size ()) vec.resize (alpha.size ());
288+ double sum = 0 ;
289+ auto i = vec.begin (), iend = vec.end ();
290+ auto a = alpha.begin ();
291+ for (;i!=iend;++i,++a) {
292+ std::gamma_distribution<double > distribution (*a, 1.0 );
293+ double x = distribution (generator);
294+ sum += x;
295+ *i = x;
296+ }
297+ for (i=vec.begin ();i!=iend;++i) {
298+ *i /= sum;
299+ }
300+ }
301+ };
302+
303+ /*
304+
305+ */
306+ class LDA_CVB0 {
307+ public:
308+ int K_, V_;
186309 double alpha_;
187310 double beta_;
188- public:
189- LDA_CVB0 (int K, double alpha, double beta) :
190- K_ (K), alpha_(alpha), beta_(beta) {
311+ Vec n_wk1, n_wk2, n_jk1, n_jk2, n_k1, n_k2;
312+ Vec &n_wk, &n_wk_buf, &n_jk, &n_jk_buf, &n_k, &n_k_buf;
313+ Mat gamma_jik;
314+ const Documents<std::string, char >& docs_;
315+ LDA_CVB0 (int K, int V, double alpha, double beta, const Documents<std::string, char >& docs) :
316+ K_ (K), V_(V), alpha_(alpha), beta_(beta), docs_(docs),
317+ n_wk (n_wk1), n_wk_buf(n_wk2), n_jk(n_jk1), n_jk_buf(n_jk2), n_k(n_k1), n_k_buf(n_k2) {
318+ parameter_init (n_wk1, n_jk1, n_k1, docs, K);
319+ parameter_init (n_wk2, n_jk2, n_k2, docs, K);
320+
321+ std::fill (n_wk.begin (), n_wk.end (), beta_);
322+ std::fill (n_jk.begin (), n_jk.end (), alpha_);
323+ std::fill (n_k.begin (), n_k.end (), beta_ * V_);
324+
325+ cybozu::ldacvb0::dirichlet_distribution dd (1U );
326+
327+ auto j = docs_.begin (), jend = docs_.end ();
328+ auto j_jk = n_jk.begin ();
329+ for (;j!=jend;++j) {
330+ auto i = j->begin (), iend = j->end ();
331+ for (;i!=iend;++i) {
332+ size_t w = i->id ;
333+ int freq = i->freq ;
334+
335+ Vec aph (K);
336+ auto aend = aph.end ();
337+ double sum = 0 ;
338+ {
339+ auto i_wk = n_wk.begin () + w * K;
340+ auto i_jk = j_jk;
341+ auto i_k = n_k.begin ();
342+ for (auto ai = aph.begin ();ai!=aend;++ai,++i_wk,++i_jk,++i_k) {
343+ sum += *ai = *i_wk * *i_jk / *i_k;
344+ }
345+ }
346+ sum = alpha / sum;
347+ for (auto ai = aph.begin (); ai != aend; ++ai) *ai *= sum;
348+
349+ gamma_jik.push_back (Vec ());
350+ Vec& gamma = gamma_jik.back ();
351+ dd.draw (gamma, aph);
352+
353+ auto gi = gamma.begin (), gend = gamma.end ();
354+ auto i_wk = n_wk.begin () + w * K;
355+ auto i_jk = j_jk;
356+ auto i_k = n_k.begin ();
357+ for (;gi!=gend;++gi,++i_wk,++i_jk,++i_k) {
358+ double g = *gi * freq;
359+ *i_wk += g;
360+ *i_jk += g;
361+ *i_k += g;
362+ }
363+ }
364+ j_jk+=K;
365+ }
366+ }
367+
368+ void learn () {
369+ std::fill (n_wk_buf.begin (), n_wk_buf.end (), beta_);
370+ std::fill (n_jk_buf.begin (), n_jk_buf.end (), alpha_);
371+ std::fill (n_k_buf.begin (), n_k_buf.end (), beta_ * V_);
372+
373+ auto gamma_k = gamma_jik.begin ();
374+ auto j = docs_.begin (), jend = docs_.end ();
375+ auto j_jk = n_jk.begin ();
376+ auto j_jk_buf = n_jk_buf.begin ();
377+ for (;j!=jend;++j) {
378+ auto i = j->begin (), iend = j->end ();
379+ for (;i!=iend;++i) {
380+
381+ size_t w = i->id ;
382+ int freq = i->freq ;
383+
384+ update_for_word (
385+ *gamma_k,
386+ n_wk.begin (), j_jk, n_k.begin (),
387+ n_wk_buf.begin (), j_jk_buf, n_k_buf.begin (),
388+ w, freq, K_
389+ );
390+
391+ ++gamma_k;
392+ }
393+ j_jk += K_;
394+ j_jk_buf += K_;
395+
396+ }
397+
191398
192- }
399+
400+ n_wk, n_wk_buf = n_wk_buf, n_wk;
401+ n_jk, n_jk_buf = n_jk_buf, n_jk;
402+ n_k, n_k_buf = n_k_buf, n_k;
403+ }
193404};
194405
195406} }
0 commit comments