diff --git a/examples/babi_rnn.py b/examples/babi_rnn.py index 2499447e4005..2b08cb8f4d83 100644 --- a/examples/babi_rnn.py +++ b/examples/babi_rnn.py @@ -1,5 +1,6 @@ from __future__ import absolute_import from __future__ import print_function +from functools import reduce import re import tarfile @@ -21,6 +22,29 @@ "Towards AI-Complete Question Answering: A Set of Prerequisite Toy Tasks" http://arxiv.org/abs/1502.05698 +Task Number | FB LSTM Baseline | Keras QA +--- | --- | --- +QA1 - Single Supporting Fact | 50 | 52.1 +QA2 - Two Supporting Facts | 20 | 37.0 +QA3 - Three Supporting Facts | 20 | 20.5 +QA4 - Two Arg. Relations | 61 | 62.9 +QA5 - Three Arg. Relations | 70 | 61.9 +QA6 - Yes/No Questions | 48 | 50.7 +QA7 - Counting | 49 | 78.9 +QA8 - Lists/Sets | 45 | 77.2 +QA9 - Simple Negation | 64 | 64.0 +QA10 - Indefinite Knowledge | 44 | 47.7 +QA11 - Basic Coreference | 72 | 74.9 +QA12 - Conjunction | 74 | 76.4 +QA13 - Compound Coreference | 94 | 94.4 +QA14 - Time Reasoning | 27 | 34.8 +QA15 - Basic Deduction | 21 | 32.4 +QA16 - Basic Induction | 23 | 50.6 +QA17 - Positional Reasoning | 51 | 49.1 +QA18 - Size Reasoning | 52 | 90.8 +QA19 - Path Finding | 8 | 9.0 +QA20 - Agent's Motivations | 91 | 90.7 + For the resources related to the bAbI project, refer to: https://research.facebook.com/researchers/1543934539189348 @@ -67,7 +91,7 @@ def parse_stories(lines, only_supporting=False): data = [] story = [] for line in lines: - line = line.strip() + line = line.decode('utf-8').strip() nid, line = line.split(' ', 1) nid = int(nid) if nid == 1: @@ -137,7 +161,7 @@ def vectorize_stories(data): train = get_stories(tar.extractfile(challenge.format('train'))) test = get_stories(tar.extractfile(challenge.format('test'))) -vocab = sorted(reduce(lambda x, y: x | y, (set(story + q) for story, q, answer in train + test))) +vocab = sorted(reduce(lambda x, y: x | y, (set(story + q + [answer]) for story, q, answer in train + test))) # Reserve 0 for masking via pad_sequences vocab_size = len(vocab) + 1 word_idx = dict((c, i + 1) for i, c in enumerate(vocab))