|
| 1 | +#!/usr/bin/python |
| 2 | + |
| 3 | +# Cambridge Programmer Study Group |
| 4 | +# |
| 5 | +# Naive bayes implementation of a spam filter |
| 6 | +# Keeps 1000 text messages as test data set |
| 7 | +# Write incorrectly classified text messages |
| 8 | +# to a file for further inspection |
| 9 | +# |
| 10 | +# Ole Schulz-Trieglaff |
| 11 | + |
| 12 | +from collections import defaultdict |
| 13 | +import math |
| 14 | +import string |
| 15 | + |
| 16 | +def tokenize(ls): |
| 17 | + # remove some frequent words, convert to lower case and remove |
| 18 | + # punctuation characters |
| 19 | + forbidden = ["and","to", "i","a", "you", "the", "your", "is"] |
| 20 | + ls = [ w.lower() for w in ls ] |
| 21 | + ls = [ w.translate(None, string.punctuation) for w in ls ] |
| 22 | + ls = [ w for w in ls if w not in forbidden ] |
| 23 | + return ls |
| 24 | + |
| 25 | +def main(): |
| 26 | + |
| 27 | + datafile = "corpus/SMSSpamCollection.txt" |
| 28 | + data = [] |
| 29 | + with open(datafile) as input: |
| 30 | + for line in input: |
| 31 | + fields = line.split() |
| 32 | + label = fields[0] |
| 33 | + text = tokenize(fields[1:]) |
| 34 | + data.append([label,text]) |
| 35 | + |
| 36 | + print "Have",len(data)," examples" |
| 37 | + |
| 38 | + # let's keep 1000 examples separate as test data |
| 39 | + num_test = 1000 |
| 40 | + test = data[:num_test] |
| 41 | + train = data[(num_test+1):] |
| 42 | + |
| 43 | + # P(word|label) |
| 44 | + word_llhoods = defaultdict(lambda: defaultdict(lambda: 0.0001)) |
| 45 | + # P(label) |
| 46 | + prior = defaultdict(float) |
| 47 | + num_train = len(train) |
| 48 | + for d in train: |
| 49 | + label = d[0] |
| 50 | + text = d[1] |
| 51 | + prior[label]+=1 |
| 52 | + for t in text: |
| 53 | + word_llhoods[label][t]+=1 |
| 54 | + |
| 55 | + # normalize to get probabilities |
| 56 | + for k in prior: |
| 57 | + prior[k] /= num_train |
| 58 | + |
| 59 | + spam_sum = sum(word_llhoods["spam"].itervalues()) |
| 60 | + for w in word_llhoods["spam"]: |
| 61 | + word_llhoods["spam"][w] /= spam_sum |
| 62 | + ham_sum = sum(word_llhoods["ham"].itervalues()) |
| 63 | + for w in word_llhoods["ham"]: |
| 64 | + word_llhoods["ham"][w] /= ham_sum |
| 65 | + |
| 66 | + # debugging |
| 67 | + print "prior=",prior |
| 68 | + maxSpam = sorted(word_llhoods["spam"].iteritems(), key=lambda x: x[1])[0:5] |
| 69 | + print "5 most freqent spam word",maxSpam |
| 70 | + maxHam = sorted(word_llhoods["ham"].iteritems(), key=lambda x: x[1])[0:5] |
| 71 | + print "5 most frequent ham word",maxHam |
| 72 | + |
| 73 | + # read test data |
| 74 | + correct = 0 |
| 75 | + mistakesFile = "mistakes" # write incorrectly classified messages to a file |
| 76 | + with open(mistakesFile,"w") as mistakesOut: |
| 77 | + for d in test: |
| 78 | + label = d[0] |
| 79 | + text = d[1] |
| 80 | + llhood_spam = 0.0 |
| 81 | + llhood_ham = 0.0 |
| 82 | + for w in text: |
| 83 | + #print w," ",math.log10(word_llhoods["ham"][w])," ", math.log10(word_llhoods["spam"][w]) |
| 84 | + llhood_spam += math.log10(word_llhoods["spam"][w]) |
| 85 | + llhood_ham += math.log10(word_llhoods["ham"][w]) |
| 86 | + |
| 87 | + llhood_spam += math.log10(prior["spam"]) |
| 88 | + llhood_ham += math.log10(prior["ham"]) |
| 89 | + |
| 90 | + guess = "spam" if llhood_spam > llhood_ham else "ham" |
| 91 | + if label == guess: |
| 92 | + correct+=1 |
| 93 | + else: |
| 94 | + print >> mistakesOut, text |
| 95 | + print >> mistakesOut, "llhood_spam=",llhood_spam |
| 96 | + print >> mistakesOut, "llhood_ham=",llhood_ham |
| 97 | + print >> mistakesOut, "true label=",label |
| 98 | + |
| 99 | + print "correct={} out of {} test cases".format(correct,num_test) |
| 100 | + |
| 101 | +if __name__ == "__main__": |
| 102 | + main() |
0 commit comments