2121
2222parser = argparse .ArgumentParser ()
2323parser .add_argument ('--batchSize' , type = int , default = 32 , help = 'input batch size' )
24+ parser .add_argument ('--num_points' , type = int , default = 2500 , help = 'input batch size' )
2425parser .add_argument ('--workers' , type = int , help = 'number of data loading workers' , default = 4 )
2526parser .add_argument ('--nepoch' , type = int , default = 25 , help = 'number of epochs to train for' )
2627parser .add_argument ('--outf' , type = str , default = 'cls' , help = 'output folder' )
3637random .seed (opt .manualSeed )
3738torch .manual_seed (opt .manualSeed )
3839
39- dataset = PartDataset (root = 'shapenetcore_partanno_segmentation_benchmark_v0' , classification = True )
40+ dataset = PartDataset (root = 'shapenetcore_partanno_segmentation_benchmark_v0' , classification = True , npoints = opt . num_points )
4041dataloader = torch .utils .data .DataLoader (dataset , batch_size = opt .batchSize ,
4142 shuffle = True , num_workers = int (opt .workers ))
4243
43- test_dataset = PartDataset (root = 'shapenetcore_partanno_segmentation_benchmark_v0' , classification = True , train = False )
44+ test_dataset = PartDataset (root = 'shapenetcore_partanno_segmentation_benchmark_v0' , classification = True , train = False , npoints = opt . num_points )
4445testdataloader = torch .utils .data .DataLoader (test_dataset , batch_size = opt .batchSize ,
4546 shuffle = True , num_workers = int (opt .workers ))
4647
5455 pass
5556
5657
57- classifier = PointNetCls (k = num_classes )
58+ classifier = PointNetCls (k = num_classes , num_points = opt . num_points )
5859
5960
6061if opt .model != '' :
6162 classifier .load_state_dict (torch .load (opt .model ))
62-
63-
63+
64+
6465optimizer = optim .SGD (classifier .parameters (), lr = 0.01 , momentum = 0.9 )
6566classifier .cuda ()
6667
7071 for i , data in enumerate (dataloader , 0 ):
7172 points , target = data
7273 points , target = Variable (points ), Variable (target [:,0 ])
73- points = points .transpose (2 ,1 )
74- points , target = points .cuda (), target .cuda ()
74+ points = points .transpose (2 ,1 )
75+ points , target = points .cuda (), target .cuda ()
7576 optimizer .zero_grad ()
7677 pred , _ = classifier (points )
7778 loss = F .nll_loss (pred , target )
8081 pred_choice = pred .data .max (1 )[1 ]
8182 correct = pred_choice .eq (target .data ).cpu ().sum ()
8283 print ('[%d: %d/%d] train loss: %f accuracy: %f' % (epoch , i , num_batch , loss .data [0 ], correct / float (opt .batchSize )))
83-
84+
8485 if i % 10 == 0 :
8586 j , data = enumerate (testdataloader , 0 ).next ()
8687 points , target = data
8788 points , target = Variable (points ), Variable (target [:,0 ])
88- points = points .transpose (2 ,1 )
89- points , target = points .cuda (), target .cuda ()
89+ points = points .transpose (2 ,1 )
90+ points , target = points .cuda (), target .cuda ()
9091 pred , _ = classifier (points )
9192 loss = F .nll_loss (pred , target )
9293 pred_choice = pred .data .max (1 )[1 ]
9394 correct = pred_choice .eq (target .data ).cpu ().sum ()
9495 print ('[%d: %d/%d] %s loss: %f accuracy: %f' % (epoch , i , num_batch , blue ('test' ), loss .data [0 ], correct / float (opt .batchSize )))
95-
96- torch .save (classifier .state_dict (), '%s/cls_model_%d.pth' % (opt .outf , epoch ))
96+
97+ torch .save (classifier .state_dict (), '%s/cls_model_%d.pth' % (opt .outf , epoch ))
0 commit comments