Skip to content

Commit 523f07f

Browse files
committed
add some benchmarking code
1 parent 0ce9a4b commit 523f07f

File tree

7 files changed

+92
-32
lines changed

7 files changed

+92
-32
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ shapenetcore_partanno_segmentation_benchmark_v0/
77
.idea*
88
cls/
99
seg/
10+
*.egg-info/

pointnet/__init__.py

Whitespace-only changes.

pointnet/dataset.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@ def __init__(self,
1212
npoints=2500,
1313
classification=False,
1414
class_choice=None,
15-
train=True):
15+
train=True,
16+
data_augmentation=True):
1617
self.npoints = npoints
1718
self.root = root
1819
self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')
1920
self.cat = {}
21+
self.data_augmentation = data_augmentation
2022

2123
self.classification = classification
2224

@@ -73,10 +75,22 @@ def __getitem__(self, index):
7375
choice = np.random.choice(len(seg), self.npoints, replace=True)
7476
#resample
7577
point_set = point_set[choice, :]
78+
79+
point_set = point_set - np.expand_dims(np.mean(point_set, axis = 0), 0) # center
80+
dist = np.max(np.sqrt(np.sum(point_set ** 2, axis = 1)),0)
81+
point_set = point_set / dist #scale
82+
83+
if self.data_augmentation:
84+
theta = np.random.uniform(0,np.pi*2)
85+
rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)],[np.sin(theta), np.cos(theta)]])
86+
point_set[:,[0,2]] = point_set[:,[0,2]].dot(rotation_matrix) # random rotation
87+
point_set += np.random.normal(0, 0.02, size=point_set.shape) # random jitter
88+
7689
seg = seg[choice]
7790
point_set = torch.from_numpy(point_set)
7891
seg = torch.from_numpy(seg)
7992
cls = torch.from_numpy(np.array([cls]).astype(np.int64))
93+
8094
if self.classification:
8195
return point_set, cls
8296
else:

pointnet/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,16 +84,17 @@ def __init__(self, k = 2):
8484
self.fc1 = nn.Linear(1024, 512)
8585
self.fc2 = nn.Linear(512, 256)
8686
self.fc3 = nn.Linear(256, k)
87+
self.dropout = nn.Dropout(p=0.3)
8788
self.bn1 = nn.BatchNorm1d(512)
8889
self.bn2 = nn.BatchNorm1d(256)
8990
self.relu = nn.ReLU()
9091

9192
def forward(self, x):
9293
x, trans = self.feat(x)
9394
x = F.relu(self.bn1(self.fc1(x)))
94-
x = F.relu(self.bn2(self.fc2(x)))
95+
x = F.relu(self.bn2(self.dropout(self.fc2(x))))
9596
x = self.fc3(x)
96-
return F.log_softmax(x, dim=0), trans
97+
return F.log_softmax(x, dim=1), trans
9798

9899
class PointNetDenseCls(nn.Module):
99100
def __init__(self, k = 2):

utils/show_seg.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,34 +17,31 @@
1717

1818
parser.add_argument('--model', type=str, default='', help='model path')
1919
parser.add_argument('--idx', type=int, default=0, help='model index')
20-
21-
20+
parser.add_argument('--dataset', type=str, default='', help='dataset path')
21+
parser.add_argument('--class_choice', type=str, default='', help='class choice')
2222

2323
opt = parser.parse_args()
2424
print(opt)
2525

2626
d = PartDataset(
27-
root='shapenetcore_partanno_segmentation_benchmark_v0',
28-
class_choice=['Airplane'],
27+
root=opt.dataset,
28+
class_choice=[opt.class_choice],
2929
train=False)
3030

3131
idx = opt.idx
3232

3333
print("model %d/%d" % (idx, len(d)))
34-
3534
point, seg = d[idx]
3635
print(point.size(), seg.size())
37-
3836
point_np = point.numpy()
3937

40-
41-
4238
cmap = plt.cm.get_cmap("hsv", 10)
4339
cmap = np.array([cmap(i) for i in range(10)])[:, :3]
4440
gt = cmap[seg.numpy() - 1, :]
4541

46-
classifier = PointNetDenseCls(k=4)
47-
classifier.load_state_dict(torch.load(opt.model))
42+
state_dict = torch.load(opt.model)
43+
classifier = PointNetDenseCls(k= state_dict['conv4.weight'].size()[0] )
44+
classifier.load_state_dict(state_dict)
4845
classifier.eval()
4946

