|
7 | 7 |
|
8 | 8 | Task Number | FB LSTM Baseline | Keras QA |
9 | 9 | --- | --- | --- |
10 | | -QA1 - Single Supporting Fact | 50 | 52.1 |
11 | | -QA2 - Two Supporting Facts | 20 | 37.0 |
| 10 | +QA1 - Single Supporting Fact | 50 | 100.0 |
| 11 | +QA2 - Two Supporting Facts | 20 | 50.0 |
12 | 12 | QA3 - Three Supporting Facts | 20 | 20.5 |
13 | 13 | QA4 - Two Arg. Relations | 61 | 62.9 |
14 | 14 | QA5 - Three Arg. Relations | 70 | 61.9 |
|
34 | 34 | Notes: |
35 | 35 |
|
36 | 36 | - With default word, sentence, and query vector sizes, the GRU model achieves: |
37 | | - - 52.1% test accuracy on QA1 in 20 epochs (2 seconds per epoch on CPU) |
38 | | - - 37.0% test accuracy on QA2 in 20 epochs (16 seconds per epoch on CPU) |
| 37 | + - 100% test accuracy on QA1 in 20 epochs (2 seconds per epoch on CPU) |
| 38 | + - 50% test accuracy on QA2 in 20 epochs (16 seconds per epoch on CPU) |
39 | 39 | In comparison, the Facebook paper achieves 50% and 20% for the LSTM baseline. |
40 | 40 |
|
41 | 41 | - The task does not traditionally parse the question separately. This likely |
@@ -138,12 +138,12 @@ def vectorize_stories(data, word_idx, story_maxlen, query_maxlen): |
138 | 138 | Y.append(y) |
139 | 139 | return pad_sequences(X, maxlen=story_maxlen), pad_sequences(Xq, maxlen=query_maxlen), np.array(Y) |
140 | 140 |
|
141 | | -RNN = recurrent.GRU |
| 141 | +RNN = recurrent.LSTM |
142 | 142 | EMBED_HIDDEN_SIZE = 50 |
143 | 143 | SENT_HIDDEN_SIZE = 100 |
144 | 144 | QUERY_HIDDEN_SIZE = 100 |
145 | 145 | BATCH_SIZE = 32 |
146 | | -EPOCHS = 20 |
| 146 | +EPOCHS = 40 |
147 | 147 | print('RNN / Embed / Sent / Query = {}, {}, {}, {}'.format(RNN, EMBED_HIDDEN_SIZE, SENT_HIDDEN_SIZE, QUERY_HIDDEN_SIZE)) |
148 | 148 |
|
149 | 149 | path = get_file('babi-tasks-v1-2.tar.gz', origin='http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz') |
@@ -178,15 +178,19 @@ def vectorize_stories(data, word_idx, story_maxlen, query_maxlen): |
178 | 178 | print('Build model...') |
179 | 179 |
|
180 | 180 | sentrnn = Sequential() |
181 | | -sentrnn.add(Embedding(vocab_size, EMBED_HIDDEN_SIZE, mask_zero=True)) |
182 | | -sentrnn.add(RNN(SENT_HIDDEN_SIZE, return_sequences=False)) |
| 181 | +sentrnn.add(Embedding(vocab_size, EMBED_HIDDEN_SIZE, input_length=story_maxlen, mask_zero=True)) |
| 182 | +sentrnn.add(Dropout(0.3)) |
183 | 183 |
|
184 | 184 | qrnn = Sequential() |
185 | | -qrnn.add(Embedding(vocab_size, EMBED_HIDDEN_SIZE)) |
186 | | -qrnn.add(RNN(QUERY_HIDDEN_SIZE, return_sequences=False)) |
| 185 | +qrnn.add(Embedding(vocab_size, EMBED_HIDDEN_SIZE, input_length=query_maxlen)) |
| 186 | +qrnn.add(Dropout(0.3)) |
| 187 | +qrnn.add(RNN(EMBED_HIDDEN_SIZE, return_sequences=False)) |
| 188 | +qrnn.add(RepeatVector(story_maxlen)) |
187 | 189 |
|
188 | 190 | model = Sequential() |
189 | | -model.add(Merge([sentrnn, qrnn], mode='concat')) |
| 191 | +model.add(Merge([sentrnn, qrnn], mode='sum')) |
| 192 | +model.add(RNN(EMBED_HIDDEN_SIZE, return_sequences=False)) |
| 193 | +model.add(Dropout(0.3)) |
190 | 194 | model.add(Dense(vocab_size, activation='softmax')) |
191 | 195 |
|
192 | 196 | model.compile(optimizer='adam', loss='categorical_crossentropy', class_mode='categorical') |
|
0 commit comments