1919
2020
2121class STN3d (nn .Module ):
22- def __init__ (self , num_points = 2500 ):
22+ def __init__ (self ):
2323 super (STN3d , self ).__init__ ()
24- self .num_points = num_points
2524 self .conv1 = torch .nn .Conv1d (3 , 64 , 1 )
2625 self .conv2 = torch .nn .Conv1d (64 , 128 , 1 )
2726 self .conv3 = torch .nn .Conv1d (128 , 1024 , 1 )
28- self .mp1 = torch .nn .MaxPool1d (num_points )
2927 self .fc1 = nn .Linear (1024 , 512 )
3028 self .fc2 = nn .Linear (512 , 256 )
3129 self .fc3 = nn .Linear (256 , 9 )
@@ -43,7 +41,7 @@ def forward(self, x):
4341 x = F .relu (self .bn1 (self .conv1 (x )))
4442 x = F .relu (self .bn2 (self .conv2 (x )))
4543 x = F .relu (self .bn3 (self .conv3 (x )))
46- x = self . mp1 ( x )
44+ x = torch . max ( x , 2 , keepdim = True )[ 0 ]
4745 x = x .view (- 1 , 1024 )
4846
4947 x = F .relu (self .bn4 (self .fc1 (x )))
@@ -59,20 +57,19 @@ def forward(self, x):
5957
6058
6159class PointNetfeat (nn .Module ):
62- def __init__ (self , num_points = 2500 , global_feat = True ):
60+ def __init__ (self , global_feat = True ):
6361 super (PointNetfeat , self ).__init__ ()
64- self .stn = STN3d (num_points = num_points )
62+ self .stn = STN3d ()
6563 self .conv1 = torch .nn .Conv1d (3 , 64 , 1 )
6664 self .conv2 = torch .nn .Conv1d (64 , 128 , 1 )
6765 self .conv3 = torch .nn .Conv1d (128 , 1024 , 1 )
6866 self .bn1 = nn .BatchNorm1d (64 )
6967 self .bn2 = nn .BatchNorm1d (128 )
7068 self .bn3 = nn .BatchNorm1d (1024 )
71- self .mp1 = torch .nn .MaxPool1d (num_points )
72- self .num_points = num_points
7369 self .global_feat = global_feat
7470 def forward (self , x ):
7571 batchsize = x .size ()[0 ]
72+ n_pts = x .size ()[2 ]
7673 trans = self .stn (x )
7774 x = x .transpose (2 ,1 )
7875 x = torch .bmm (x , trans )
@@ -81,19 +78,18 @@ def forward(self, x):
8178 pointfeat = x
8279 x = F .relu (self .bn2 (self .conv2 (x )))
8380 x = self .bn3 (self .conv3 (x ))
84- x = self . mp1 ( x )
81+ x = torch . max ( x , 2 , keepdim = True )[ 0 ]
8582 x = x .view (- 1 , 1024 )
8683 if self .global_feat :
8784 return x , trans
8885 else :
89- x = x .view (- 1 , 1024 , 1 ).repeat (1 , 1 , self . num_points )
86+ x = x .view (- 1 , 1024 , 1 ).repeat (1 , 1 , n_pts )
9087 return torch .cat ([x , pointfeat ], 1 ), trans
9188
9289class PointNetCls (nn .Module ):
93- def __init__ (self , num_points = 2500 , k = 2 ):
90+ def __init__ (self , k = 2 ):
9491 super (PointNetCls , self ).__init__ ()
95- self .num_points = num_points
96- self .feat = PointNetfeat (num_points , global_feat = True )
92+ self .feat = PointNetfeat (global_feat = True )
9793 self .fc1 = nn .Linear (1024 , 512 )
9894 self .fc2 = nn .Linear (512 , 256 )
9995 self .fc3 = nn .Linear (256 , k )
@@ -105,14 +101,13 @@ def forward(self, x):
105101 x = F .relu (self .bn1 (self .fc1 (x )))
106102 x = F .relu (self .bn2 (self .fc2 (x )))
107103 x = self .fc3 (x )
108- return F .log_softmax (x , dim = - 1 ), trans
104+ return F .log_softmax (x , dim = 0 ), trans
109105
110106class PointNetDenseCls (nn .Module ):
111- def __init__ (self , num_points = 2500 , k = 2 ):
107+ def __init__ (self , k = 2 ):
112108 super (PointNetDenseCls , self ).__init__ ()
113- self .num_points = num_points
114109 self .k = k
115- self .feat = PointNetfeat (num_points , global_feat = False )
110+ self .feat = PointNetfeat (global_feat = False )
116111 self .conv1 = torch .nn .Conv1d (1088 , 512 , 1 )
117112 self .conv2 = torch .nn .Conv1d (512 , 256 , 1 )
118113 self .conv3 = torch .nn .Conv1d (256 , 128 , 1 )
@@ -123,14 +118,15 @@ def __init__(self, num_points = 2500, k = 2):
123118
124119 def forward (self , x ):
125120 batchsize = x .size ()[0 ]
121+ n_pts = x .size ()[2 ]
126122 x , trans = self .feat (x )
127123 x = F .relu (self .bn1 (self .conv1 (x )))
128124 x = F .relu (self .bn2 (self .conv2 (x )))
129125 x = F .relu (self .bn3 (self .conv3 (x )))
130126 x = self .conv4 (x )
131127 x = x .transpose (2 ,1 ).contiguous ()
132128 x = F .log_softmax (x .view (- 1 ,self .k ), dim = - 1 )
133- x = x .view (batchsize , self . num_points , self .k )
129+ x = x .view (batchsize , n_pts , self .k )
134130 return x , trans
135131
136132
0 commit comments