5047
point = point.transpose(1, 0).contiguous()

utils/train_classification.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,10 @@
66
import torch.nn.parallel
77
import torch.optim as optim
88
import torch.utils.data
9-
from torch.autograd import Variable
109
from pointnet.dataset import PartDataset
1110
from pointnet.model import PointNetCls
1211
import torch.nn.functional as F
13-
12+
from tqdm import tqdm
1413

1514

1615
parser = argparse.ArgumentParser()
@@ -21,9 +20,10 @@
2120
parser.add_argument(
2221
'--workers', type=int, help='number of data loading workers', default=4)
2322
parser.add_argument(
24-
'--nepoch', type=int, default=25, help='number of epochs to train for')
23+
'--nepoch', type=int, default=250, help='number of epochs to train for')
2524
parser.add_argument('--outf', type=str, default='cls', help='output folder')
2625
parser.add_argument('--model', type=str, default='', help='model path')
26+
parser.add_argument('--dataset', type=str, required=True, help="dataset path")
2727

2828
opt = parser.parse_args()
2929
print(opt)
@@ -36,7 +36,7 @@
3636
torch.manual_seed(opt.manualSeed)
3737

3838
dataset = PartDataset(
39-
root='shapenetcore_partanno_segmentation_benchmark_v0',
39+
root=opt.dataset,
4040
classification=True,
4141
npoints=opt.num_points)
4242
dataloader = torch.utils.data.DataLoader(
@@ -46,7 +46,7 @@
4646
num_workers=int(opt.workers))
4747

4848
test_dataset = PartDataset(
49-
root='shapenetcore_partanno_segmentation_benchmark_v0',
49+
root=opt.dataset,
5050
classification=True,
5151
train=False,
5252
npoints=opt.num_points)
@@ -65,22 +65,23 @@
6565
except OSError:
6666
pass
6767

68-
6968
classifier = PointNetCls(k=num_classes)
7069

7170
if opt.model != '':
7271
classifier.load_state_dict(torch.load(opt.model))
7372

7473

75-
optimizer = optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9)
74+
optimizer = optim.Adam(classifier.parameters(), lr=0.001, betas=(0.9, 0.999))
75+
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
7676
classifier.cuda()
7777

7878
num_batch = len(dataset) / opt.batchSize
7979

8080
for epoch in range(opt.nepoch):
81+
scheduler.step()
8182
for i, data in enumerate(dataloader, 0):
8283
points, target = data
83-
points, target = Variable(points), Variable(target[:, 0])
84+
target = target[:, 0]
8485
points = points.transpose(2, 1)
8586
points, target = points.cuda(), target.cuda()
8687
optimizer.zero_grad()
@@ -96,7 +97,7 @@
9697
if i % 10 == 0:
9798
j, data = next(enumerate(testdataloader, 0))
9899
points, target = data
99-
points, target = Variable(points), Variable(target[:, 0])
100+
target = target[:, 0]
100101
points = points.transpose(2, 1)
101102
points, target = points.cuda(), target.cuda()
102103
classifier = classifier.eval()
@@ -107,3 +108,19 @@
107108
print('[%d: %d/%d] %s loss: %f accuracy: %f' % (epoch, i, num_batch, blue('test'), loss.item(), correct.item()/float(opt.batchSize)))
108109

109110
torch.save(classifier.state_dict(), '%s/cls_model_%d.pth' % (opt.outf, epoch))
111+
112+
total_correct = 0
113+
total_testset = 0
114+
for i,data in tqdm(enumerate(testdataloader, 0)):
115+
points, target = data
116+
target = target[:, 0]
117+
points = points.transpose(2, 1)
118+
points, target = points.cuda(), target.cuda()
119+
classifier = classifier.eval()
120+
pred, _ = classifier(points)
121+
pred_choice = pred.data.max(1)[1]
122+
correct = pred_choice.eq(target.data).cpu().sum()
123+
total_correct += correct.item()
124+
total_testset += points.size()[0]
125+
126+
print("final accuracy {}".format(total_correct / float(total_testset)))

utils/train_segmentation.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
import torch.nn.parallel
77
import torch.optim as optim
88
import torch.utils.data
9-
from torch.autograd import Variable
109
from pointnet.dataset import PartDataset
1110
from pointnet.model import PointNetDenseCls
1211
import torch.nn.functional as F
12+
from tqdm import tqdm
13+
import numpy as np
1314

1415

