Skip to content

Commit f485f7b

Browse files
bmccannsoumith
authored andcommitted
multi-gpu via DataParallel
1 parent c84f7df commit f485f7b

File tree

3 files changed

+37
-24
lines changed

3 files changed

+37
-24
lines changed

OpenNMT/onmt/Models.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,9 @@ def __init__(self, opt, dicts):
2929
self.word_lut.weight.copy_(pretrained)
3030

3131
def forward(self, input, hidden=None):
32-
emb = self.word_lut(input)
33-
32+
batch_size = input.size(0) # batch first for multi-gpu compatibility
33+
emb = self.word_lut(input).transpose(0, 1)
3434
if hidden is None:
35-
batch_size = emb.size(1)
3635
h_size = (self.layers * self.num_directions, batch_size, self.hidden_size)
3736
h_0 = Variable(emb.data.new(*h_size).zero_(), requires_grad=False)
3837
c_0 = Variable(emb.data.new(*h_size).zero_(), requires_grad=False)
@@ -46,21 +45,21 @@ class StackedLSTM(nn.Module):
4645
def __init__(self, num_layers, input_size, rnn_size, dropout):
4746
super(StackedLSTM, self).__init__()
4847
self.dropout = nn.Dropout(dropout)
48+
self.num_layers = num_layers
4949

50-
self.layers = []
5150
for i in range(num_layers):
5251
layer = nn.LSTMCell(input_size, rnn_size)
5352
self.add_module('layer_%d' % i, layer)
54-
self.layers += [layer]
5553
input_size = rnn_size
5654

5755
def forward(self, input, hidden):
5856
h_0, c_0 = hidden
5957
h_1, c_1 = [], []
60-
for i, layer in enumerate(self.layers):
58+
for i in range(self.num_layers):
59+
layer = getattr(self, 'layer_%d' % i)
6160
h_1_i, c_1_i = layer(input, (h_0[i], c_0[i]))
6261
input = h_1_i
63-
if i != len(self.layers):
62+
if i != self.num_layers:
6463
input = self.dropout(input)
6564
h_1 += [h_1_i]
6665
c_1 += [c_1_i]
@@ -99,9 +98,9 @@ def __init__(self, opt, dicts):
9998

10099

101100
def forward(self, input, hidden, context, init_output):
102-
emb = self.word_lut(input)
101+
emb = self.word_lut(input).transpose(0, 1)
103102

104-
batch_size = input.size(1)
103+
batch_size = input.size(0)
105104

106105
h_size = (batch_size, self.hidden_size)
107106
output = Variable(emb.data.new(*h_size).zero_(), requires_grad=False)
@@ -122,7 +121,7 @@ def forward(self, input, hidden, context, init_output):
122121
outputs += [output]
123122

124123
outputs = torch.stack(outputs)
125-
return outputs, hidden, attn
124+
return outputs.transpose(0, 1), hidden, attn
126125

127126

128127
class NMTModel(nn.Module):
@@ -154,7 +153,7 @@ def _fix_enc_hidden(self, h):
154153

155154
def forward(self, input):
156155
src = input[0]
157-
tgt = input[1][:-1] # exclude last target from inputs
156+
tgt = input[1][:, :-1] # exclude last target from inputs
158157
enc_hidden, context = self.encoder(src)
159158
init_output = self.make_init_decoder_output(context)
160159

OpenNMT/preprocess.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import onmt
22

33
import argparse
4-
import os
54
import torch
65

76
parser = argparse.ArgumentParser(description='preprocess.lua')

