@@ -66,9 +66,8 @@ def main():
6666 if args .augment :
6767 transform_train = transforms .Compose ([
6868 transforms .ToTensor (),
69- transforms .Lambda (lambda x : F .pad (
70- Variable (x .unsqueeze (0 ), requires_grad = False , volatile = True ),
71- (4 ,4 ,4 ,4 ),mode = 'reflect' ).data .squeeze ()),
69+ transforms .Lambda (lambda x : F .pad (x .unsqueeze (0 ),
70+ (4 ,4 ,4 ,4 ),mode = 'reflect' ).squeeze ()),
7271 transforms .ToPILImage (),
7372 transforms .RandomCrop (32 ),
7473 transforms .RandomHorizontalFlip (),
@@ -146,7 +145,7 @@ def main():
146145 'state_dict' : model .state_dict (),
147146 'best_prec1' : best_prec1 ,
148147 }, is_best )
149- print 'Best accuracy: ' , best_prec1
148+ print ( 'Best accuracy: ' , best_prec1 )
150149
151150def train (train_loader , model , criterion , optimizer , epoch ):
152151 """Train for one epoch on the training set"""
@@ -170,8 +169,8 @@ def train(train_loader, model, criterion, optimizer, epoch):
170169
171170 # measure accuracy and record loss
172171 prec1 = accuracy (output .data , target , topk = (1 ,))[0 ]
173- losses .update (loss .data [ 0 ] , input .size (0 ))
174- top1 .update (prec1 [ 0 ] , input .size (0 ))
172+ losses .update (loss .data . item () , input .size (0 ))
173+ top1 .update (prec1 . item () , input .size (0 ))
175174
176175 # compute gradient and do SGD step
177176 optimizer .zero_grad ()
@@ -207,17 +206,18 @@ def validate(val_loader, model, criterion, epoch):
207206 for i , (input , target ) in enumerate (val_loader ):
208207 target = target .cuda (async = True )
209208 input = input .cuda ()
210- input_var = torch .autograd .Variable (input , volatile = True )
211- target_var = torch .autograd .Variable (target , volatile = True )
209+ input_var = torch .autograd .Variable (input )
210+ target_var = torch .autograd .Variable (target )
212211
213212 # compute output
214- output = model (input_var )
213+ with torch .no_grad ():
214+ output = model (input_var )
215215 loss = criterion (output , target_var )
216216
217217 # measure accuracy and record loss
218218 prec1 = accuracy (output .data , target , topk = (1 ,))[0 ]
219- losses .update (loss .data [ 0 ] , input .size (0 ))
220- top1 .update (prec1 [ 0 ] , input .size (0 ))
219+ losses .update (loss .data . item () , input .size (0 ))
220+ top1 .update (prec1 . item () , input .size (0 ))
221221
222222 # measure elapsed time
223223 batch_time .update (time .time () - end )
0 commit comments