Skip to content

Commit 1a572b1

Browse files
committed
Merge pull request #501 from Smerity/master
Fixes and full results for bAbi RNN example
2 parents e42f738 + 342f2bc commit 1a572b1

File tree

1 file changed

+26
-2
lines changed

1 file changed

+26
-2
lines changed

examples/babi_rnn.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import absolute_import
22
from __future__ import print_function
3+
from functools import reduce
34
import re
45
import tarfile
56

@@ -21,6 +22,29 @@
2122
"Towards AI-Complete Question Answering: A Set of Prerequisite Toy Tasks"
2223
http://arxiv.org/abs/1502.05698
2324
25+
Task Number | FB LSTM Baseline | Keras QA
26+
--- | --- | ---
27+
QA1 - Single Supporting Fact | 50 | 52.1
28+
QA2 - Two Supporting Facts | 20 | 37.0
29+
QA3 - Three Supporting Facts | 20 | 20.5
30+
QA4 - Two Arg. Relations | 61 | 62.9
31+
QA5 - Three Arg. Relations | 70 | 61.9
32+
QA6 - Yes/No Questions | 48 | 50.7
33+
QA7 - Counting | 49 | 78.9
34+
QA8 - Lists/Sets | 45 | 77.2
35+
QA9 - Simple Negation | 64 | 64.0
36+
QA10 - Indefinite Knowledge | 44 | 47.7
37+
QA11 - Basic Coreference | 72 | 74.9
38+
QA12 - Conjunction | 74 | 76.4
39+
QA13 - Compound Coreference | 94 | 94.4
40+
QA14 - Time Reasoning | 27 | 34.8
41+
QA15 - Basic Deduction | 21 | 32.4
42+
QA16 - Basic Induction | 23 | 50.6
43+
QA17 - Positional Reasoning | 51 | 49.1
44+
QA18 - Size Reasoning | 52 | 90.8
45+
QA19 - Path Finding | 8 | 9.0
46+
QA20 - Agent's Motivations | 91 | 90.7
47+
2448
For the resources related to the bAbI project, refer to:
2549
https://research.facebook.com/researchers/1543934539189348
2650
@@ -67,7 +91,7 @@ def parse_stories(lines, only_supporting=False):
6791
data = []
6892
story = []
6993
for line in lines:
70-
line = line.strip()
94+
line = line.decode('utf-8').strip()
7195
nid, line = line.split(' ', 1)
7296
nid = int(nid)
7397
if nid == 1:
@@ -137,7 +161,7 @@ def vectorize_stories(data):
137161
train = get_stories(tar.extractfile(challenge.format('train')))
138162
test = get_stories(tar.extractfile(challenge.format('test')))
139163

140-
vocab = sorted(reduce(lambda x, y: x | y, (set(story + q) for story, q, answer in train + test)))
164+
vocab = sorted(reduce(lambda x, y: x | y, (set(story + q + [answer]) for story, q, answer in train + test)))
141165
# Reserve 0 for masking via pad_sequences
142166
vocab_size = len(vocab) + 1
143167
word_idx = dict((c, i + 1) for i, c in enumerate(vocab))

0 commit comments

Comments
 (0)