1516
parser = argparse.ArgumentParser()
@@ -21,6 +22,8 @@
2122
'--nepoch', type=int, default=25, help='number of epochs to train for')
2223
parser.add_argument('--outf', type=str, default='seg', help='output folder')
2324
parser.add_argument('--model', type=str, default='', help='model path')
25+
parser.add_argument('--dataset', type=str, required=True, help="dataset path")
26+
parser.add_argument('--class_choice', type=str, default='Chair', help="class_choice")
2427

2528

2629
opt = parser.parse_args()
@@ -32,19 +35,19 @@
3235
torch.manual_seed(opt.manualSeed)
3336

3437
dataset = PartDataset(
35-
root='shapenetcore_partanno_segmentation_benchmark_v0',
38+
root=opt.dataset,
3639
classification=False,
37-
class_choice=['Chair'])
40+
class_choice=[opt.class_choice])
3841
dataloader = torch.utils.data.DataLoader(
3942
dataset,
4043
batch_size=opt.batchSize,
4144
shuffle=True,
4245
num_workers=int(opt.workers))
4346

4447
test_dataset = PartDataset(
45-
root='shapenetcore_partanno_segmentation_benchmark_v0',
48+
root=opt.dataset,
4649
classification=False,
47-
class_choice=['Chair'],
50+
class_choice=[opt.class_choice],
4851
train=False)
4952
testdataloader = torch.utils.data.DataLoader(
5053
test_dataset,
@@ -67,15 +70,16 @@
6770
if opt.model != '':
6871
classifier.load_state_dict(torch.load(opt.model))
6972

70-
optimizer = optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9)
73+
optimizer = optim.Adam(classifier.parameters(), lr=0.001, betas=(0.9, 0.999))
74+
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
7175
classifier.cuda()
7276

7377
num_batch = len(dataset) / opt.batchSize
7478

7579
for epoch in range(opt.nepoch):
80+
scheduler.step()
7681
for i, data in enumerate(dataloader, 0):
7782
points, target = data
78-
points, target = Variable(points), Variable(target)
7983
points = points.transpose(2, 1)
8084
points, target = points.cuda(), target.cuda()
8185
optimizer.zero_grad()
@@ -94,17 +98,43 @@
9498
if i % 10 == 0:
9599
j, data = next(enumerate(testdataloader, 0))
96100
points, target = data
97-
points, target = Variable(points), Variable(target)
98101
points = points.transpose(2, 1)
99102
points, target = points.cuda(), target.cuda()
100103
classifier = classifier.eval()
101104
pred, _ = classifier(points)
102105
pred = pred.view(-1, num_classes)
103106
target = target.view(-1, 1)[:, 0] - 1
104-
105107
loss = F.nll_loss(pred, target)
106108
pred_choice = pred.data.max(1)[1]
107109
correct = pred_choice.eq(target.data).cpu().sum()
108110
print('[%d: %d/%d] %s loss: %f accuracy: %f' % (epoch, i, num_batch, blue('test'), loss.item(), correct.item()/float(opt.batchSize * 2500)))
109111

110-
torch.save(classifier.state_dict(), '%s/seg_model_%d.pth' % (opt.outf, epoch))
112+
torch.save(classifier.state_dict(), '%s/seg_model_%s_%d.pth' % (opt.outf, opt.class_choice, epoch))
113+
114+
## benchmark mIOU
115+
shape_ious = []
116+
for i,data in tqdm(enumerate(testdataloader, 0)):
117+
points, target = data
118+
points = points.transpose(2, 1)
119+
points, target = points.cuda(), target.cuda()
120+
classifier = classifier.eval()
121+
pred, _ = classifier(points)
122+
pred_choice = pred.data.max(2)[1]
123+
124+
pred_np = pred_choice.cpu().data.numpy()
125+
target_np = target.cpu().data.numpy() - 1
126+
127+
for shape_idx in range(target_np.shape[0]):
128+
parts = np.unique(target_np[shape_idx])
129+
part_ious = []
130+
for part in parts:
131+
I = np.sum(np.logical_and(pred_np[shape_idx] == part, target_np[shape_idx] == part))
132+
U = np.sum(np.logical_or(pred_np[shape_idx] == part, target_np[shape_idx] == part))
133+
if U == 0:
134+
iou = 0
135+
else:
136+
iou = I / float(U)
137+
part_ious.append(iou)
138+
shape_ious.append(np.mean(part_ious))
139+
140+
print("mIOU for class {}: {}".format(opt.class_choice, np.mean(shape_ious)))

0 commit comments

Comments
 (0)