@@ -156,11 +156,23 @@ type NaiveBayes struct {
156156 // stream holds the datastream
157157 stream <- chan base.TextDatapoint
158158
159+ // tokenizer is used by a model
160+ // to split the input into tokens
161+ tokenize Tokenizer
162+
159163 // Output is the io.Writer used for logging
160164 // and printing. Defaults to os.Stdout.
161165 Output io.Writer `json:"-"`
162166}
163167
168+ // Tokenizer accepts a sentence as input and breaks
169+ // it down into a slice of tokens
170+ type Tokenizer func (string ) []string
171+
172+ func spaceTokenizer (input string ) []string {
173+ return strings .Split (strings .ToLower (input ), " " )
174+ }
175+
164176// concurrentMap allows concurrency-friendly map
165177// access via its exported Get and Set methods
166178type concurrentMap struct {
@@ -236,6 +248,7 @@ func NewNaiveBayes(stream <-chan base.TextDatapoint, classes uint8, sanitize fun
236248
237249 sanitize : transform .RemoveFunc (sanitize ),
238250 stream : stream ,
251+ tokenize : spaceTokenizer ,
239252
240253 Output : os .Stdout ,
241254 }
@@ -249,7 +262,7 @@ func (b *NaiveBayes) Predict(sentence string) uint8 {
249262 sums := make ([]float64 , len (b .Count ))
250263
251264 sentence , _ , _ = transform .String (b .sanitize , sentence )
252- words := strings . Split ( strings . ToLower ( sentence ), " " )
265+ words := b . tokenize ( sentence )
253266 for _ , word := range words {
254267 w , ok := b .Words .Get (word )
255268 if ! ok {
@@ -300,7 +313,7 @@ func (b *NaiveBayes) Probability(sentence string) (uint8, float64) {
300313 }
301314
302315 sentence , _ , _ = transform .String (b .sanitize , sentence )
303- words := strings . Split ( strings . ToLower ( sentence ), " " )
316+ words := b . tokenize ( sentence )
304317 for _ , word := range words {
305318 w , ok := b .Words .Get (word )
306319 if ! ok {
@@ -353,9 +366,7 @@ func (b *NaiveBayes) OnlineLearn(errors chan<- error) {
353366 if more {
354367 // sanitize and break up document
355368 sanitized , _ , _ := transform .String (b .sanitize , point .X )
356- sanitized = strings .ToLower (sanitized )
357-
358- words := strings .Split (sanitized , " " )
369+ words := b .tokenize (sanitized )
359370
360371 C := int (point .Y )
361372
@@ -425,6 +436,13 @@ func (b *NaiveBayes) UpdateSanitize(sanitize func(rune) bool) {
425436 b .sanitize = transform .RemoveFunc (sanitize )
426437}
427438
439+ // UpdateTokenizer updates NaiveBayes model's tokenizer function.
440+ // The default implementation will convert the input to lower
441+ // case and split on the space character.
442+ func (b * NaiveBayes ) UpdateTokenizer (tokenizer Tokenizer ) {
443+ b .tokenize = tokenizer
444+ }
445+
428446// String implements the fmt interface for clean printing. Here
429447// we're using it to print the model as the equation h(θ)=...
430448// where h is the perceptron hypothesis model.
0 commit comments