Skip to content

Commit ff32133

Browse files
committed
add show classification results
1 parent bf3494c commit ff32133

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

show_cls.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from __future__ import print_function
2+
import argparse
3+
import os
4+
import random
5+
import numpy as np
6+
import torch
7+
import torch.nn as nn
8+
import torch.nn.parallel
9+
import torch.backends.cudnn as cudnn
10+
import torch.optim as optim
11+
import torch.utils.data
12+
import torchvision.datasets as dset
13+
import torchvision.transforms as transforms
14+
import torchvision.utils as vutils
15+
from torch.autograd import Variable
16+
from datasets import PartDataset
17+
from pointnet import PointNetCls
18+
import torch.nn.functional as F
19+
import matplotlib.pyplot as plt
20+
21+
22+
#showpoints(np.random.randn(2500,3), c1 = np.random.uniform(0,1,size = (2500)))
23+
24+
parser = argparse.ArgumentParser()
25+
26+
parser.add_argument('--model', type=str, default = '', help='model path')
27+
28+
29+
opt = parser.parse_args()
30+
print (opt)
31+
32+
test_dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0' , train = False, classification = True)
33+
34+
testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle = False)
35+
36+
37+
classifier = PointNetCls(k = len(test_dataset.classes))
38+
classifier.cuda()
39+
classifier.load_state_dict(torch.load(opt.model))
40+
classifier.eval()
41+
42+
for i, data in enumerate(testdataloader, 0):
43+
points, target = data
44+
points, target = Variable(points), Variable(target[:,0])
45+
points = points.transpose(2,1)
46+
points, target = points.cuda(), target.cuda()
47+
pred, _ = classifier(points)
48+
loss = F.nll_loss(pred, target)
49+
pred_choice = pred.data.max(1)[1]
50+
correct = pred_choice.eq(target.data).cpu().sum()
51+
print('i:%d loss: %f accuracy: %f' %(i, loss.data[0], correct/float(32)))

0 commit comments

Comments
 (0)