Skip to content

Commit 5787367

Browse files
committed
add modelnet script
1 parent 3c7e2cd commit 5787367

File tree

4 files changed

+154
-29
lines changed

4 files changed

+154
-29
lines changed

misc/modelnet_id.txt

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
airplane 0
2+
bathtub 1
3+
bed 2
4+
bench 3
5+
bookshelf 4
6+
bottle 5
7+
bowl 6
8+
car 7
9+
chair 8
10+
cone 9
11+
cup 10
12+
curtain 11
13+
desk 12
14+
door 13
15+
dresser 14
16+
flower_pot 15
17+
glass_box 16
18+
guitar 17
19+
keyboard 18
20+
lamp 19
21+
laptop 20
22+
mantel 21
23+
monitor 22
24+
night_stand 23
25+
person 24
26+
piano 25
27+
plant 26
28+
radio 27
29+
range_hood 28
30+
sink 29
31+
sofa 30
32+
stairs 31
33+
stool 32
34+
table 33
35+
tent 34
36+
toilet 35
37+
tv_stand 36
38+
vase 37
39+
wardrobe 38
40+
xbox 39
File renamed without changes.

pointnet/dataset.py

Lines changed: 84 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import sys
88
from tqdm import tqdm
99
import json
10+
from plyfile import PlyData, PlyElement
1011

1112
def get_segmentation_classes(root):
1213
catfile = os.path.join(root, 'synsetoffset2category.txt')
@@ -27,7 +28,7 @@ def get_segmentation_classes(root):
2728
token = (os.path.splitext(os.path.basename(fn))[0])
2829
meta[item].append((os.path.join(dir_point, token + '.pts'), os.path.join(dir_seg, token + '.seg')))
2930

30-
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'num_seg_classes.txt'), 'w') as f:
31+
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/num_seg_classes.txt'), 'w') as f:
3132
for item in cat:
3233
datapath = []
3334
num_seg_classes = 0
@@ -42,6 +43,16 @@ def get_segmentation_classes(root):
4243
print("category {} num segmentation classes {}".format(item, num_seg_classes))
4344
f.write("{}\t{}\n".format(item, num_seg_classes))
4445

46+
def gen_modelnet_id(root):
47+
classes = []
48+
with open(os.path.join(root, 'train.txt'), 'r') as f:
49+
for line in f:
50+
classes.append(line.strip().split('/')[0])
51+
classes = np.unique(classes)
52+
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/modelnet_id.txt'), 'w') as f:
53+
for i in range(len(classes)):
54+
f.write('{}\t{}\n'.format(classes[i], i))
55+
4556
class ShapeNetDataset(data.Dataset):
4657
def __init__(self,
4758
root,
@@ -88,7 +99,7 @@ def __init__(self,
8899

89100
self.classes = dict(zip(sorted(self.cat), range(len(self.cat))))
90101
print(self.classes)
91-
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'num_seg_classes.txt'), 'r') as f:
102+
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/num_seg_classes.txt'), 'r') as f:
92103
for line in f:
93104
ls = line.strip().split()
94105
self.seg_classes[ls[0]] = int(ls[1])
@@ -129,18 +140,76 @@ def __getitem__(self, index):
129140
def __len__(self):
130141
return len(self.datapath)
131142

143+
class ModelNetDataset(data.Dataset):
144+
def __init__(self,
145+
root,
146+
npoints=2500,
147+
split='train',
148+
data_augmentation=True):
149+
self.npoints = npoints
150+
self.root = root
151+
self.split = split
152+
self.data_augmentation = data_augmentation
153+
self.fns = []
154+
with open(os.path.join(root, '{}.txt'.format(self.split)), 'r') as f:
155+
for line in f:
156+
self.fns.append(line.strip())
157+
158+
self.cat = {}
159+
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/modelnet_id.txt'), 'r') as f:
160+
for line in f:
161+
ls = line.strip().split()
162+
self.cat[ls[0]] = int(ls[1])
163+
164+
print(self.cat)
165+
self.classes = list(self.cat.keys())
166+
167+
def __getitem__(self, index):
168+
fn = self.fns[index]
169+
cls = self.cat[fn.split('/')[0]]
170+
with open(os.path.join(self.root, fn), 'rb') as f:
171+
plydata = PlyData.read(f)
172+
pts = np.vstack([plydata['vertex']['x'], plydata['vertex']['y'], plydata['vertex']['z']]).T
173+
choice = np.random.choice(len(pts), self.npoints, replace=True)
174+
point_set = pts[choice, :]
175+
176+
point_set = point_set - np.expand_dims(np.mean(point_set, axis=0), 0) # center
177+
dist = np.max(np.sqrt(np.sum(point_set ** 2, axis=1)), 0)
178+
point_set = point_set / dist # scale
179+
180+
if self.data_augmentation:
181+
theta = np.random.uniform(0, np.pi * 2)
182+
rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
183+
point_set[:, [0, 2]] = point_set[:, [0, 2]].dot(rotation_matrix) # random rotation
184+
point_set += np.random.normal(0, 0.02, size=point_set.shape) # random jitter
185+
186+
point_set = torch.from_numpy(point_set.astype(np.float32))
187+
cls = torch.from_numpy(np.array([cls]).astype(np.int64))
188+
return point_set, cls
189+
190+
191+
def __len__(self):
192+
return len(self.fns)
132193

