Skip to content

Commit 4b2f5a3

Browse files
authored
Merge pull request #12 from jrbarron/tokenizer
Add support for specifying a custom tokenizer on the NaiveBayes model
2 parents 86c1fda + 732a627 commit 4b2f5a3

File tree

3 files changed

+84
-8
lines changed

3 files changed

+84
-8
lines changed

text/bayes.go

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -156,11 +156,23 @@ type NaiveBayes struct {
156156
// stream holds the datastream
157157
stream <-chan base.TextDatapoint
158158

159+
// tokenizer is used by a model
160+
// to split the input into tokens
161+
tokenize Tokenizer
162+
159163
// Output is the io.Writer used for logging
160164
// and printing. Defaults to os.Stdout.
161165
Output io.Writer `json:"-"`
162166
}
163167

168+
// Tokenizer accepts a sentence as input and breaks
169+
// it down into a slice of tokens
170+
type Tokenizer func(string) []string
171+
172+
func spaceTokenizer(input string) []string {
173+
return strings.Split(strings.ToLower(input), " ")
174+
}
175+
164176
// concurrentMap allows concurrency-friendly map
165177
// access via its exported Get and Set methods
166178
type concurrentMap struct {
@@ -236,6 +248,7 @@ func NewNaiveBayes(stream <-chan base.TextDatapoint, classes uint8, sanitize fun
236248

237249
sanitize: transform.RemoveFunc(sanitize),
238250
stream: stream,
251+
tokenize: spaceTokenizer,
239252

240253
Output: os.Stdout,
241254
}
@@ -249,7 +262,7 @@ func (b *NaiveBayes) Predict(sentence string) uint8 {
249262
sums := make([]float64, len(b.Count))
250263

251264
sentence, _, _ = transform.String(b.sanitize, sentence)
252-
words := strings.Split(strings.ToLower(sentence), " ")
265+
words := b.tokenize(sentence)
253266
for _, word := range words {
254267
w, ok := b.Words.Get(word)
255268
if !ok {
@@ -300,7 +313,7 @@ func (b *NaiveBayes) Probability(sentence string) (uint8, float64) {
300313
}
301314

302315
sentence, _, _ = transform.String(b.sanitize, sentence)
303-
words := strings.Split(strings.ToLower(sentence), " ")
316+
words := b.tokenize(sentence)
304317
for _, word := range words {
305318
w, ok := b.Words.Get(word)
306319
if !ok {
@@ -353,9 +366,7 @@ func (b *NaiveBayes) OnlineLearn(errors chan<- error) {
353366
if more {
354367
// sanitize and break up document
355368
sanitized, _, _ := transform.String(b.sanitize, point.X)
356-
sanitized = strings.ToLower(sanitized)
357-
358-
words := strings.Split(sanitized, " ")
369+
words := b.tokenize(sanitized)
359370

360371
C := int(point.Y)
361372

@@ -425,6 +436,13 @@ func (b *NaiveBayes) UpdateSanitize(sanitize func(rune) bool) {
425436
b.sanitize = transform.RemoveFunc(sanitize)
426437
}
427438

439+
// UpdateTokenizer updates NaiveBayes model's tokenizer function.
440+
// The default implementation will convert the input to lower
441+
// case and split on the space character.
442+
func (b *NaiveBayes) UpdateTokenizer(tokenizer Tokenizer) {
443+
b.tokenize = tokenizer
444+
}
445+
428446
// String implements the fmt interface for clean printing. Here
429447
// we're using it to print the model as the equation h(θ)=...
430448
// where h is the perceptron hypothesis model.

text/bayes_test.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,3 +356,62 @@ func TestPersistNaiveBayesShouldPass1(t *testing.T) {
356356
assert.EqualValues(t, 1, class, "Class should be 1")
357357
assert.True(t, p > 0.75, "There should be a greater than 75 percent chance the document is positive - Given %v", p)
358358
}
359+
360+
func TestTokenizer(t *testing.T) {
361+
stream := make(chan base.TextDatapoint, 100)
362+
errors := make(chan error)
363+
364+
// This is a somewhat contrived test case since splitting on commas is
365+
// probably not very useful, but it is designed to purely test the
366+
// tokenizer. A more useful, but too complicated test case would be to use
367+
// a tokenizer that does something like porter stemming.
368+
model := NewNaiveBayes(stream, 3, func(rune) bool {
369+
// do not filter out commas
370+
return false
371+
})
372+
model.UpdateTokenizer(func(input string) []string {
373+
return strings.Split(strings.ToLower(input), ",")
374+
})
375+
376+
go model.OnlineLearn(errors)
377+
378+
stream <- base.TextDatapoint{
379+
X: "I,love,the,city",
380+
Y: 1,
381+
}
382+
383+
stream <- base.TextDatapoint{
384+
X: "I,hate,Los,Angeles",
385+
Y: 0,
386+
}
387+
388+
stream <- base.TextDatapoint{
389+
X: "My,mother,is,not,a,nice,lady",
390+
Y: 0,
391+
}
392+
393+
close(stream)
394+
395+
for {
396+
err, more := <-errors
397+
if more {
398+
fmt.Printf("Error passed: %v", err)
399+
} else {
400+
// training is done!
401+
break
402+
}
403+
}
404+
405+
// now you can predict like normal
406+
class := model.Predict("My,mo~~~ther,is,in,Los,Angeles") // 0
407+
assert.EqualValues(t, 0, class, "Class should be 0")
408+
409+
// test small document classification
410+
class, p := model.Probability("Mother,Los,Angeles")
411+
assert.EqualValues(t, 0, class, "Class should be 0")
412+
assert.True(t, p > 0.75, "There should be a greater than 75 percent chance the document is negative - Given %v", p)
413+
414+
class, p = model.Probability("Love,the,CiTy")
415+
assert.EqualValues(t, 1, class, "Class should be 1")
416+
assert.True(t, p > 0.75, "There should be a greater than 75 percent chance the document is positive - Given %v", p)
417+
}

text/tfidf.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package text
33
import (
44
"math"
55
"sort"
6-
"strings"
76

87
"golang.org/x/text/transform"
98
)
@@ -81,7 +80,7 @@ func (f Frequencies) Swap(i, j int) {
8180
// this is calculated
8281
func (t *TFIDF) TFIDF(word string, sentence string) float64 {
8382
sentence, _, _ = transform.String(t.sanitize, sentence)
84-
document := strings.Split(strings.ToLower(sentence), " ")
83+
document := t.tokenize(sentence)
8584

8685
return t.TermFrequency(word, document) * t.InverseDocumentFrequency(word)
8786
}
@@ -96,7 +95,7 @@ func (t *TFIDF) TFIDF(word string, sentence string) float64 {
9695
// by importance
9796
func (t *TFIDF) MostImportantWords(sentence string, n int) Frequencies {
9897
sentence, _, _ = transform.String(t.sanitize, sentence)
99-
document := strings.Split(strings.ToLower(sentence), " ")
98+
document := t.tokenize(sentence)
10099

101100
freq := TermFrequencies(document)
102101
for i := range freq {

0 commit comments

Comments
 (0)