77import sys
88from tqdm import tqdm
99import json
10+ from plyfile import PlyData , PlyElement
1011
1112def get_segmentation_classes (root ):
1213 catfile = os .path .join (root , 'synsetoffset2category.txt' )
@@ -27,7 +28,7 @@ def get_segmentation_classes(root):
2728 token = (os .path .splitext (os .path .basename (fn ))[0 ])
2829 meta [item ].append ((os .path .join (dir_point , token + '.pts' ), os .path .join (dir_seg , token + '.seg' )))
2930
30- with open (os .path .join (os .path .dirname (os .path .realpath (__file__ )), 'num_seg_classes.txt' ), 'w' ) as f :
31+ with open (os .path .join (os .path .dirname (os .path .realpath (__file__ )), '../misc/ num_seg_classes.txt' ), 'w' ) as f :
3132 for item in cat :
3233 datapath = []
3334 num_seg_classes = 0
@@ -42,6 +43,16 @@ def get_segmentation_classes(root):
4243 print ("category {} num segmentation classes {}" .format (item , num_seg_classes ))
4344 f .write ("{}\t {}\n " .format (item , num_seg_classes ))
4445
46+ def gen_modelnet_id (root ):
47+ classes = []
48+ with open (os .path .join (root , 'train.txt' ), 'r' ) as f :
49+ for line in f :
50+ classes .append (line .strip ().split ('/' )[0 ])
51+ classes = np .unique (classes )
52+ with open (os .path .join (os .path .dirname (os .path .realpath (__file__ )), '../misc/modelnet_id.txt' ), 'w' ) as f :
53+ for i in range (len (classes )):
54+ f .write ('{}\t {}\n ' .format (classes [i ], i ))
55+
4556class ShapeNetDataset (data .Dataset ):
4657 def __init__ (self ,
4758 root ,
@@ -88,7 +99,7 @@ def __init__(self,
8899
89100 self .classes = dict (zip (sorted (self .cat ), range (len (self .cat ))))
90101 print (self .classes )
91- with open (os .path .join (os .path .dirname (os .path .realpath (__file__ )), 'num_seg_classes.txt' ), 'r' ) as f :
102+ with open (os .path .join (os .path .dirname (os .path .realpath (__file__ )), '../misc/ num_seg_classes.txt' ), 'r' ) as f :
92103 for line in f :
93104 ls = line .strip ().split ()
94105 self .seg_classes [ls [0 ]] = int (ls [1 ])
@@ -129,18 +140,76 @@ def __getitem__(self, index):
129140 def __len__ (self ):
130141 return len (self .datapath )
131142
143+ class ModelNetDataset (data .Dataset ):
144+ def __init__ (self ,
145+ root ,
146+ npoints = 2500 ,
147+ split = 'train' ,
148+ data_augmentation = True ):
149+ self .npoints = npoints
150+ self .root = root
151+ self .split = split
152+ self .data_augmentation = data_augmentation
153+ self .fns = []
154+ with open (os .path .join (root , '{}.txt' .format (self .split )), 'r' ) as f :
155+ for line in f :
156+ self .fns .append (line .strip ())
157+
158+ self .cat = {}
159+ with open (os .path .join (os .path .dirname (os .path .realpath (__file__ )), '../misc/modelnet_id.txt' ), 'r' ) as f :
160+ for line in f :
161+ ls = line .strip ().split ()
162+ self .cat [ls [0 ]] = int (ls [1 ])
163+
164+ print (self .cat )
165+ self .classes = list (self .cat .keys ())
166+
167+ def __getitem__ (self , index ):
168+ fn = self .fns [index ]
169+ cls = self .cat [fn .split ('/' )[0 ]]
170+ with open (os .path .join (self .root , fn ), 'rb' ) as f :
171+ plydata = PlyData .read (f )
172+ pts = np .vstack ([plydata ['vertex' ]['x' ], plydata ['vertex' ]['y' ], plydata ['vertex' ]['z' ]]).T
173+ choice = np .random .choice (len (pts ), self .npoints , replace = True )
174+ point_set = pts [choice , :]
175+
176+ point_set = point_set - np .expand_dims (np .mean (point_set , axis = 0 ), 0 ) # center
177+ dist = np .max (np .sqrt (np .sum (point_set ** 2 , axis = 1 )), 0 )
178+ point_set = point_set / dist # scale
179+
180+ if self .data_augmentation :
181+ theta = np .random .uniform (0 , np .pi * 2 )
182+ rotation_matrix = np .array ([[np .cos (theta ), - np .sin (theta )], [np .sin (theta ), np .cos (theta )]])
183+ point_set [:, [0 , 2 ]] = point_set [:, [0 , 2 ]].dot (rotation_matrix ) # random rotation
184+ point_set += np .random .normal (0 , 0.02 , size = point_set .shape ) # random jitter
185+
186+ point_set = torch .from_numpy (point_set .astype (np .float32 ))
187+ cls = torch .from_numpy (np .array ([cls ]).astype (np .int64 ))
188+ return point_set , cls
189+
190+
191+ def __len__ (self ):
192+ return len (self .fns )
132193
133194if __name__ == '__main__' :
134- datapath = sys .argv [1 ]
135- print ('test' )
136- d = ShapeNetDataset (root = datapath , class_choice = ['Chair' ])
137- print (len (d ))
138- ps , seg = d [0 ]
139- print (ps .size (), ps .type (), seg .size (),seg .type ())
140-
141- d = ShapeNetDataset (root = datapath , classification = True )
142- print (len (d ))
143- ps , cls = d [0 ]
144- print (ps .size (), ps .type (), cls .size (),cls .type ())
145-
146- #get_segmentation_classes(datapath)
195+ dataset = sys .argv [1 ]
196+ datapath = sys .argv [2 ]
197+
198+ if dataset == 'shapenet' :
199+ d = ShapeNetDataset (root = datapath , class_choice = ['Chair' ])
200+ print (len (d ))
201+ ps , seg = d [0 ]
202+ print (ps .size (), ps .type (), seg .size (),seg .type ())
203+
204+ d = ShapeNetDataset (root = datapath , classification = True )
205+ print (len (d ))
206+ ps , cls = d [0 ]
207+ print (ps .size (), ps .type (), cls .size (),cls .type ())
208+ # get_segmentation_classes(datapath)
209+
210+ if dataset == 'modelnet' :
211+ gen_modelnet_id (datapath )
212+ d = ModelNetDataset (root = datapath )
213+ print (len (d ))
214+ print (d [0 ])
215+
0 commit comments