Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
#13266 Update lstm_seq2seq.py(from 22% to 87% acc)
I added the codes for applying one-hot encoding on the end of sentences about encoder_input_data, decoder_input_data, and decoder_target_data. I added an accuracy metric for model training. The original code has 22% accuracy, but the proposed code had 87% validation accuracy.
  • Loading branch information
tykimos authored Aug 30, 2019
commit c846308ad0d9cb894ac6e08f69baf949d00ee5e7
6 changes: 4 additions & 2 deletions examples/lstm_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,16 @@
for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):
for t, char in enumerate(input_text):
encoder_input_data[i, t, input_token_index[char]] = 1.
encoder_input_data[i, t+1:, input_token_index[' ']] = 1.
for t, char in enumerate(target_text):
# decoder_target_data is ahead of decoder_input_data by one timestep
decoder_input_data[i, t, target_token_index[char]] = 1.
if t > 0:
# decoder_target_data will be ahead by one timestep
# and will not include the start character.
decoder_target_data[i, t - 1, target_token_index[char]] = 1.

decoder_input_data[i, t+1:, target_token_index[' ']] = 1.
decoder_target_data[i, t:, target_token_index[' ']] = 1.
# Define an input sequence and process it.
encoder_inputs = Input(shape=(None, num_encoder_tokens))
encoder = LSTM(latent_dim, return_state=True)
Expand All @@ -145,7 +147,7 @@
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)

# Run training
model.compile(optimizer='rmsprop', loss='categorical_crossentropy')
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit([encoder_input_data, decoder_input_data], decoder_target_data,
batch_size=batch_size,
epochs=epochs,
Expand Down