Skip to content

Commit ae12d25

Browse files
committed
updated for PyTorch 0.4.0 and Python 3
1 parent 2e6245b commit ae12d25

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

train.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,8 @@ def main():
6666
if args.augment:
6767
transform_train = transforms.Compose([
6868
transforms.ToTensor(),
69-
transforms.Lambda(lambda x: F.pad(
70-
Variable(x.unsqueeze(0), requires_grad=False, volatile=True),
71-
(4,4,4,4),mode='reflect').data.squeeze()),
69+
transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
70+
(4,4,4,4),mode='reflect').squeeze()),
7271
transforms.ToPILImage(),
7372
transforms.RandomCrop(32),
7473
transforms.RandomHorizontalFlip(),
@@ -146,7 +145,7 @@ def main():
146145
'state_dict': model.state_dict(),
147146
'best_prec1': best_prec1,
148147
}, is_best)
149-
print 'Best accuracy: ', best_prec1
148+
print('Best accuracy: ', best_prec1)
150149

151150
def train(train_loader, model, criterion, optimizer, epoch):
152151
"""Train for one epoch on the training set"""
@@ -170,8 +169,8 @@ def train(train_loader, model, criterion, optimizer, epoch):
170169

171170
# measure accuracy and record loss
172171
prec1 = accuracy(output.data, target, topk=(1,))[0]
173-
losses.update(loss.data[0], input.size(0))
174-
top1.update(prec1[0], input.size(0))
172+
losses.update(loss.data.item(), input.size(0))
173+
top1.update(prec1.item(), input.size(0))
175174

176175
# compute gradient and do SGD step
177176
optimizer.zero_grad()
@@ -207,17 +206,18 @@ def validate(val_loader, model, criterion, epoch):
207206
for i, (input, target) in enumerate(val_loader):
208207
target = target.cuda(async=True)
209208
input = input.cuda()
210-
input_var = torch.autograd.Variable(input, volatile=True)
211-
target_var = torch.autograd.Variable(target, volatile=True)
209+
input_var = torch.autograd.Variable(input)
210+
target_var = torch.autograd.Variable(target)
212211

213212
# compute output
214-
output = model(input_var)
213+
with torch.no_grad():
214+
output = model(input_var)
215215
loss = criterion(output, target_var)
216216

217217
# measure accuracy and record loss
218218
prec1 = accuracy(output.data, target, topk=(1,))[0]
219-
losses.update(loss.data[0], input.size(0))
220-
top1.update(prec1[0], input.size(0))
219+
losses.update(loss.data.item(), input.size(0))
220+
top1.update(prec1.item(), input.size(0))
221221

222222
# measure elapsed time
223223
batch_time.update(time.time() - end)

wideresnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0
3636
self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
3737
def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
3838
layers = []
39-
for i in range(nb_layers):
39+
for i in range(int(nb_layers)):
4040
layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
4141
return nn.Sequential(*layers)
4242
def forward(self, x):

0 commit comments

Comments
 (0)