OpenNMT/train.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import argparse
33
import torch
44
import torch.nn as nn
5+
from torch import cuda
56
from torch.autograd import Variable
67
import math
78
import time
@@ -86,7 +87,7 @@
8687
See README for specific formatting instructions.""")
8788

8889
# GPU
89-
parser.add_argument('-cuda', action='store_true',
90+
parser.add_argument('-gpu', default=[], nargs='+', type=int,
9091
help="Use CUDA")
9192

9293
parser.add_argument('-log_interval', type=int, default=50,
@@ -95,11 +96,15 @@
9596
# help="Seed for random initialization")
9697

9798
opt = parser.parse_args()
99+
opt.cuda = len(opt.gpu)
100+
98101
print(opt)
99102

100103
if torch.cuda.is_available() and not opt.cuda:
101104
print("WARNING: You have a CUDA device, so you should probably run with -cuda")
102105

106+
if opt.cuda:
107+
cuda.set_device(opt.gpu[0])
103108

104109
def NMTCriterion(vocabSize):
105110
weight = torch.ones(vocabSize)
@@ -117,7 +122,7 @@ def memoryEfficientLoss(outputs, targets, generator, crit, eval=False):
117122

118123
batch_size = outputs.size(1)
119124
outputs_split = torch.split(outputs, opt.max_generator_batches)
120-
targets_split = torch.split(targets, opt.max_generator_batches)
125+
targets_split = torch.split(targets.contiguous(), opt.max_generator_batches)
121126
for out_t, targ_t in zip(outputs_split, targets_split):
122127
out_t = out_t.view(-1, out_t.size(2))
123128
pred_t = generator(out_t)
@@ -136,9 +141,9 @@ def eval(model, criterion, data):
136141

137142
model.eval()
138143
for i in range(len(data)):
139-
batch = data[i]
144+
batch = [x.transpose(0, 1) for x in data[i]] # must be batch first for gather/scatter in DataParallel
140145
outputs = model(batch) # FIXME volatile
141-
targets = batch[1][1:] # exclude <s> from targets
146+
targets = batch[1][:, 1:] # exclude <s> from targets
142147
loss, _ = memoryEfficientLoss(
143148
outputs, targets, model.generator, criterion, eval=True)
144149
total_loss += loss
@@ -155,6 +160,7 @@ def trainModel(model, trainData, validData, dataset, optim):
155160
# define criterion of each GPU
156161
criterion = NMTCriterion(dataset['dicts']['tgt'].size())
157162

163+
start_time = time.time()
158164
def trainEpoch(epoch):
159165

160166
# shuffle mini batch order
@@ -167,10 +173,11 @@ def trainEpoch(epoch):
167173

168174
batchIdx = batchOrder[i] if epoch >= opt.curriculum else i
169175
batch = trainData[batchIdx]
176+
batch = [x.transpose(0, 1) for x in batch] # must be batch first for gather/scatter in DataParallel
170177

171178
model.zero_grad()
172179
outputs = model(batch)
173-
targets = batch[1][1:] # exclude <s> from targets
180+
targets = batch[1][:, 1:] # exclude <s> from targets
174181
loss, gradOutput = memoryEfficientLoss(
175182
outputs, targets, model.generator, criterion)
176183

@@ -185,10 +192,11 @@ def trainEpoch(epoch):
185192
total_words += num_words
186193
report_words += num_words
187194
if i % opt.log_interval == 0 and i > 0:
188-
print("Epoch %2d, %5d/%5d batches; perplexity: %6.2f; %3.0f tokens/s" %
195+
print("Epoch %2d, %5d/%5d batches; perplexity: %6.2f; %3.0f tokens/s; %6.0f s elapsed" %
189196
(epoch, i, len(trainData),
190197
math.exp(report_loss / report_words),
191-
report_words/(time.time()-start)))
198+
report_words/(time.time()-start),
199+
time.time()-start_time))
192200

193201
report_loss = report_words = 0
194202
start = time.time()
@@ -249,7 +257,15 @@ def main():
249257
generator = nn.Sequential(
250258
nn.Linear(opt.rnn_size, dicts['tgt'].size()),
251259
nn.LogSoftmax())
260+
generator = nn.DataParallel(generator, device_ids=opt.gpu)
252261
model = onmt.Models.NMTModel(encoder, decoder, generator)
262+
model = nn.DataParallel(model, device_ids=opt.gpu)
263+
if opt.cuda:
264+
model.cuda()
265+
else:
266+
model.cpu()
267+
268+
model.generator = generator
253269

254270
for p in model.parameters():
255271
p.data.uniform_(-opt.param_init, opt.param_init)
@@ -263,14 +279,13 @@ def main():
263279
print('Loading from checkpoint at %s' % opt.train_from)
264280
checkpoint = torch.load(opt.train_from)
265281
model = checkpoint['model']
282+
if opt.cuda:
283+
model.cuda()
284+
else:
285+
model.cpu()
266286
optim = checkpoint['optim']
267287
opt.start_epoch = checkpoint['epoch'] + 1
268288

269-
if opt.cuda:
270-
model.cuda()
271-
else:
272-
model.cpu()
273-
274289
nParams = sum([p.nelement() for p in model.parameters()])
275290
print('* number of parameters: %d' % nParams)
276291

0 commit comments

Comments
 (0)