Skip to content

Commit 349678a

Browse files
committed
python implementation of a naive bayes
1 parent 65807e8 commit 349678a

File tree

1 file changed

+102
-0
lines changed

1 file changed

+102
-0
lines changed

06-Naive-Bayes/nbayes.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
#!/usr/bin/python
2+
3+
# Cambridge Programmer Study Group
4+
#
5+
# Naive bayes implementation of a spam filter
6+
# Keeps 1000 text messages as test data set
7+
# Write incorrectly classified text messages
8+
# to a file for further inspection
9+
#
10+
# Ole Schulz-Trieglaff
11+
12+
from collections import defaultdict
13+
import math
14+
import string
15+
16+
def tokenize(ls):
17+
# remove some frequent words, convert to lower case and remove
18+
# punctuation characters
19+
forbidden = ["and","to", "i","a", "you", "the", "your", "is"]
20+
ls = [ w.lower() for w in ls ]
21+
ls = [ w.translate(None, string.punctuation) for w in ls ]
22+
ls = [ w for w in ls if w not in forbidden ]
23+
return ls
24+
25+
def main():
26+
27+
datafile = "corpus/SMSSpamCollection.txt"
28+
data = []
29+
with open(datafile) as input:
30+
for line in input:
31+
fields = line.split()
32+
label = fields[0]
33+
text = tokenize(fields[1:])
34+
data.append([label,text])
35+
36+
print "Have",len(data)," examples"
37+
38+
# let's keep 1000 examples separate as test data
39+
num_test = 1000
40+
test = data[:num_test]
41+
train = data[(num_test+1):]
42+
43+
# P(word|label)
44+
word_llhoods = defaultdict(lambda: defaultdict(lambda: 0.0001))
45+
# P(label)
46+
prior = defaultdict(float)
47+
num_train = len(train)
48+
for d in train:
49+
label = d[0]
50+
text = d[1]
51+
prior[label]+=1
52+
for t in text:
53+
word_llhoods[label][t]+=1
54+
55+
# normalize to get probabilities
56+
for k in prior:
57+
prior[k] /= num_train
58+
59+
spam_sum = sum(word_llhoods["spam"].itervalues())
60+
for w in word_llhoods["spam"]:
61+
word_llhoods["spam"][w] /= spam_sum
62+
ham_sum = sum(word_llhoods["ham"].itervalues())
63+
for w in word_llhoods["ham"]:
64+
word_llhoods["ham"][w] /= ham_sum
65+
66+
# debugging
67+
print "prior=",prior
68+
maxSpam = sorted(word_llhoods["spam"].iteritems(), key=lambda x: x[1])[0:5]
69+
print "5 most freqent spam word",maxSpam
70+
maxHam = sorted(word_llhoods["ham"].iteritems(), key=lambda x: x[1])[0:5]
71+
print "5 most frequent ham word",maxHam
72+
73+
# read test data
74+
correct = 0
75+
mistakesFile = "mistakes" # write incorrectly classified messages to a file
76+
with open(mistakesFile,"w") as mistakesOut:
77+
for d in test:
78+
label = d[0]
79+
text = d[1]
80+
llhood_spam = 0.0
81+
llhood_ham = 0.0
82+
for w in text:
83+
#print w," ",math.log10(word_llhoods["ham"][w])," ", math.log10(word_llhoods["spam"][w])
84+
llhood_spam += math.log10(word_llhoods["spam"][w])
85+
llhood_ham += math.log10(word_llhoods["ham"][w])
86+
87+
llhood_spam += math.log10(prior["spam"])
88+
llhood_ham += math.log10(prior["ham"])
89+
90+
guess = "spam" if llhood_spam > llhood_ham else "ham"
91+
if label == guess:
92+
correct+=1
93+
else:
94+
print >> mistakesOut, text
95+
print >> mistakesOut, "llhood_spam=",llhood_spam
96+
print >> mistakesOut, "llhood_ham=",llhood_ham
97+
print >> mistakesOut, "true label=",label
98+
99+
print "correct={} out of {} test cases".format(correct,num_test)
100+
101+
if __name__ == "__main__":
102+
main()

0 commit comments

Comments
 (0)