PointNet++详解与代码
在之前的一篇文章《PointNet:3D點集分類與分割深度學習模型》中分析了PointNet網絡是如何進行3D點云數據分類與分割的。但是PointNet存在的一個缺點是無法獲得局部特征,這使得它很難對復雜場景進行分析。在PointNet++中,作者通過兩個主要的方法進行了改進,使得網絡能更好的提取局部特征。第一,利用空間距離(metric space distances),使用PointNet對點集局部區域進行特征迭代提取,使其能夠學到局部尺度越來越大的特征。第二,由于點集分布很多時候是不均勻的,如果默認是均勻的,會使得網絡性能變差,所以作者提出了一種自適應密度的特征提取方法。通過以上兩種方法,能夠更高效的學習特征,也更有魯棒性。
(2021-1-27日補充):這是PointNet作者2021年分享的報告《3D物體檢測發展與未來》,對3D物體檢測感興趣的朋友可以看看。
【PointNet作者親述】90分鐘帶你了解3D物體檢測算法和未來方向!
補充:下面的視頻是PointNet++作者分享的報告《點云上的深度學習及其在三維場景理解中的應用》,里面有詳細介紹PointNet++(將門創投 | 斯坦福大學在讀博士生祁芮中臺:點云上的深度學習及其在三維場景理解中的應用_嗶哩嗶哩_bilibili)。
將門創投 | 斯坦福大學在讀博士生祁芮中臺:點云上的深度學習及其在三維場景理解中的應用
目錄
1.PointNet不足之處
2. PointNet++網絡結構
2.1 Sample layer
2.2 Grouping layer
2.3 PointNet layer
2.4 點云分布不一致的處理方法
2.5 Point Feature Propagation for Set Segmentation
2.6 Classification
2.7 Part Segmentation
2.8 Scene Segmentation
3. 參考資料
1.PointNet不足之處
在卷積神經網絡中,3D CNN和2D CNN很像,也可以通過多級學習不斷進行提取,同時也具有著卷積的平移不變性。
而在PointNet中 網絡對每一個點做低維到高維的映射進行特征學習,然后把所有點映射到高維的特征通過最大池化最終表示全局特征。從本質上來說,要么對一個點做操作,要么對所有點做操作,實際上沒有局部的概念(loal context) 。同時也缺少local context 在平移不變性上也有局限性。(世界坐標系和局部坐標系)。對點云數據做平移操作后,所有的數據都將發生變化,導致所有的特征,全局特征都不一樣了。對于單個的物體還好,可以將其平移到坐標系的中心,把他的大小歸一化到一個球中,但是在一個場景中有多個物體時則不好辦,需要對哪個物體做歸一化呢?
在PointNet++中,作者利用所在空間的距離度量將點集劃分(partition)為有重疊的局部區域。在此基礎上,首先在小范圍中從幾何結構中提取局部特征(淺層特征),然后擴大范圍,在這些局部特征的基礎上提取更高層次的特征,直到提取到整個點集的全局特征。可以發現,這個過程和CNN網絡的特征提取過程類似,首先提取低級別的特征,隨著感受野的增大,提取的特征level越來越高。
PointNet++需要解決兩個關鍵的問題:第一,如何將點集劃分為不同的區域;第二,如何利用特征提取器獲取不同區域的局部特征。這兩個問題實際上是相關的,要想通過特征提取器來對不同的區域進行特征提取,需要每個分區具有相同的結構。這里同樣可以類比CNN來理解,在CNN中,卷積塊作為基本的特征提取器,對應的區域都是(n, n)的像素區域。而在3D點集當中,同樣需要找到結構相同的子區域,和對應的區域特征提取器。
在本文中,作者使用了PointNet作為特征提取器,另外一個問題就是如何來劃分點集從而產生結構相同的區域。作者使用鄰域球來定義分區,每個區域可以通過中心坐標和半徑來確定。中心坐標的選取,作者使用了最遠點采樣算法算法來實現(farthest point sampling (FPS) algorithm)。
2. PointNet++網絡結構
PointNet++是PointNet的延伸,在PointNet的基礎上加入了多層次結構(hierarchical structure),使得網絡能夠在越來越大的區域上提供更高級別的特征。
網絡的每一組set abstraction layers主要包括3個部分:Sampling layer, Grouping layer and PointNet layer。
· Sample layer:主要是對輸入點進行采樣,在這些點中選出若干個中心點;· Grouping layer:是利用上一步得到的中心點將點集劃分成若干個區域;
· PointNet layer:是對上述得到的每個區域進行編碼,變成特征向量。
每一組提取層的輸入是,其中N是輸入點的數量,d是坐標維度,C是特征維度。輸出是,其中N'是輸出點的數量,d是坐標維度不變,C'是新的特征維度。下面詳細介紹每一層的作用及實現過程。
2.1 Sample layer
使用farthest point sampling(FPS)選擇N'個點,至于為什么選擇使用這種方法選擇點,文中提到相比于隨機采樣,這種方法能更好的的覆蓋整個點集。具體選擇多少個中心點,數量怎么確定,可以看做是超參數視數據規模來定。
FPS算法原理為:
其Python實現代碼為:
def farthest_point_sample(xyz, npoint):"""Input:xyz: pointcloud data, [B, N, 3]npoint: number of samplesReturn:centroids: sampled pointcloud index, [B, npoint, 3]"""device = xyz.deviceB, N, C = xyz.shapecentroids = torch.zeros(B, npoint, dtype=torch.long).to(device) # 采樣點矩陣(B, npoint)distance = torch.ones(B, N).to(device) * 1e10 # 采樣點到所有點距離(B, N)farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) # 最遠點,初試時隨機選擇一點點batch_indices = torch.arange(B, dtype=torch.long).to(device) # batch_size 數組for i in range(npoint):centroids[:, i] = farthest # 更新第i個最遠點centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) # 取出這個最遠點的xyz坐標dist = torch.sum((xyz - centroid) ** 2, -1) # 計算點集中的所有點到這個最遠點的歐式距離mask = dist < distance distance[mask] = dist[mask] # 更新distances,記錄樣本中每個點距離所有已出現的采樣點的最小距離farthest = torch.max(distance, -1)[1] # 返回最遠點索引return centroids2.2 Grouping layer
這一層使用Ball query方法對sample layers采樣的點生成個對應的局部區域,根據論文中的意思,這里使用到兩個超參數?,一個是每個區域中點的數量K,另一個是query的半徑r。這里半徑應該是占主導的,在某個半徑的球內找點,點的數量上限是K。球的半徑和每個區域中點的數量都是超參數。
代碼為:
def square_distance(src, dst):"""Calculate Euclid distance between each two points.src^T * dst = xn * xm + yn * ym + zn * zm;sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dstInput:src: source points, [B, N, C]dst: target points, [B, M, C]Output:dist: per-point square distance, [B, N, M]"""B, N, _ = src.shape_, M, _ = dst.shapedist = -2 * torch.matmul(src, dst.permute(0, 2, 1))dist += torch.sum(src ** 2, -1).view(B, N, 1)dist += torch.sum(dst ** 2, -1).view(B, 1, M)return distdef query_ball_point(radius, nsample, xyz, new_xyz):"""Input:radius: local region radiusnsample: max sample number in local regionxyz: all points, [B, N, 3]new_xyz: query points, [B, S, 3]Return:group_idx: grouped points index, [B, S, nsample]"""device = xyz.deviceB, N, C = xyz.shape_, S, _ = new_xyz.shapegroup_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])sqrdists = square_distance(new_xyz, xyz)group_idx[sqrdists > radius ** 2] = Ngroup_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])mask = group_idx == Ngroup_idx[mask] = group_first[mask]return group_idx2.3 PointNet layer
這一層是PointNet網絡,輸入為局部區域:。輸出是。需要注意的是,在輸入到網絡之前,會把該區域中的點變成圍繞中心點的相對坐標。作者提到,這樣做能夠獲取點與點之間的關系。至此則完成了set abstraction工作,set abstraction代碼為:
class PointNetSetAbstraction(nn.Module):def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):super(PointNetSetAbstraction, self).__init__()self.npoint = npointself.radius = radiusself.nsample = nsampleself.mlp_convs = nn.ModuleList()self.mlp_bns = nn.ModuleList()last_channel = in_channelfor out_channel in mlp:self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))self.mlp_bns.append(nn.BatchNorm2d(out_channel))last_channel = out_channelself.group_all = group_alldef forward(self, xyz, points):"""Input:xyz: input points position data, [B, C, N]points: input points data, [B, D, N]Return:new_xyz: sampled points position data, [B, C, S]new_points_concat: sample points feature data, [B, D', S]"""xyz = xyz.permute(0, 2, 1)if points is not None:points = points.permute(0, 2, 1)if self.group_all:new_xyz, new_points = sample_and_group_all(xyz, points)else:new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)# new_xyz: sampled points position data, [B, npoint, C]# new_points: sampled points data, [B, npoint, nsample, C+D]new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]for i, conv in enumerate(self.mlp_convs):bn = self.mlp_bns[i]new_points = F.relu(bn(conv(new_points)))new_points = torch.max(new_points, 2)[0]new_xyz = new_xyz.permute(0, 2, 1)return new_xyz, new_points2.4 點云分布不一致的處理方法
點云分布不一致時,每個子區域中如果在生成的時候使用相同的半徑r,會導致有些區域采樣點過少。
作者提到這個問題需要解決,并且提出了兩個方法:Multi-scale grouping (MSG) and Multi-resolution grouping (MRG)。下面是論文當中的示意圖。
下面分別介紹一下這兩種方法。
第一種多尺度分組(MSG),對于同一個中心點,如果使用3個不同尺度的話,就分別找圍繞每個中心點畫3個區域,每個區域的半徑及里面的點的個數不同。對于同一個中心點來說,不同尺度的區域送入不同的PointNet進行特征提取,之后concat,作為這個中心點的特征。也就是說MSG實際上相當于并聯了多個hierarchical structure,每個結構中心點不變,但是區域范圍不同。PointNet的輸入和輸出尺寸也不同,然后幾個不同尺度的結構在PointNet有一個Concat。代碼是:
class PointNetSetAbstractionMsg(nn.Module):def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):super(PointNetSetAbstractionMsg, self).__init__()self.npoint = npointself.radius_list = radius_listself.nsample_list = nsample_listself.conv_blocks = nn.ModuleList()self.bn_blocks = nn.ModuleList()for i in range(len(mlp_list)):convs = nn.ModuleList()bns = nn.ModuleList()last_channel = in_channel + 3for out_channel in mlp_list[i]:convs.append(nn.Conv2d(last_channel, out_channel, 1))bns.append(nn.BatchNorm2d(out_channel))last_channel = out_channelself.conv_blocks.append(convs)self.bn_blocks.append(bns)def forward(self, xyz, points):"""Input:xyz: input points position data, [B, C, N]points: input points data, [B, D, N]Return:new_xyz: sampled points position data, [B, C, S]new_points_concat: sample points feature data, [B, D', S]"""xyz = xyz.permute(0, 2, 1)if points is not None:points = points.permute(0, 2, 1)B, N, C = xyz.shapeS = self.npointnew_xyz = index_points(xyz, farthest_point_sample(xyz, S))new_points_list = []for i, radius in enumerate(self.radius_list):K = self.nsample_list[i]group_idx = query_ball_point(radius, K, xyz, new_xyz)grouped_xyz = index_points(xyz, group_idx)grouped_xyz -= new_xyz.view(B, S, 1, C)if points is not None:grouped_points = index_points(points, group_idx)grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)else:grouped_points = grouped_xyzgrouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S]for j in range(len(self.conv_blocks[i])):conv = self.conv_blocks[i][j]bn = self.bn_blocks[i][j]grouped_points = F.relu(bn(conv(grouped_points)))new_points = torch.max(grouped_points, 2)[0] # [B, D', S]new_points_list.append(new_points)new_xyz = new_xyz.permute(0, 2, 1)new_points_concat = torch.cat(new_points_list, dim=1)return new_xyz, new_points_concat另一種是多分辨率分組(MRG)。MSG很明顯會影響降低運算速度,所以提出了MRG,這種方法應該是對不同level的grouping做了一個concat,但是由于尺度不同,對于low level的先放入一個pointnet進行處理再和high level的進行concat。感覺和ResNet中的跳躍連接有點類似。
在這部分,作者還提到了一種random input dropout(DP)的方法,就是在輸入到點云之前,對點集進行隨機的Dropout, 比例為95%,也就是說進行95%的比例采樣。
2.5 Point Feature Propagation for Set Segmentation
對于點云分割任務,我們還需要將點集上采樣回原始點集數量,這里使用了分層的差值方法。代碼為:
class PointNetFeaturePropagation(nn.Module):def __init__(self, in_channel, mlp):super(PointNetFeaturePropagation, self).__init__()self.mlp_convs = nn.ModuleList()self.mlp_bns = nn.ModuleList()last_channel = in_channelfor out_channel in mlp:self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))self.mlp_bns.append(nn.BatchNorm1d(out_channel))last_channel = out_channeldef forward(self, xyz1, xyz2, points1, points2):"""Input:xyz1: input points position data, [B, C, N]xyz2: sampled input points position data, [B, C, S]points1: input points data, [B, D, N]points2: input points data, [B, D, S]Return:new_points: upsampled points data, [B, D', N]"""xyz1 = xyz1.permute(0, 2, 1)xyz2 = xyz2.permute(0, 2, 1)points2 = points2.permute(0, 2, 1)B, N, C = xyz1.shape_, S, _ = xyz2.shapeif S == 1:interpolated_points = points2.repeat(1, N, 1)else:dists = square_distance(xyz1, xyz2)dists, idx = dists.sort(dim=-1)dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]dists[dists < 1e-10] = 1e-10weight = 1.0 / dists # [B, N, 3]weight = weight / torch.sum(weight, dim=-1).view(B, N, 1) # [B, N, 3]interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)if points1 is not None:points1 = points1.permute(0, 2, 1)new_points = torch.cat([points1, interpolated_points], dim=-1)else:new_points = interpolated_pointsnew_points = new_points.permute(0, 2, 1)for i, conv in enumerate(self.mlp_convs):bn = self.mlp_bns[i]new_points = F.relu(bn(conv(new_points)))return new_points2.6 Classification
class PointNet2ClsMsg(nn.Module):def __init__(self):super(PointNet2ClsMsg, self).__init__()self.sa1 = PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [16, 32, 128], 0, [[32, 32, 64], [64, 64, 128], [64, 96, 128]])self.sa2 = PointNetSetAbstractionMsg(128, [0.2, 0.4, 0.8], [32, 64, 128], 320, [[64, 64, 128], [128, 128, 256], [128, 128, 256]])self.sa3 = PointNetSetAbstraction(None, None, None, 640 + 3, [256, 512, 1024], True)self.fc1 = nn.Linear(1024, 512)self.bn1 = nn.BatchNorm1d(512)self.drop1 = nn.Dropout(0.4)self.fc2 = nn.Linear(512, 256)self.bn2 = nn.BatchNorm1d(256)self.drop2 = nn.Dropout(0.4)self.fc3 = nn.Linear(256, 40)def forward(self, xyz):B, _, _ = xyz.shapel1_xyz, l1_points = self.sa1(xyz, None)l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)x = l3_points.view(B, 1024)x = self.drop1(F.relu(self.bn1(self.fc1(x))))x = self.drop2(F.relu(self.bn2(self.fc2(x))))x = self.fc3(x)x = F.log_softmax(x, -1)return x2.7 Part Segmentation
class PointNet2PartSeg(nn.Module): def __init__(self, num_classes):super(PointNet2PartSeg, self).__init__()self.sa1 = PointNetSetAbstraction(npoint=512, radius=0.2, nsample=64, in_channel=3, mlp=[64, 64, 128], group_all=False)self.sa2 = PointNetSetAbstraction(npoint=128, radius=0.4, nsample=64, in_channel=128 + 3, mlp=[128, 128, 256], group_all=False)self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=256 + 3, mlp=[256, 512, 1024], group_all=True)self.fp3 = PointNetFeaturePropagation(in_channel=1280, mlp=[256, 256])self.fp2 = PointNetFeaturePropagation(in_channel=384, mlp=[256, 128])self.fp1 = PointNetFeaturePropagation(in_channel=128, mlp=[128, 128, 128])self.conv1 = nn.Conv1d(128, 128, 1)self.bn1 = nn.BatchNorm1d(128)self.drop1 = nn.Dropout(0.5)self.conv2 = nn.Conv1d(128, num_classes, 1)def forward(self, xyz):# Set Abstraction layersl1_xyz, l1_points = self.sa1(xyz, None)l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)# Feature Propagation layersl2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points)l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points)l0_points = self.fp1(xyz, l1_xyz, None, l1_points)# FC layersfeat = F.relu(self.bn1(self.conv1(l0_points)))x = self.drop1(feat)x = self.conv2(x)x = F.log_softmax(x, dim=1)x = x.permute(0, 2, 1)return x, feat2.8 Scene Segmentation
class PointNet2SemSeg(nn.Module):def __init__(self, num_classes):super(PointNet2SemSeg, self).__init__()self.sa1 = PointNetSetAbstraction(1024, 0.1, 32, 3, [32, 32, 64], False)self.sa2 = PointNetSetAbstraction(256, 0.2, 32, 64 + 3, [64, 64, 128], False)self.sa3 = PointNetSetAbstraction(64, 0.4, 32, 128 + 3, [128, 128, 256], False)self.sa4 = PointNetSetAbstraction(16, 0.8, 32, 256 + 3, [256, 256, 512], False)self.fp4 = PointNetFeaturePropagation(768, [256, 256])self.fp3 = PointNetFeaturePropagation(384, [256, 256])self.fp2 = PointNetFeaturePropagation(320, [256, 128])self.fp1 = PointNetFeaturePropagation(128, [128, 128, 128])self.conv1 = nn.Conv1d(128, 128, 1)self.bn1 = nn.BatchNorm1d(128)self.drop1 = nn.Dropout(0.5)self.conv2 = nn.Conv1d(128, num_classes, 1)def forward(self, xyz):l1_xyz, l1_points = self.sa1(xyz, None)l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)l4_xyz, l4_points = self.sa4(l3_xyz, l3_points)l3_points = self.fp4(l3_xyz, l4_xyz, l3_points, l4_points)l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points)l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points)l0_points = self.fp1(xyz, l1_xyz, None, l1_points)x = self.drop1(F.relu(self.bn1(self.conv1(l0_points))))x = self.conv2(x)x = F.log_softmax(x, dim=1)return x3. 參考資料
PointNet++作者分享報告:將門創投 | 斯坦福大學在讀博士生祁芮中臺:點云上的深度學習及其在三維場景理解中的應用_嗶哩嗶哩_bilibili
PointNet++官網鏈接:PointNet++
PointNet++代碼:https://github.com/yanx27/Pointnet_Pointnet2_pytorch
????????????????????????????PointNet++的pytorch實現代碼閱讀
PointNet++作者視頻講解文字版:PointNet++作者的視頻講解文字版 - 一杯明月 - 博客園
總結
以上是生活随笔為你收集整理的PointNet++详解与代码的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: ARM胆子大了:X3超大核单核性能比12
- 下一篇: 国际油价超130美元后国内油价不上调:发