Skip to content

Commit 18d5625

Browse files
committed
add concurrency-friendly map access to fix #8
1 parent ee73c68 commit 18d5625

File tree

2 files changed

+44
-16
lines changed

2 files changed

+44
-16
lines changed

text/bayes.go

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,11 @@ import (
7979
"math"
8080
"os"
8181
"strings"
82+
"sync"
8283

8384
"golang.org/x/text/transform"
8485

85-
"github.com/cdipaolo/goml/base"
86+
"github.com/piazzamp/goml/base"
8687
)
8788

8889
/*
@@ -129,7 +130,7 @@ type NaiveBayes struct {
129130
// Words holds a map of words
130131
// to their corresponding Word
131132
// structure
132-
Words map[string]Word `json:"words"`
133+
Words histogram `json:"words"`
133134

134135
// Count holds the number of times
135136
// class i was seen as Count[i]
@@ -160,6 +161,30 @@ type NaiveBayes struct {
160161
Output io.Writer
161162
}
162163

164+
// histogram allows conncurrency-friendly map access via its
165+
// exported Get and Set methods
166+
type histogram struct {
167+
sync.RWMutex
168+
words map[string]Word
169+
}
170+
171+
// Get looks up a word from h's Word map, it should be used in
172+
// place of a direct map lookup
173+
// the only caveat here is that it will always return the 'success' boolean
174+
func (h *histogram) Get(w string) (Word, bool) {
175+
h.RLock()
176+
result, ok := h.words[w]
177+
h.RUnlock()
178+
return result, ok
179+
}
180+
181+
// Set sets word k's value to v in h's Word map
182+
func (h *histogram) Set(k string, v Word) {
183+
h.Lock()
184+
h.words[k] = v
185+
h.Unlock()
186+
}
187+
163188
// Word holds the structural
164189
// information needed to calculate
165190
// the probability of
@@ -192,7 +217,7 @@ type Word struct {
192217
// comply with the transform.RemoveFunc interface
193218
func NewNaiveBayes(stream <-chan base.TextDatapoint, classes uint8, sanitize func(rune) bool) *NaiveBayes {
194219
return &NaiveBayes{
195-
Words: make(map[string]Word),
220+
Words: histogram{sync.RWMutex{}, make(map[string]Word)},
196221
Count: make([]uint64, classes),
197222
Probabilities: make([]float64, classes),
198223

@@ -211,14 +236,15 @@ func (b *NaiveBayes) Predict(sentence string) uint8 {
211236
sums := make([]float64, len(b.Count))
212237

213238
sentence, _, _ = transform.String(b.sanitize, sentence)
214-
w := strings.Split(strings.ToLower(sentence), " ")
215-
for _, word := range w {
216-
if _, ok := b.Words[word]; !ok {
239+
words := strings.Split(strings.ToLower(sentence), " ")
240+
for _, word := range words {
241+
w, ok := b.Words.Get(word)
242+
if !ok {
217243
continue
218244
}
219245

220246
for i := range sums {
221-
sums[i] += math.Log(float64(b.Words[word].Count[i]+1) / float64(b.Words[word].Seen+b.DictCount))
247+
sums[i] += math.Log(float64(w.Count[i]+1) / float64(w.Seen+b.DictCount))
222248
}
223249
}
224250

@@ -261,14 +287,15 @@ func (b *NaiveBayes) Probability(sentence string) (uint8, float64) {
261287
}
262288

263289
sentence, _, _ = transform.String(b.sanitize, sentence)
264-
w := strings.Split(strings.ToLower(sentence), " ")
265-
for _, word := range w {
266-
if _, ok := b.Words[word]; !ok {
290+
words := strings.Split(strings.ToLower(sentence), " ")
291+
for _, word := range words {
292+
w, ok := b.Words.Get(word)
293+
if !ok {
267294
continue
268295
}
269296

270297
for i := range sums {
271-
sums[i] *= float64(b.Words[word].Count[i]+1) / float64(b.Words[word].Seen+b.DictCount)
298+
sums[i] *= float64(w.Count[i]+1) / float64(w.Seen+b.DictCount)
272299
}
273300
}
274301

@@ -340,7 +367,7 @@ func (b *NaiveBayes) OnlineLearn(errors chan<- error) {
340367
continue
341368
}
342369

343-
w, ok := b.Words[word]
370+
w, ok := b.Words.Get(word)
344371

345372
if !ok {
346373
w = Word{
@@ -354,16 +381,16 @@ func (b *NaiveBayes) OnlineLearn(errors chan<- error) {
354381
w.Count[C]++
355382
w.Seen++
356383

357-
b.Words[word] = w
384+
b.Words.Set(word, w)
358385

359386
seenCount[word] = 1
360387
}
361388

362389
// add to DocsSeen
363390
for term := range seenCount {
364-
tmp := b.Words[term]
391+
tmp, _ := b.Words.Get(term)
365392
tmp.DocsSeen++
366-
b.Words[term] = tmp
393+
b.Words.Set(term, tmp)
367394
}
368395
} else {
369396
fmt.Fprintf(b.Output, "Training Completed.\n%v\n\n", b)

text/tfidf.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,5 +173,6 @@ func TermFrequencies(document []string) Frequencies {
173173
// Look at the TFIDF docs to see more about how
174174
// this is calculated
175175
func (t *TFIDF) InverseDocumentFrequency(word string) float64 {
176-
return math.Log(float64(t.DocumentCount)) - math.Log(float64(t.Words[word].DocsSeen)+1)
176+
w, _ := t.Words.Get(word)
177+
return math.Log(float64(t.DocumentCount)) - math.Log(float64(w.DocsSeen)+1)
177178
}

0 commit comments

Comments
 (0)