Skip to content

Commit b48c286

Browse files
committed
model initialization and test
1 parent e0c5fa0 commit b48c286

File tree

2 files changed

+409
-23
lines changed

2 files changed

+409
-23
lines changed

lda/ldacvb0_cpp/ldacvb0/ldacvb0.hpp

Lines changed: 234 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,14 @@ namespace std {
1919
#include <regex>
2020
#endif
2121
#include <algorithm>
22+
#include <cassert>
2223
#include <string>
2324
#include <unordered_map>
25+
#include <vector>
2426
#include <cybozu/string.hpp>
2527
#include <cybozu/string_operation.hpp>
2628
#include <cybozu/mmap.hpp>
29+
#include <random>
2730

2831
namespace cybozu {
2932
namespace ldacvb0 {
@@ -34,7 +37,13 @@ namespace ldacvb0 {
3437
std::regex rexword("(\\S+)/\\S+");
3538

3639
/*
37-
id and counter of vocabulary
40+
41+
*/
42+
typedef std::vector<double> Vec;
43+
typedef std::vector<Vec> Mat;
44+
45+
/*
46+
id and counter of vocabulary
3847
*/
3948
class IdCount {
4049
public:
@@ -44,8 +53,10 @@ class IdCount {
4453
IdCount(size_t id_, int count_) : id(id_), count(count_) {}
4554
};
4655

56+
57+
4758
/*
48-
identify and count freaquency of vocabulary
59+
identify and count freaquency of vocabulary
4960
*/
5061
template <class WORD>
5162
class Vocabularies {
@@ -57,10 +68,10 @@ class Vocabularies {
5768
size_t add(const WORD &word) {
5869
WORD key(word);
5970
normalize(key);
60-
if (voca.find(key) != voca.end()) {
61-
IdCount& x = voca[key];
62-
x.count += 1;
63-
return x.id;
71+
auto x = voca.find(key);
72+
if (x != voca.end()) {
73+
x->second.count += 1;
74+
return x->second.id;
6475
} else {
6576
size_t new_id = vocalist.size();
6677
voca[key] = IdCount(new_id, 1);
@@ -103,28 +114,45 @@ class Vocabularies {
103114
}
104115
};
105116

106-
typedef std::vector<size_t> Document;
117+
class Term {
118+
public:
119+
size_t id;
120+
int freq;
121+
Term() : id(0), freq(0) {}
122+
Term(size_t id_, int freq_) : id(id_), freq(freq_) {}
123+
};
124+
typedef std::vector<Term> Document;
107125

108126
/*
109-
doocument loader
127+
doocument loader
110128
*/
111129
template <class STRING, class CHAR>
112-
class Documents : std::vector<Document> {
130+
class Documents : public std::vector<Document> {
113131
public:
114132
Vocabularies<STRING> vocabularies;
115133

116134
private:
117135
template <class T> void addeachword(std::regex_iterator<T> i) {
118-
push_back(Document());
119-
Document& doc = back();
120-
136+
std::unordered_map<size_t, int> count;
121137
std::regex_iterator<T> iend;
122138
for (; i != iend; ++i) {
123139
const std::string& w = (*i)[1].str();
124140
char c = w[0];
125141
if (c < 'A' || (c > 'Z' && c < 'a') || c > 'z') continue;
126142
size_t id = vocabularies.add(w);
127-
doc.push_back(id);
143+
auto x = count.find(id);
144+
if (x != count.end()) {
145+
x->second += 1;
146+
} else {
147+
count[id] = 1;
148+
}
149+
}
150+
151+
push_back(Document());
152+
Document& doc = back();
153+
auto j = count.begin(), jend = count.end();
154+
for (;j!=jend;++j) {
155+
doc.push_back(Term(j->first, j->second));
128156
}
129157
}
130158

@@ -150,15 +178,15 @@ class Documents : std::vector<Document> {
150178

151179

152180
/*
153-
call the procedure for each word (mmap)
154-
*/
181+
call the procedure for each word (mmap)
182+
*/
155183
template <class T, class CHAR>
156184
void eachwords(const CHAR* p, const CHAR* end, T& func) {
157185
};
158186

159187
/*
160-
call the procedure for each word (std::string, cybozu::String)
161-
*/
188+
call the procedure for each word (std::string, cybozu::String)
189+
*/
162190
template <class T, class STRING>
163191
void eachwords(STRING st, T& func) {
164192
std::regex_token_iterator<STRING::iterator> i( st.begin(), st.end(), rexword ), iend;
@@ -180,16 +208,199 @@ template <class T> void loadeachwords(std::string filename, T& func) {
180208
}
181209
}
182210

183-
class LDA_CVB0 {
211+
212+
213+
214+
215+
216+
/*
217+
218+
*/
219+
void update_for_word(
220+
Vec& gamma_k,
221+
Vec::iterator i_wk_buf, Vec::iterator i_jk_buf, Vec::iterator i_k_buf,
222+
Vec::const_iterator& i_wk, Vec::const_iterator i_jk, Vec::const_iterator i_k,
223+
const size_t w, const int freq, const int K
224+
) {
225+
i_wk += w * K;
226+
i_wk_buf += w * K;
227+
228+
auto i_gamma = gamma_k.begin();
229+
double sum_gamma = 0;
230+
for (int k=0;k<K;++k) {
231+
double gamma = *i_gamma;
232+
double new_gamma = (*i_wk++ - gamma) * (*i_jk++ - gamma) / (*i_k++ - gamma);
233+
*i_gamma++ = new_gamma;
234+
sum_gamma += new_gamma;
235+
}
236+
i_gamma = gamma_k.begin();
237+
for (int k=0;k<K;++k) {
238+
double gamma = *i_gamma / sum_gamma;
239+
*i_gamma++ = gamma;
240+
gamma *= freq;
241+
*i_wk_buf++ += gamma;
242+
*i_jk_buf++ += gamma;
243+
*i_k_buf++ += gamma;
244+
}
245+
}
246+
247+
248+
void parameter_init(Vec& n_wk, Vec& n_jk, Vec& n_k, const Documents<std::string, char>& docs, const int K) {
249+
const size_t M = docs.size();
250+
const size_t V = docs.vocabularies.size();
251+
n_wk.resize(V*K);
252+
n_jk.resize(M*K);
253+
n_k.resize(K);
254+
255+
}
256+
257+
class dirichlet_distribution {
184258
private:
185-
int K_;
259+
std::mt19937 generator;
260+
public:
261+
dirichlet_distribution() {
262+
263+
}
264+
265+
dirichlet_distribution(unsigned long seed) {
266+
if (seed>0) {
267+
generator.seed(seed);
268+
} else {
269+
dirichlet_distribution();
270+
}
271+
}
272+
void draw(Vec& vec, const int K, const double alpha) {
273+
std::gamma_distribution<double> distribution(alpha, 1.0);
274+
if (vec.size() != K) vec.resize(K);
275+
double sum = 0;
276+
auto i = vec.begin(), iend = vec.end();
277+
for (;i!=iend;++i) {
278+
double x = distribution(generator);
279+
sum += x;
280+
*i = x;
281+
}
282+
for (i=vec.begin();i!=iend;++i) {
283+
*i /= sum;
284+
}
285+
}
286+
void draw(Vec& vec, const Vec& alpha) {
287+
if (vec.size() != alpha.size()) vec.resize(alpha.size());
288+
double sum = 0;
289+
auto i = vec.begin(), iend = vec.end();
290+
auto a = alpha.begin();
291+
for (;i!=iend;++i,++a) {
292+
std::gamma_distribution<double> distribution(*a, 1.0);
293+
double x = distribution(generator);
294+
sum += x;
295+
*i = x;
296+
}
297+
for (i=vec.begin();i!=iend;++i) {
298+
*i /= sum;
299+
}
300+
}
301+
};
302+
303+
/*
304+
305+
*/
306+
class LDA_CVB0 {
307+
public:
308+
int K_, V_;
186309
double alpha_;
187310
double beta_;
188-
public:
189-
LDA_CVB0(int K, double alpha, double beta) :
190-
K_(K), alpha_(alpha), beta_(beta) {
311+
Vec n_wk1, n_wk2, n_jk1, n_jk2, n_k1, n_k2;
312+
Vec &n_wk, &n_wk_buf, &n_jk, &n_jk_buf, &n_k, &n_k_buf;
313+
Mat gamma_jik;
314+
const Documents<std::string, char>& docs_;
315+
LDA_CVB0(int K, int V, double alpha, double beta, const Documents<std::string, char>& docs) :
316+
K_(K), V_(V), alpha_(alpha), beta_(beta), docs_(docs),
317+
n_wk(n_wk1), n_wk_buf(n_wk2), n_jk(n_jk1), n_jk_buf(n_jk2), n_k(n_k1), n_k_buf(n_k2) {
318+
parameter_init(n_wk1, n_jk1, n_k1, docs, K);
319+
parameter_init(n_wk2, n_jk2, n_k2, docs, K);
320+
321+
std::fill(n_wk.begin(), n_wk.end(), beta_);
322+
std::fill(n_jk.begin(), n_jk.end(), alpha_);
323+
std::fill(n_k.begin(), n_k.end(), beta_ * V_);
324+
325+
cybozu::ldacvb0::dirichlet_distribution dd(1U);
326+
327+
auto j = docs_.begin(), jend = docs_.end();
328+
auto j_jk = n_jk.begin();
329+
for (;j!=jend;++j) {
330+
auto i = j->begin(), iend = j->end();
331+
for (;i!=iend;++i) {
332+
size_t w = i->id;
333+
int freq = i->freq;
334+
335+
Vec aph(K);
336+
auto aend = aph.end();
337+
double sum = 0;
338+
{
339+
auto i_wk = n_wk.begin() + w * K;
340+
auto i_jk = j_jk;
341+
auto i_k = n_k.begin();
342+
for (auto ai = aph.begin();ai!=aend;++ai,++i_wk,++i_jk,++i_k) {
343+
sum += *ai = *i_wk * *i_jk / *i_k;
344+
}
345+
}
346+
sum = alpha / sum;
347+
for (auto ai = aph.begin(); ai != aend; ++ai) *ai *= sum;
348+
349+
gamma_jik.push_back(Vec());
350+
Vec& gamma = gamma_jik.back();
351+
dd.draw(gamma, aph);
352+
353+
auto gi = gamma.begin(), gend = gamma.end();
354+
auto i_wk = n_wk.begin() + w * K;
355+
auto i_jk = j_jk;
356+
auto i_k = n_k.begin();
357+
for (;gi!=gend;++gi,++i_wk,++i_jk,++i_k) {
358+
double g = *gi * freq;
359+
*i_wk += g;
360+
*i_jk += g;
361+
*i_k += g;
362+
}
363+
}
364+
j_jk+=K;
365+
}
366+
}
367+
368+
void learn() {
369+
std::fill(n_wk_buf.begin(), n_wk_buf.end(), beta_);
370+
std::fill(n_jk_buf.begin(), n_jk_buf.end(), alpha_);
371+
std::fill(n_k_buf.begin(), n_k_buf.end(), beta_ * V_);
372+
373+
auto gamma_k = gamma_jik.begin();
374+
auto j = docs_.begin(), jend = docs_.end();
375+
auto j_jk = n_jk.begin();
376+
auto j_jk_buf = n_jk_buf.begin();
377+
for (;j!=jend;++j) {
378+
auto i = j->begin(), iend = j->end();
379+
for (;i!=iend;++i) {
380+
381+
size_t w = i->id;
382+
int freq = i->freq;
383+
384+
update_for_word(
385+
*gamma_k,
386+
n_wk.begin(), j_jk, n_k.begin(),
387+
n_wk_buf.begin(), j_jk_buf, n_k_buf.begin(),
388+
w, freq, K_
389+
);
390+
391+
++gamma_k;
392+
}
393+
j_jk += K_;
394+
j_jk_buf += K_;
395+
396+
}
397+
191398

192-
}
399+
400+
n_wk, n_wk_buf = n_wk_buf, n_wk;
401+
n_jk, n_jk_buf = n_jk_buf, n_jk;
402+
n_k, n_k_buf = n_k_buf, n_k;
403+
}
193404
};
194405

195406
} }

0 commit comments

Comments
 (0)