Skip to content

Commit 716cac3

Browse files
committed
modify seq2seq
1 parent 5842b6d commit 716cac3

File tree

1 file changed

+5
-3
lines changed
  • chapter8_Application/seq2seq-translation

1 file changed

+5
-3
lines changed

chapter8_Application/seq2seq-translation/train.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
input_size = lang_dataset.input_lang_words
2121
hidden_size = 256
2222
output_size = lang_dataset.output_lang_words
23-
total_epoch = 100
23+
total_epoch = 20
2424

2525
encoder = EncoderRNN(input_size, hidden_size)
2626
decoder = DecoderRNN(hidden_size, output_size, n_layers=2)
@@ -43,13 +43,14 @@ def showPlot(points):
4343
def train(encoder, decoder, total_epoch, use_attn):
4444

4545
param = list(encoder.parameters()) + list(decoder.parameters())
46-
optimizer = optim.SGD(param, lr=1e-2)
46+
optimizer = optim.Adam(param, lr=1e-2)
4747
criterion = nn.NLLLoss()
4848
plot_losses = []
4949
for epoch in range(total_epoch):
5050
since = time.time()
5151
running_loss = 0
5252
print_loss_total = 0
53+
total_loss = 0
5354
for i, data in enumerate(lang_dataloader):
5455
in_lang, out_lang = data
5556
if torch.cuda.is_available():
@@ -104,6 +105,7 @@ def train(encoder, decoder, total_epoch, use_attn):
104105
optimizer.step()
105106
running_loss += loss.data[0]
106107
print_loss_total += loss.data[0]
108+
total_loss += loss.data[0]
107109
if (i + 1) % 5000 == 0:
108110
print('{}/{}, Loss:{:.6f}'.format(
109111
i + 1, len(lang_dataloader), running_loss / 5000))
@@ -114,7 +116,7 @@ def train(encoder, decoder, total_epoch, use_attn):
114116
print_loss_total = 0
115117
during = time.time() - since
116118
print('Finish {}/{} , Loss:{:.6f}, Time:{:.0f}s'.format(
117-
epoch + 1, total_epoch, running_loss / len(lang_dataset), during))
119+
epoch + 1, total_epoch, total_loss / len(lang_dataset), during))
118120
print()
119121
showPlot(plot_losses)
120122

0 commit comments

Comments
 (0)