Skip to content

Commit 22c091a

Browse files
committed
bAbI: Doubling the FB LSTM baseline :)
1 parent d20fe64 commit 22c091a

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

examples/babi_rnn.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
88
Task Number | FB LSTM Baseline | Keras QA
99
--- | --- | ---
10-
QA1 - Single Supporting Fact | 50 | 52.1
11-
QA2 - Two Supporting Facts | 20 | 37.0
10+
QA1 - Single Supporting Fact | 50 | 100.0
11+
QA2 - Two Supporting Facts | 20 | 50.0
1212
QA3 - Three Supporting Facts | 20 | 20.5
1313
QA4 - Two Arg. Relations | 61 | 62.9
1414
QA5 - Three Arg. Relations | 70 | 61.9
@@ -34,8 +34,8 @@
3434
Notes:
3535
3636
- With default word, sentence, and query vector sizes, the GRU model achieves:
37-
- 52.1% test accuracy on QA1 in 20 epochs (2 seconds per epoch on CPU)
38-
- 37.0% test accuracy on QA2 in 20 epochs (16 seconds per epoch on CPU)
37+
- 100% test accuracy on QA1 in 20 epochs (2 seconds per epoch on CPU)
38+
- 50% test accuracy on QA2 in 20 epochs (16 seconds per epoch on CPU)
3939
In comparison, the Facebook paper achieves 50% and 20% for the LSTM baseline.
4040
4141
- The task does not traditionally parse the question separately. This likely
@@ -138,12 +138,12 @@ def vectorize_stories(data, word_idx, story_maxlen, query_maxlen):
138138
Y.append(y)
139139
return pad_sequences(X, maxlen=story_maxlen), pad_sequences(Xq, maxlen=query_maxlen), np.array(Y)
140140

141-
RNN = recurrent.GRU
141+
RNN = recurrent.LSTM
142142
EMBED_HIDDEN_SIZE = 50
143143
SENT_HIDDEN_SIZE = 100
144144
QUERY_HIDDEN_SIZE = 100
145145
BATCH_SIZE = 32
146-
EPOCHS = 20
146+
EPOCHS = 40
147147
print('RNN / Embed / Sent / Query = {}, {}, {}, {}'.format(RNN, EMBED_HIDDEN_SIZE, SENT_HIDDEN_SIZE, QUERY_HIDDEN_SIZE))
148148

149149
path = get_file('babi-tasks-v1-2.tar.gz', origin='http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz')
@@ -178,15 +178,19 @@ def vectorize_stories(data, word_idx, story_maxlen, query_maxlen):
178178
print('Build model...')
179179

180180
sentrnn = Sequential()
181-
sentrnn.add(Embedding(vocab_size, EMBED_HIDDEN_SIZE, mask_zero=True))
182-
sentrnn.add(RNN(SENT_HIDDEN_SIZE, return_sequences=False))
181+
sentrnn.add(Embedding(vocab_size, EMBED_HIDDEN_SIZE, input_length=story_maxlen, mask_zero=True))
182+
sentrnn.add(Dropout(0.3))
183183

184184
qrnn = Sequential()
185-
qrnn.add(Embedding(vocab_size, EMBED_HIDDEN_SIZE))
186-
qrnn.add(RNN(QUERY_HIDDEN_SIZE, return_sequences=False))
185+
qrnn.add(Embedding(vocab_size, EMBED_HIDDEN_SIZE, input_length=query_maxlen))
186+
qrnn.add(Dropout(0.3))
187+
qrnn.add(RNN(EMBED_HIDDEN_SIZE, return_sequences=False))
188+
qrnn.add(RepeatVector(story_maxlen))
187189

188190
model = Sequential()
189-
model.add(Merge([sentrnn, qrnn], mode='concat'))
191+
model.add(Merge([sentrnn, qrnn], mode='sum'))
192+
model.add(RNN(EMBED_HIDDEN_SIZE, return_sequences=False))
193+
model.add(Dropout(0.3))
190194
model.add(Dense(vocab_size, activation='softmax'))
191195

192196
model.compile(optimizer='adam', loss='categorical_crossentropy', class_mode='categorical')

0 commit comments

Comments
 (0)