Skip to content

Commit b0fd29c

Browse files
committed
Merge pull request #36 from ajkumarnv/master
naive bayes implementation by Ajay
2 parents c53e4ab + a9ddc66 commit b0fd29c

File tree

1 file changed

+133
-0
lines changed

1 file changed

+133
-0
lines changed

06-Naive-Bayes/bayes_msg_filter.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
2+
from __future__ import division
3+
from collections import defaultdict
4+
import pickle
5+
import re
6+
import sys
7+
import math
8+
9+
count = 0
10+
11+
12+
def rr():
13+
return 0.000000000000000000000001
14+
15+
def dd():
16+
return defaultdict(rr)
17+
18+
def train_line(line,data_dict,prior):
19+
(label,text) = line.split("\t")
20+
# Split the text by the punctuation words
21+
text = re.split(r"[^\w]|[\s]",text)
22+
prior[label.lower()] += 1
23+
for word in text:
24+
if word == "":
25+
continue
26+
data_dict[label.lower()][word.lower()] += 1
27+
28+
29+
def normalize_prob(data_dict,prior):
30+
ham_word_count = sum(data_dict["ham"].itervalues())
31+
spam_word_count = sum(data_dict["spam"].itervalues())
32+
for word in data_dict["spam"]:
33+
data_dict["spam"][word] /= spam_word_count
34+
for word in data_dict["ham"]:
35+
data_dict["ham"][word] /= ham_word_count
36+
37+
prior["ham"] /= (prior["ham"] + prior["spam"])
38+
prior["spam"] /= (prior["ham"] + prior["spam"])
39+
40+
def classify_msg(msg,data_dict,prior):
41+
spam_prob= 0.0
42+
ham_prob = 0.0
43+
44+
words = re.split("[^\w]|[/s]",msg)
45+
#print words
46+
for word in words:
47+
word=word.lower()
48+
if word == "":
49+
continue
50+
ham_prob += math.log10(data_dict["ham"][word])
51+
spam_prob += math.log10(data_dict["spam"][word])
52+
53+
ham_prob += math.log10(prior["ham"])
54+
spam_prob += math.log10(prior["spam"])
55+
56+
#print "Word :{}--Spam Prob->{} Ham Prob -> {}".format(word,spam_prob,ham_prob)
57+
if spam_prob > ham_prob:
58+
guess = "spam"
59+
else:
60+
guess = "ham"
61+
62+
return guess
63+
64+
65+
66+
def train(filename):
67+
data_dict = defaultdict(dd)
68+
prior = defaultdict(rr)
69+
with open(filename,'r') as inp:
70+
for line in inp:
71+
train_line(line,data_dict,prior)
72+
73+
normalize_prob()
74+
#print ham_words
75+
with open("m_brain", "wb") as fout:
76+
pickle.dump((data_dict,prior), fout)
77+
78+
79+
def test(filename):
80+
with open("m_brain","r") as fin:
81+
(data_dict,prior) = pickle.load(fin)
82+
83+
84+
count = 0
85+
success = 0
86+
with open(filename, "r") as inp:
87+
for line in inp:
88+
count +=1
89+
label,text = line.split("\t")
90+
guess = classify_msg(text,data_dict,prior)
91+
if label.lower() == guess:
92+
success += 1
93+
94+
95+
print "Success rate = {}/{}".format(success,count)
96+
print "Success rate %= {}".format(success/count * 100)
97+
98+
def train_and_test():
99+
data_dict = defaultdict(dd)
100+
prior = defaultdict(rr)
101+
102+
with open("corpus/SMSSpamCollection.txt") as inp:
103+
lines = inp.readlines()
104+
train_len = 5550
105+
train_data = lines[:train_len]
106+
test_data = lines[train_len:]
107+
success = 0
108+
for line in train_data:
109+
train_line(line,data_dict,prior)
110+
normalize_prob(data_dict,prior)
111+
112+
for line in test_data:
113+
label,text = line.split("\t")
114+
guess = classify_msg(text,data_dict,prior)
115+
if label.lower() == guess:
116+
success += 1
117+
118+
print "Success rate = {}/{}".format(success,len(test_data))
119+
print "Success rate %= {}".format(success/len(test_data) * 100)
120+
121+
def main():
122+
if len(sys.argv) == 3:
123+
if sys.argv[1] == "-train" or sys.argv[1] == "--t":
124+
train(sys.argv[2])
125+
elif sys.argv[1] == "-run" or sys.argv[1] == "--r":
126+
test(sys.argv[2])
127+
else:
128+
train_and_test()
129+
130+
131+
132+
if __name__ == "__main__":
133+
main()

0 commit comments

Comments
 (0)