diff --git a/datasets.py b/datasets.py index 52910f452..d96d6753a 100644 --- a/datasets.py +++ b/datasets.py @@ -8,7 +8,6 @@ import json import codecs import numpy as np -import progressbar import sys import torchvision.transforms as transforms import argparse diff --git a/pointnet.py b/pointnet.py index 2bb766723..37091cd84 100644 --- a/pointnet.py +++ b/pointnet.py @@ -19,13 +19,11 @@ class STN3d(nn.Module): - def __init__(self, num_points = 2500): + def __init__(self): super(STN3d, self).__init__() - self.num_points = num_points self.conv1 = torch.nn.Conv1d(3, 64, 1) self.conv2 = torch.nn.Conv1d(64, 128, 1) self.conv3 = torch.nn.Conv1d(128, 1024, 1) - self.mp1 = torch.nn.MaxPool1d(num_points) self.fc1 = nn.Linear(1024, 512) self.fc2 = nn.Linear(512, 256) self.fc3 = nn.Linear(256, 9) @@ -43,7 +41,7 @@ def forward(self, x): x = F.relu(self.bn1(self.conv1(x))) x = F.relu(self.bn2(self.conv2(x))) x = F.relu(self.bn3(self.conv3(x))) - x = self.mp1(x) + x = torch.max(x, 2, keepdim=True)[0] x = x.view(-1, 1024) x = F.relu(self.bn4(self.fc1(x))) @@ -59,20 +57,19 @@ def forward(self, x): class PointNetfeat(nn.Module): - def __init__(self, num_points = 2500, global_feat = True): + def __init__(self, global_feat = True): super(PointNetfeat, self).__init__() - self.stn = STN3d(num_points = num_points) + self.stn = STN3d() self.conv1 = torch.nn.Conv1d(3, 64, 1) self.conv2 = torch.nn.Conv1d(64, 128, 1) self.conv3 = torch.nn.Conv1d(128, 1024, 1) self.bn1 = nn.BatchNorm1d(64) self.bn2 = nn.BatchNorm1d(128) self.bn3 = nn.BatchNorm1d(1024) - self.mp1 = torch.nn.MaxPool1d(num_points) - self.num_points = num_points self.global_feat = global_feat def forward(self, x): batchsize = x.size()[0] + n_pts = x.size()[2] trans = self.stn(x) x = x.transpose(2,1) x = torch.bmm(x, trans) @@ -81,19 +78,18 @@ def forward(self, x): pointfeat = x x = F.relu(self.bn2(self.conv2(x))) x = self.bn3(self.conv3(x)) - x = self.mp1(x) + x = torch.max(x, 2, keepdim=True)[0] x = x.view(-1, 1024) if self.global_feat: return x, trans else: - x = x.view(-1, 1024, 1).repeat(1, 1, self.num_points) + x = x.view(-1, 1024, 1).repeat(1, 1, n_pts) return torch.cat([x, pointfeat], 1), trans class PointNetCls(nn.Module): - def __init__(self, num_points = 2500, k = 2): + def __init__(self, k = 2): super(PointNetCls, self).__init__() - self.num_points = num_points - self.feat = PointNetfeat(num_points, global_feat=True) + self.feat = PointNetfeat(global_feat=True) self.fc1 = nn.Linear(1024, 512) self.fc2 = nn.Linear(512, 256) self.fc3 = nn.Linear(256, k) @@ -105,14 +101,13 @@ def forward(self, x): x = F.relu(self.bn1(self.fc1(x))) x = F.relu(self.bn2(self.fc2(x))) x = self.fc3(x) - return F.log_softmax(x, dim=-1), trans + return F.log_softmax(x, dim=0), trans class PointNetDenseCls(nn.Module): - def __init__(self, num_points = 2500, k = 2): + def __init__(self, k = 2): super(PointNetDenseCls, self).__init__() - self.num_points = num_points self.k = k - self.feat = PointNetfeat(num_points, global_feat=False) + self.feat = PointNetfeat(global_feat=False) self.conv1 = torch.nn.Conv1d(1088, 512, 1) self.conv2 = torch.nn.Conv1d(512, 256, 1) self.conv3 = torch.nn.Conv1d(256, 128, 1) @@ -123,6 +118,7 @@ def __init__(self, num_points = 2500, k = 2): def forward(self, x): batchsize = x.size()[0] + n_pts = x.size()[2] x, trans = self.feat(x) x = F.relu(self.bn1(self.conv1(x))) x = F.relu(self.bn2(self.conv2(x))) @@ -130,7 +126,7 @@ def forward(self, x): x = self.conv4(x) x = x.transpose(2,1).contiguous() x = F.log_softmax(x.view(-1,self.k), dim=-1) - x = x.view(batchsize, self.num_points, self.k) + x = x.view(batchsize, n_pts, self.k) return x, trans