Skip to content

Commit bf3494c

Browse files
committed
fix import
1 parent e08c96e commit bf3494c

File tree

3 files changed

+42
-42
lines changed

3 files changed

+42
-42
lines changed

datasets.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import progressbar
1212
import sys
1313
import torchvision.transforms as transforms
14-
import utils
1514
import argparse
1615
import json
1716

@@ -22,17 +21,17 @@ def __init__(self, root, npoints = 2500, classification = False, class_choice =
2221
self.root = root
2322
self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')
2423
self.cat = {}
25-
24+
2625
self.classification = classification
27-
26+
2827
with open(self.catfile, 'r') as f:
2928
for line in f:
3029
ls = line.strip().split()
3130
self.cat[ls[0]] = ls[1]
3231
#print(self.cat)
3332
if not class_choice is None:
3433
self.cat = {k:v for k,v in self.cat.items() if k in class_choice}
35-
34+
3635
self.meta = {}
3736
for item in self.cat:
3837
#print('category', item)
@@ -45,35 +44,36 @@ def __init__(self, root, npoints = 2500, classification = False, class_choice =
4544
fns = fns[:int(len(fns) * 0.9)]
4645
else:
4746
fns = fns[int(len(fns) * 0.9):]
48-
47+
4948
#print(os.path.basename(fns))
5049
for fn in fns:
51-
token = (os.path.splitext(os.path.basename(fn))[0])
50+
token = (os.path.splitext(os.path.basename(fn))[0])
5251
self.meta[item].append((os.path.join(dir_point, token + '.pts'), os.path.join(dir_seg, token + '.seg')))
53-
52+
5453
self.datapath = []
5554
for item in self.cat:
5655
for fn in self.meta[item]:
5756
self.datapath.append((item, fn[0], fn[1]))
58-
59-
60-
self.classes = dict(zip(self.cat, range(len(self.cat))))
57+
58+
59+
self.classes = dict(zip(self.cat, range(len(self.cat))))
60+
print(self.classes)
6161
self.num_seg_classes = 0
6262
if not self.classification:
6363
for i in range(len(self.datapath)/50):
6464
l = len(np.unique(np.loadtxt(self.datapath[i][-1]).astype(np.uint8)))
6565
if l > self.num_seg_classes:
6666
self.num_seg_classes = l
6767
#print(self.num_seg_classes)
68-
69-
68+
69+
7070
def __getitem__(self, index):
7171
fn = self.datapath[index]
7272
cls = self.classes[self.datapath[index][0]]
7373
point_set = np.loadtxt(fn[1]).astype(np.float32)
7474
seg = np.loadtxt(fn[2]).astype(np.int64)
7575
#print(point_set.shape, seg.shape)
76-
76+
7777
choice = np.random.choice(len(seg), self.npoints, replace=True)
7878
#resample
7979
point_set = point_set[choice, :]
@@ -85,7 +85,7 @@ def __getitem__(self, index):
8585
return point_set, cls
8686
else:
8787
return point_set, seg
88-
88+
8989
def __len__(self):
9090
return len(self.datapath)
9191

@@ -96,8 +96,8 @@ def __len__(self):
9696
print(len(d))
9797
ps, seg = d[0]
9898
print(ps.size(), ps.type(), seg.size(),seg.type())
99-
99+
100100
d = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = True)
101101
print(len(d))
102102
ps, cls = d[0]
103-
print(ps.size(), ps.type(), cls.size(),cls.type())
103+
print(ps.size(), ps.type(), cls.size(),cls.type())

pointnet.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import numpy as np
1616
import matplotlib.pyplot as plt
1717
import pdb
18-
import utils
1918
import torch.nn.functional as F
2019

2120

@@ -31,14 +30,14 @@ def __init__(self, num_points = 2500):
3130
self.fc2 = nn.Linear(512, 256)
3231
self.fc3 = nn.Linear(256, 9)
3332
self.relu = nn.ReLU()
34-
33+
3534
self.bn1 = nn.BatchNorm1d(64)
3635
self.bn2 = nn.BatchNorm1d(128)
3736
self.bn3 = nn.BatchNorm1d(1024)
3837
self.bn4 = nn.BatchNorm1d(512)
3938
self.bn5 = nn.BatchNorm1d(256)
40-
41-
39+
40+
4241
def forward(self, x):
4342
batchsize = x.size()[0]
4443
x = F.relu(self.bn1(self.conv1(x)))
@@ -57,12 +56,12 @@ def forward(self, x):
5756
x = x + iden
5857
x = x.view(-1, 3, 3)
5958
return x
60-
61-
59+
60+
6261
class PointNetfeat(nn.Module):
6362
def __init__(self, num_points = 2500, global_feat = True):
6463
super(PointNetfeat, self).__init__()
65-
self.stn = STN3d()
64+
self.stn = STN3d(num_points = num_points)
6665
self.conv1 = torch.nn.Conv1d(3, 64, 1)
6766
self.conv2 = torch.nn.Conv1d(64, 128, 1)
6867
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
@@ -89,7 +88,7 @@ def forward(self, x):
8988
else:
9089
x = x.view(-1, 1024, 1).repeat(1, 1, self.num_points)
9190
return torch.cat([x, pointfeat], 1), trans
92-
91+
9392
class PointNetCls(nn.Module):
9493
def __init__(self, num_points = 2500, k = 2):
9594
super(PointNetCls, self).__init__()
@@ -121,7 +120,7 @@ def __init__(self, num_points = 2500, k = 2):
121120
self.bn1 = nn.BatchNorm1d(512)
122121
self.bn2 = nn.BatchNorm1d(256)
123122
self.bn3 = nn.BatchNorm1d(128)
124-
123+
125124
def forward(self, x):
126125
batchsize = x.size()[0]
127126
x, trans = self.feat(x)
@@ -133,26 +132,26 @@ def forward(self, x):
133132
x = F.log_softmax(x.view(-1,self.k))
134133
x = x.view(batchsize, self.num_points, self.k)
135134
return x, trans
136-
135+
137136

138137
if __name__ == '__main__':
139138
sim_data = Variable(torch.rand(32,3,2500))
140139
trans = STN3d()
141140
out = trans(sim_data)
142141
print('stn', out.size())
143-
142+
144143
pointfeat = PointNetfeat(global_feat=True)
145144
out, _ = pointfeat(sim_data)
146145
print('global feat', out.size())
147146

148147
pointfeat = PointNetfeat(global_feat=False)
149148
out, _ = pointfeat(sim_data)
150149
print('point feat', out.size())
151-
150+
152151
cls = PointNetCls(k = 5)
153152
out, _ = cls(sim_data)
154153
print('class', out.size())
155-
154+
156155
seg = PointNetDenseCls(k = 3)
157156
out, _ = seg(sim_data)
158-
print('seg', out.size())
157+
print('seg', out.size())

train_classification.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
parser = argparse.ArgumentParser()
2323
parser.add_argument('--batchSize', type=int, default=32, help='input batch size')
24+
parser.add_argument('--num_points', type=int, default=2500, help='input batch size')
2425
parser.add_argument('--workers', type=int, help='number of data loading workers', default=4)
2526
parser.add_argument('--nepoch', type=int, default=25, help='number of epochs to train for')
2627
parser.add_argument('--outf', type=str, default='cls', help='output folder')
@@ -36,11 +37,11 @@
3637
random.seed(opt.manualSeed)
3738
torch.manual_seed(opt.manualSeed)
3839

39-
dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = True)
40+
dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = True, npoints = opt.num_points)
4041
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
4142
shuffle=True, num_workers=int(opt.workers))
4243