133194
if __name__ == '__main__':
134-
datapath = sys.argv[1]
135-
print('test')
136-
d = ShapeNetDataset(root = datapath, class_choice = ['Chair'])
137-
print(len(d))
138-
ps, seg = d[0]
139-
print(ps.size(), ps.type(), seg.size(),seg.type())
140-
141-
d = ShapeNetDataset(root = datapath, classification = True)
142-
print(len(d))
143-
ps, cls = d[0]
144-
print(ps.size(), ps.type(), cls.size(),cls.type())
145-
146-
#get_segmentation_classes(datapath)
195+
dataset = sys.argv[1]
196+
datapath = sys.argv[2]
197+
198+
if dataset == 'shapenet':
199+
d = ShapeNetDataset(root = datapath, class_choice = ['Chair'])
200+
print(len(d))
201+
ps, seg = d[0]
202+
print(ps.size(), ps.type(), seg.size(),seg.type())
203+
204+
d = ShapeNetDataset(root = datapath, classification = True)
205+
print(len(d))
206+
ps, cls = d[0]
207+
print(ps.size(), ps.type(), cls.size(),cls.type())
208+
# get_segmentation_classes(datapath)
209+
210+
if dataset == 'modelnet':
211+
gen_modelnet_id(datapath)
212+
d = ModelNetDataset(root=datapath)
213+
print(len(d))
214+
print(d[0])
215+

utils/train_classification.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch.nn.parallel
77
import torch.optim as optim
88
import torch.utils.data
9-
from pointnet.dataset import ShapeNetDataset
9+
from pointnet.dataset import ShapeNetDataset, ModelNetDataset
1010
from pointnet.model import PointNetCls
1111
import torch.nn.functional as F
1212
from tqdm import tqdm
@@ -24,6 +24,7 @@
2424
parser.add_argument('--outf', type=str, default='cls', help='output folder')
2525
parser.add_argument('--model', type=str, default='', help='model path')
2626
parser.add_argument('--dataset', type=str, required=True, help="dataset path")
27+
parser.add_argument('--dataset_type', type=str, default='shapenet', help="dataset type shapenet|modelnet40")
2728

2829
opt = parser.parse_args()
2930
print(opt)
@@ -35,26 +36,41 @@
3536
random.seed(opt.manualSeed)
3637
torch.manual_seed(opt.manualSeed)
3738

38-
dataset = ShapeNetDataset(
39-
root=opt.dataset,
40-
classification=True,
41-
npoints=opt.num_points)
39+
if opt.dataset_type == 'shapenet':
40+
dataset = ShapeNetDataset(
41+
root=opt.dataset,
42+
classification=True,
43+
npoints=opt.num_points)
44+
45+
test_dataset = ShapeNetDataset(
46+
root=opt.dataset,
47+
classification=True,
48+
split='test',
49+
npoints=opt.num_points)
50+
elif opt.dataset_type == 'modelnet40':
51+
dataset = ModelNetDataset(
52+
root=opt.dataset,
53+
npoints=opt.num_points)
54+
55+
test_dataset = ModelNetDataset(
56+
root=opt.dataset,
57+
split='test',
58+
npoints=opt.num_points)
59+
else:
60+
exit('wrong dataset type')
61+
62+
4263
dataloader = torch.utils.data.DataLoader(
4364
dataset,
4465
batch_size=opt.batchSize,
4566
shuffle=True,
4667
num_workers=int(opt.workers))
4768

48-
test_dataset = ShapeNetDataset(
49-
root=opt.dataset,
50-
classification=True,
51-
split='test',
52-
npoints=opt.num_points)
5369
testdataloader = torch.utils.data.DataLoader(
54-
test_dataset,
55-
batch_size=opt.batchSize,
56-
shuffle=True,
57-
num_workers=int(opt.workers))
70+
test_dataset,
71+
batch_size=opt.batchSize,
72+
shuffle=True,
73+
num_workers=int(opt.workers))
5874

5975
print(len(dataset), len(test_dataset))
6076
num_classes = len(dataset.classes)

0 commit comments

Comments
 (0)