pointnet.pytorch代码解析

pointnet.pytorch代码解析

代码运行

Training

cd utils
python train_classification.py --dataset <dataset path> --nepoch=<number epochs> --dataset_type <modelnet40 | shapenet>
python train_segmentation.py --dataset <dataset path> --nepoch=<number epochs> 

运行结果

  1. Classification on ShapeNet

    epoch = 10 Overall Acc
    Original implementation N/A
    this implementation(无 feature transform) 95.6
    this implementation(有 feature transform) 92.97
  2. Segmentation on ShapeNet

dataset代码

  1. 读取的数据格式

    ShapeNetDataset():默认读取分割数据,返回值d:点云个数*(点云数据ps,标签seg)

    数据ps:torch.Size([2500, 3]) torch.FloatTensor ,一个点云有2500个点,每个点3个特征

    标签seg:torch.Size([2500]) torch.LongTensor,每个点都有一个标签

    代码及注释如下:

    if __name__ == '__main__':
    dataset = sys.argv[1]           # 运行命令中传入的第一个参数
    datapath = sys.argv[2]          # 运行命令中传入的第二个参数
    
    if dataset == 'shapenet':
        # 读取标签为Chair的分割数据
        d = ShapeNetDataset(root = datapath, class_choice = ['Chair'])
        print(len(d))   #2658,共有2658个Chair点云
        ps, seg = d[0]
        print(ps.size(), ps.type(), seg.size(),seg.type())
        # torch.Size([2500, 3]) torch.FloatTensor ,第一个点云有2500个点,每个点3个特征
        # torch.Size([2500]) torch.LongTensor,每个点都有一个标签
    
        d = ShapeNetDataset(root = datapath, classification = True)
        print(len(d))
        ps, cls = d[0]
        print(ps.size(), ps.type(), cls.size(),cls.type())
        # torch.Size([2500, 3]) torch.FloatTensor torch.Size([1]) torch.LongTensor,每个点云一个标签
        # get_segmentation_classes(datapath)
    
  2. 数据读取

model代码

  1. 网络整体结构

    if __name__ == '__main__':
        # input transform
        sim_data = Variable(torch.rand(32,3,2500))          # 32个点云,3个特征,2500个点
        trans = STN3d()
        out = trans(sim_data)                               # stn torch.Size([32, 3, 3]),返回3x3的输入变换矩阵
        print('stn', out.size())
        print('loss', feature_transform_regularizer(out))
    
        # feature transform
        sim_data_64d = Variable(torch.rand(32, 64, 2500))
        trans = STNkd(k=64)
        out = trans(sim_data_64d)                           # stn64d torch.Size([32, 64, 64]),返回64x64的特征变换矩阵
        print('stn64d', out.size())
        print('loss', feature_transform_regularizer(out))
    
        # global feat
        pointfeat = PointNetfeat(global_feat=True)
        out, _, _ = pointfeat(sim_data)                     # global feat torch.Size([32, 1024]),32个点云,每个有1024维全局特征
        print('global feat', out.size())
    
        # point feat
        pointfeat = PointNetfeat(global_feat=False)
        out, _, _ = pointfeat(sim_data)                     # point feat torch.Size([32, 1088, 2500]),2500个点,每个点有1024+64维特征
        print('point feat', out.size())
    
        # Classification
        cls = PointNetCls(k = 5)
        out, _, _ = cls(sim_data)                           # class torch.Size([32, 5]),global feat经过全连接层,得到在5个类别上的概率信息
        print('class', out.size())
    
        # Segmentation 
        seg = PointNetDenseCls(k = 3)
        out, _, _ = seg(sim_data)                           # seg torch.Size([32, 2500, 3]),point feat经过一维卷积,得到在3个类别上概率信息
        print('seg', out.size())
    
  2. PointNetfeat特征提取网络

    class PointNetfeat(nn.Module):
        '''
        点云的特征提取网络:global feature 和 point features
        '''
        def __init__(self, global_feat = True, feature_transform = False):
            super(PointNetfeat, self).__init__()
            self.stn = STN3d()
            self.conv1 = torch.nn.Conv1d(3, 64, 1)
            self.conv2 = torch.nn.Conv1d(64, 128, 1)
            self.conv3 = torch.nn.Conv1d(128, 1024, 1)
            self.bn1 = nn.BatchNorm1d(64)
            self.bn2 = nn.BatchNorm1d(128)
            self.bn3 = nn.BatchNorm1d(1024)
            self.global_feat = global_feat
            self.feature_transform = feature_transform
            if self.feature_transform:
                self.fstn = STNkd(k=64)
    
        def forward(self, x):
            n_pts = x.size()[2]
            trans = self.stn(x)
            x = x.transpose(2, 1)
            x = torch.bmm(x, trans)                 # 乘以3x3变换矩阵
            x = x.transpose(2, 1)
            x = F.relu(self.bn1(self.conv1(x)))
    
            if self.feature_transform:              # 特征变换,64x64矩阵
                trans_feat = self.fstn(x)
                x = x.transpose(2,1)
                x = torch.bmm(x, trans_feat)
                x = x.transpose(2,1)
            else:
                trans_feat = None
    
            pointfeat = x                           # nx64的点特征
            x = F.relu(self.bn2(self.conv2(x)))
            x = self.bn3(self.conv3(x))
            x = torch.max(x, 2, keepdim=True)[0]    # Maxpool
            x = x.view(-1, 1024)
            if self.global_feat:
                return x, trans, trans_feat         # x:mx1x1024的global feature,两个变换矩阵
            else:
                x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
                return torch.cat([x, pointfeat], 1), trans, trans_feat      # global feature+point features = nx1088的点特征矩阵    
    
posted @ 2021-01-04 17:47  Dawn嗯  阅读(681)  评论(0)    收藏  举报