@@ -79,10 +79,11 @@ import (
7979 "math"
8080 "os"
8181 "strings"
82+ "sync"
8283
8384 "golang.org/x/text/transform"
8485
85- "github.com/cdipaolo /goml/base"
86+ "github.com/piazzamp /goml/base"
8687)
8788
8889/*
@@ -129,7 +130,7 @@ type NaiveBayes struct {
129130 // Words holds a map of words
130131 // to their corresponding Word
131132 // structure
132- Words map [ string ] Word `json:"words"`
133+ Words histogram `json:"words"`
133134
134135 // Count holds the number of times
135136 // class i was seen as Count[i]
@@ -160,6 +161,30 @@ type NaiveBayes struct {
160161 Output io.Writer
161162}
162163
164+ // histogram allows conncurrency-friendly map access via its
165+ // exported Get and Set methods
166+ type histogram struct {
167+ sync.RWMutex
168+ words map [string ]Word
169+ }
170+
171+ // Get looks up a word from h's Word map, it should be used in
172+ // place of a direct map lookup
173+ // the only caveat here is that it will always return the 'success' boolean
174+ func (h * histogram ) Get (w string ) (Word , bool ) {
175+ h .RLock ()
176+ result , ok := h .words [w ]
177+ h .RUnlock ()
178+ return result , ok
179+ }
180+
181+ // Set sets word k's value to v in h's Word map
182+ func (h * histogram ) Set (k string , v Word ) {
183+ h .Lock ()
184+ h .words [k ] = v
185+ h .Unlock ()
186+ }
187+
163188// Word holds the structural
164189// information needed to calculate
165190// the probability of
@@ -192,7 +217,7 @@ type Word struct {
192217// comply with the transform.RemoveFunc interface
193218func NewNaiveBayes (stream <- chan base.TextDatapoint , classes uint8 , sanitize func (rune ) bool ) * NaiveBayes {
194219 return & NaiveBayes {
195- Words : make (map [string ]Word ),
220+ Words : histogram {sync. RWMutex {}, make (map [string ]Word )} ,
196221 Count : make ([]uint64 , classes ),
197222 Probabilities : make ([]float64 , classes ),
198223
@@ -211,14 +236,15 @@ func (b *NaiveBayes) Predict(sentence string) uint8 {
211236 sums := make ([]float64 , len (b .Count ))
212237
213238 sentence , _ , _ = transform .String (b .sanitize , sentence )
214- w := strings .Split (strings .ToLower (sentence ), " " )
215- for _ , word := range w {
216- if _ , ok := b .Words [word ]; ! ok {
239+ words := strings .Split (strings .ToLower (sentence ), " " )
240+ for _ , word := range words {
241+ w , ok := b .Words .Get (word )
242+ if ! ok {
217243 continue
218244 }
219245
220246 for i := range sums {
221- sums [i ] += math .Log (float64 (b . Words [ word ]. Count [i ]+ 1 ) / float64 (b . Words [ word ] .Seen + b .DictCount ))
247+ sums [i ] += math .Log (float64 (w . Count [i ]+ 1 ) / float64 (w .Seen + b .DictCount ))
222248 }
223249 }
224250
@@ -261,14 +287,15 @@ func (b *NaiveBayes) Probability(sentence string) (uint8, float64) {
261287 }
262288
263289 sentence , _ , _ = transform .String (b .sanitize , sentence )
264- w := strings .Split (strings .ToLower (sentence ), " " )
265- for _ , word := range w {
266- if _ , ok := b .Words [word ]; ! ok {
290+ words := strings .Split (strings .ToLower (sentence ), " " )
291+ for _ , word := range words {
292+ w , ok := b .Words .Get (word )
293+ if ! ok {
267294 continue
268295 }
269296
270297 for i := range sums {
271- sums [i ] *= float64 (b . Words [ word ]. Count [i ]+ 1 ) / float64 (b . Words [ word ] .Seen + b .DictCount )
298+ sums [i ] *= float64 (w . Count [i ]+ 1 ) / float64 (w .Seen + b .DictCount )
272299 }
273300 }
274301
@@ -340,7 +367,7 @@ func (b *NaiveBayes) OnlineLearn(errors chan<- error) {
340367 continue
341368 }
342369
343- w , ok := b .Words [ word ]
370+ w , ok := b .Words . Get ( word )
344371
345372 if ! ok {
346373 w = Word {
@@ -354,16 +381,16 @@ func (b *NaiveBayes) OnlineLearn(errors chan<- error) {
354381 w .Count [C ]++
355382 w .Seen ++
356383
357- b .Words [ word ] = w
384+ b .Words . Set ( word , w )
358385
359386 seenCount [word ] = 1
360387 }
361388
362389 // add to DocsSeen
363390 for term := range seenCount {
364- tmp := b .Words [ term ]
391+ tmp , _ := b .Words . Get ( term )
365392 tmp .DocsSeen ++
366- b .Words [ term ] = tmp
393+ b .Words . Set ( term , tmp )
367394 }
368395 } else {
369396 fmt .Fprintf (b .Output , "Training Completed.\n %v\n \n " , b )
0 commit comments