22import argparse
33import torch
44import torch .nn as nn
5+ from torch import cuda
56from torch .autograd import Variable
67import math
78import time
8687 See README for specific formatting instructions.""" )
8788
8889# GPU
89- parser .add_argument ('-cuda ' , action = 'store_true' ,
90+ parser .add_argument ('-gpu ' , default = [], nargs = '+' , type = int ,
9091 help = "Use CUDA" )
9192
9293parser .add_argument ('-log_interval' , type = int , default = 50 ,
9596# help="Seed for random initialization")
9697
9798opt = parser .parse_args ()
99+ opt .cuda = len (opt .gpu )
100+
98101print (opt )
99102
100103if torch .cuda .is_available () and not opt .cuda :
101104 print ("WARNING: You have a CUDA device, so you should probably run with -cuda" )
102105
106+ if opt .cuda :
107+ cuda .set_device (opt .gpu [0 ])
103108
104109def NMTCriterion (vocabSize ):
105110 weight = torch .ones (vocabSize )
@@ -117,7 +122,7 @@ def memoryEfficientLoss(outputs, targets, generator, crit, eval=False):
117122
118123 batch_size = outputs .size (1 )
119124 outputs_split = torch .split (outputs , opt .max_generator_batches )
120- targets_split = torch .split (targets , opt .max_generator_batches )
125+ targets_split = torch .split (targets . contiguous () , opt .max_generator_batches )
121126 for out_t , targ_t in zip (outputs_split , targets_split ):
122127 out_t = out_t .view (- 1 , out_t .size (2 ))
123128 pred_t = generator (out_t )
@@ -136,9 +141,9 @@ def eval(model, criterion, data):
136141
137142 model .eval ()
138143 for i in range (len (data )):
139- batch = data [i ]
144+ batch = [ x . transpose ( 0 , 1 ) for x in data [i ]] # must be batch first for gather/scatter in DataParallel
140145 outputs = model (batch ) # FIXME volatile
141- targets = batch [1 ][1 :] # exclude <s> from targets
146+ targets = batch [1 ][:, 1 :] # exclude <s> from targets
142147 loss , _ = memoryEfficientLoss (
143148 outputs , targets , model .generator , criterion , eval = True )
144149 total_loss += loss
@@ -155,6 +160,7 @@ def trainModel(model, trainData, validData, dataset, optim):
155160 # define criterion of each GPU
156161 criterion = NMTCriterion (dataset ['dicts' ]['tgt' ].size ())
157162
163+ start_time = time .time ()
158164 def trainEpoch (epoch ):
159165
160166 # shuffle mini batch order
@@ -167,10 +173,11 @@ def trainEpoch(epoch):
167173
168174 batchIdx = batchOrder [i ] if epoch >= opt .curriculum else i
169175 batch = trainData [batchIdx ]
176+ batch = [x .transpose (0 , 1 ) for x in batch ] # must be batch first for gather/scatter in DataParallel
170177
171178 model .zero_grad ()
172179 outputs = model (batch )
173- targets = batch [1 ][1 :] # exclude <s> from targets
180+ targets = batch [1 ][:, 1 :] # exclude <s> from targets
174181 loss , gradOutput = memoryEfficientLoss (
175182 outputs , targets , model .generator , criterion )
176183
@@ -185,10 +192,11 @@ def trainEpoch(epoch):
185192 total_words += num_words
186193 report_words += num_words
187194 if i % opt .log_interval == 0 and i > 0 :
188- print ("Epoch %2d, %5d/%5d batches; perplexity: %6.2f; %3.0f tokens/s" %
195+ print ("Epoch %2d, %5d/%5d batches; perplexity: %6.2f; %3.0f tokens/s; %6.0f s elapsed " %
189196 (epoch , i , len (trainData ),
190197 math .exp (report_loss / report_words ),
191- report_words / (time .time ()- start )))
198+ report_words / (time .time ()- start ),
199+ time .time ()- start_time ))
192200
193201 report_loss = report_words = 0
194202 start = time .time ()
@@ -249,7 +257,15 @@ def main():
249257 generator = nn .Sequential (
250258 nn .Linear (opt .rnn_size , dicts ['tgt' ].size ()),
251259 nn .LogSoftmax ())
260+ generator = nn .DataParallel (generator , device_ids = opt .gpu )
252261 model = onmt .Models .NMTModel (encoder , decoder , generator )
262+ model = nn .DataParallel (model , device_ids = opt .gpu )
263+ if opt .cuda :
264+ model .cuda ()
265+ else :
266+ model .cpu ()
267+
268+ model .generator = generator
253269
254270 for p in model .parameters ():
255271 p .data .uniform_ (- opt .param_init , opt .param_init )
@@ -263,14 +279,13 @@ def main():
263279 print ('Loading from checkpoint at %s' % opt .train_from )
264280 checkpoint = torch .load (opt .train_from )
265281 model = checkpoint ['model' ]
282+ if opt .cuda :
283+ model .cuda ()
284+ else :
285+ model .cpu ()
266286 optim = checkpoint ['optim' ]
267287 opt .start_epoch = checkpoint ['epoch' ] + 1
268288
269- if opt .cuda :
270- model .cuda ()
271- else :
272- model .cpu ()
273-
274289 nParams = sum ([p .nelement () for p in model .parameters ()])
275290 print ('* number of parameters: %d' % nParams )
276291
0 commit comments