43-
test_dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = True, train = False)
44+
test_dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = True, train = False, npoints = opt.num_points)
4445
testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batchSize,
4546
shuffle=True, num_workers=int(opt.workers))
4647

@@ -54,13 +55,13 @@
5455
pass
5556

5657

57-
classifier = PointNetCls(k = num_classes)
58+
classifier = PointNetCls(k = num_classes, num_points = opt.num_points)
5859

5960

6061
if opt.model != '':
6162
classifier.load_state_dict(torch.load(opt.model))
62-
63-
63+
64+
6465
optimizer = optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9)
6566
classifier.cuda()
6667

@@ -70,8 +71,8 @@
7071
for i, data in enumerate(dataloader, 0):
7172
points, target = data
7273
points, target = Variable(points), Variable(target[:,0])
73-
points = points.transpose(2,1)
74-
points, target = points.cuda(), target.cuda()
74+
points = points.transpose(2,1)
75+
points, target = points.cuda(), target.cuda()
7576
optimizer.zero_grad()
7677
pred, _ = classifier(points)
7778
loss = F.nll_loss(pred, target)
@@ -80,17 +81,17 @@
8081
pred_choice = pred.data.max(1)[1]
8182
correct = pred_choice.eq(target.data).cpu().sum()
8283
print('[%d: %d/%d] train loss: %f accuracy: %f' %(epoch, i, num_batch, loss.data[0], correct/float(opt.batchSize)))
83-
84+
8485
if i % 10 == 0:
8586
j, data = enumerate(testdataloader, 0).next()
8687
points, target = data
8788
points, target = Variable(points), Variable(target[:,0])
88-
points = points.transpose(2,1)
89-
points, target = points.cuda(), target.cuda()
89+
points = points.transpose(2,1)
90+
points, target = points.cuda(), target.cuda()
9091
pred, _ = classifier(points)
9192
loss = F.nll_loss(pred, target)
9293
pred_choice = pred.data.max(1)[1]
9394
correct = pred_choice.eq(target.data).cpu().sum()
9495
print('[%d: %d/%d] %s loss: %f accuracy: %f' %(epoch, i, num_batch, blue('test'), loss.data[0], correct/float(opt.batchSize)))
95-
96-
torch.save(classifier.state_dict(), '%s/cls_model_%d.pth' % (opt.outf, epoch))
96+
97+
torch.save(classifier.state_dict(), '%s/cls_model_%d.pth' % (opt.outf, epoch))

0 commit comments

Comments
 (0)