From 0d507160cae3db413b8baffaafd85bbde4d8b301 Mon Sep 17 00:00:00 2001 From: Javi Merino Date: Sat, 17 Oct 2015 08:54:24 +0100 Subject: [PATCH] Add Girish's and Javi's (partial) solution --- 06-Naive-Bayes/spam_bayes.py | 97 ++++++++++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100755 06-Naive-Bayes/spam_bayes.py diff --git a/06-Naive-Bayes/spam_bayes.py b/06-Naive-Bayes/spam_bayes.py new file mode 100755 index 0000000..7a25c40 --- /dev/null +++ b/06-Naive-Bayes/spam_bayes.py @@ -0,0 +1,97 @@ +#!/usr/bin/python3 + +import collections +import pickle +import re +import sys + +def classify_words(text, word_dict): + # Remove trailing \n + text = text[:-1] + + words = re.split(r"(?: |,|\.|-|–|:|;|&|=|\+|#|\(|\)|\||<|…|\^|\[|\])+", text) + + for word in words: + # Remove misc symbols + word = re.sub(r"^(?:'|\"|“)+", r"", word) + word = re.sub(r"(\w)'$", r"\1", word) + word = word.lower() + + if word == "": + continue + + word_dict[word] += 1 + +def train(fname): + ham_words = collections.defaultdict(int) + spam_words = collections.defaultdict(int) + + with open(fname) as fin: + for line in fin: + line_parts = line.split("\t") + if line_parts[0] == "ham": + classify_words(line_parts[1], ham_words) + elif line_parts[0] == "spam": + classify_words(line_parts[1], spam_words) + else: + raise RuntimeError("Unkwnown line: {}".format(line)) + + with open("brain", "wb") as fout: + pickle.dump((ham_words, spam_words), fout) + +def test(fname): + with open("brain", "rb") as fin: + (ham_words, spam_words) = pickle.load(fin) + + probability_ham = 0.5 + probability_spam = 0.5 + + with open(fname) as fin: + for line in fin: + (true_clas, text) = line.split("\t") + + words = re.split(r"(?: |,|\.|-|–|:|;|&|=|\+|#|\(|\)|\||<|…|\^|\[|\])+", text) + + for word in words: + # Remove misc symbols + word = re.sub(r"^(?:'|\"|“)+", r"", word) + word = re.sub(r"(\w)'$", r"\1", word) + word = word.lower() + + if word == "": + continue + + ham_instances = ham_words[word] + spam_instances = spam_words[word] + total_instances = ham_instances + spam_instances + if total_instances == 0: + continue + + probability_ham *= ham_instances / total_instances + probability_spam *= spam_instances / total_instances + print("Spam: {} ham: {}".format(probability_spam, probability_ham)) + + if probability_spam > probability_ham: + guess = "spam" + else: + guess = "ham" + + if (guess == true_clas): + correct = "Correct" + else: + correct = "incorrect" + raise RuntimeError("Incorrect line: {}".format(line)) + + print("Our guess: {} (spam: {}, ham: {}), {}".format(guess, probability_spam, probability_ham, correct)) + +def main(): + if sys.argv[1] == "--train": + train(sys.argv[2]) + elif sys.argv[1] == "--test": + test(sys.argv[2]) + else: + train(sys.argv[1]) + test(sys.argv[1]) + +if __name__ == "__main__": + main()