PointNet:3D点集分类与分割深度学习模型
之前的一篇博客《動手學無人駕駛(4):基于激光雷達點云數據3D目標檢測》里介紹到了如何基于PointRCNN模型來進行3D目標檢測,作者使用的主干網是PointNet++,而PointNet++又是基于PointNet來實現的。今天寫的這篇博客就是對PointNet網絡進行詳細介紹。
(2021-1-27日補充):這是PointNet作者2021年分享的報告《3D物體檢測發展與未來》,對3D物體檢測感興趣的朋友可以看看。
【PointNet作者親述】90分鐘帶你了解3D物體檢測算法和未來方向!
補充:下面的視頻是PointNet作者分享的報告《點云上的深度學習及其在三維場景理解中的應用》,里面有詳細介紹PointNet(https://www.bilibili.com/video/BV1As411377S/?spm_id_from=333.788.videocard.1)。
將門創投 | 斯坦福大學在讀博士生祁芮中臺:點云上的深度學習及其在三維場景理解中的應用
1.PointNet論文解讀
前言
PointNet網絡
2.PointNet源碼
參考資料
1.PointNet論文解讀
前言
隨著大數據和深度學習的興起,涌現了許許多多的3D應用,與此同時需要一種數據驅動的方式去理解和處理三維數據,這就是:3D deep learning。
三維數據本身有一定的復雜性,2D圖像可以輕易的表示成矩陣,3D表達形式主要分為以下幾種:
- point cloud :深度傳感器掃描得到的深度數據,點云。
- Mesh:三角面片在計算機圖形學中渲染和建模話會很有用。
- Volumetric:將空間劃分成三維網格,柵格化。
- Multi-View:用多個角度的圖片表示物體。
Point cloud 是一種非常適合于3D場景理解的數據,原因是:
- 點云是非常接近原始傳感器的數據集,激光雷達掃描之后就是點云,原始的數據可以做端到端的深度學習。
- 點云在表達形式上是比較簡單的,一組點。相比較來說Mesh需要選擇面片類型和如何連接;網格需要選擇多大的網格,分辨?? 率;圖像的選擇,需要選擇拍攝的角度,但是表達是不全面的。
最近才有一些方法研究直接在點云上進行特征學習,之前的大部分工作都是集中在手工設計點云數據的。這些特征都是針對特定任務,有不同的假設,新的任務很難優化特征。
但是點云數據是一種不規則的數據,在空間上和數量上可以任意分布,之前的研究者在點云上會先把它轉化成一個規則的數據,比如柵格讓其均勻分布,然后再用3D CNN來處理柵格數據。3D CNN復雜度相當的高,是三次方的增長,所以分辨率不高,相比圖像是很低的。
但是如果考慮不計復雜度的柵格,會導致大量的柵格都是空白,智能掃描到表面,內部都是空白的。所以柵格并不是對3D點云很好的一種表達方式,也有人考慮過,用3D點云數據投影到2D平面上用2D cnn 進行訓練,這樣會損失3D的信息。 還要決定的投影的角度。點云中提取手工的特征,再接FC,這么做有很大的局限性
我們能否直接用一種在點云上學習的方法?
PointNet網絡
我們的目標是提出一種端到端的點云多任務處理框架,包括目標分類,目標零件分類以及場景語義解析。
點云輸入數據處理
點云是數據的表達點的集合,網絡模型應對點云的排列方式不敏感,如下圖所示,對于N個具有D維特征的點云數據,排列方式可能有N!種,我們希望我們的網絡模型能夠對于N!排列方式點云數據能夠保持同樣的學習效果。
神將網絡本質上是一個函數,我們希望找到一個對稱函數,能夠對于點云數據具有置換不變性。如取最大值函數,無論輸入怎么變換,最后的結果都是輸入的最大值。
雖然是置換不變的,但是這種方式只計算了最遠點的邊界,損失了很多有意義的幾何信息,如何解決呢?與其說直接做對稱性可以先把每個點映射到高維空間,在高維空間中做對稱性的操作,高維空間可以是一個冗余的,在max操作中通過冗余可以避免信息的丟失,可以保留足夠的點云信息,再通過一個網絡來進一步消化信息得到點云的特征。這就是函數的組合:每個點都做h低維到高維的映射,G是對稱的那么整個結構就都是對稱的。下圖就是原始的pointnet結構。
在實際執行過程中,可以用MLP多層感知器(Multilayer perceptron) 來描述h和γ,g( max polling) 效果最好。
我們發現,pointnet 可以任意的逼近在集合上的對稱函數,只要是對稱函數是在hausdorff空間是連續的,那么就可以通過任意的增加神經網絡的寬度深度,來逼近這個函數:
視角變換
如何來應對輸入點云的幾何(視角)變換,比如一輛車在不同的角度點云的xyz都是不同的, 但代表的都是車,我們希望網絡也能應對視角的變換。
增加了一個基于數據本身的變換函數模塊, T-net?生成變換參數,之后的網絡處理變換之后的點,目標是通過整體優化變換網絡和后面的網絡使得變換函數對齊輸入,如果對齊了,不同視角的問題就可以簡化。實際中點云的變化很簡單,不像圖片做變換需要做插值,做矩陣乘法就可以。比如對于一個3*3的矩陣僅僅是一個正交變換,計算容易實現簡單。
PointNet分類網絡
將以上這些變換的網絡和pointnet結合起來,就可以得到PointNet分類網絡。
首先輸入一個n*3的矩陣,先做一個輸入的矩陣變換,T-net 變成一個n*3的矩陣,然后通過MLP把每個點投射到64高維空間,在做一個高維空間的變換,形成一個更加歸一化的64維矩陣,繼續做MLP將64維映射到1024維,在1024中可以做對稱性的操作,就maxpooling,得到globle fearue,1024維度 ,通過全連接網絡生成k (分類)。
PointNet分割網絡
分割網絡如圖:
可以定以為對每個點的分類問題,通過全局坐標是沒法對每個點進行分割的,簡單有效的做法是,將局部單個點的特征和全局的坐標結合起來,實現分割的功能。最后輸出m類相當于m個score:(將單個點和總體的特征連接到一起,判定在總體中的位置,來決定是哪個分類)
2.PointNet源碼
這里使用的是Pytorch的版本。
class STN3d(nn.Module):'''3x3 transform'''def __init__(self):super(STN3d, self).__init__()self.conv1 = torch.nn.Conv1d(3, 64, 1)self.conv2 = torch.nn.Conv1d(64, 128, 1)self.conv3 = torch.nn.Conv1d(128, 1024, 1)self.fc1 = nn.Linear(1024, 512)self.fc2 = nn.Linear(512, 256)self.fc3 = nn.Linear(256, 9)self.relu = nn.ReLU()self.bn1 = nn.BatchNorm1d(64)self.bn2 = nn.BatchNorm1d(128)self.bn3 = nn.BatchNorm1d(1024)self.bn4 = nn.BatchNorm1d(512)self.bn5 = nn.BatchNorm1d(256)def forward(self, x):batchsize = x.size()[0]x = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = F.relu(self.bn3(self.conv3(x)))x = torch.max(x, 2, keepdim=True)[0]x = x.view(-1, 1024)x = F.relu(self.bn4(self.fc1(x)))x = F.relu(self.bn5(self.fc2(x)))x = self.fc3(x)iden = Variable(torch.from_numpy(np.array([1,0,0,0,1,0,0,0,1]).astype(np.float32))).view(1,9).repeat(batchsize,1)if x.is_cuda:iden = iden.cuda()x = x + idenx = x.view(-1, 3, 3)return xclass STNkd(nn.Module):'''64x64 transform'''def __init__(self, k=64):super(STNkd, self).__init__()self.conv1 = torch.nn.Conv1d(k, 64, 1)self.conv2 = torch.nn.Conv1d(64, 128, 1)self.conv3 = torch.nn.Conv1d(128, 1024, 1)self.fc1 = nn.Linear(1024, 512)self.fc2 = nn.Linear(512, 256)self.fc3 = nn.Linear(256, k*k)self.relu = nn.ReLU()self.bn1 = nn.BatchNorm1d(64)self.bn2 = nn.BatchNorm1d(128)self.bn3 = nn.BatchNorm1d(1024)self.bn4 = nn.BatchNorm1d(512)self.bn5 = nn.BatchNorm1d(256)self.k = kdef forward(self, x):batchsize = x.size()[0]x = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = F.relu(self.bn3(self.conv3(x)))x = torch.max(x, 2, keepdim=True)[0]x = x.view(-1, 1024)x = F.relu(self.bn4(self.fc1(x)))x = F.relu(self.bn5(self.fc2(x)))x = self.fc3(x)iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1,self.k*self.k).repeat(batchsize,1)if x.is_cuda:iden = iden.cuda()x = x + idenx = x.view(-1, self.k, self.k)return xclass PointNetfeat(nn.Module):'''Output: global feature / local+global feature'''def __init__(self, global_feat = True, feature_transform = False):super(PointNetfeat, self).__init__()self.stn = STN3d()self.conv1 = torch.nn.Conv1d(3, 64, 1)self.conv2 = torch.nn.Conv1d(64, 128, 1)self.conv3 = torch.nn.Conv1d(128, 1024, 1)self.bn1 = nn.BatchNorm1d(64)self.bn2 = nn.BatchNorm1d(128)self.bn3 = nn.BatchNorm1d(1024)self.global_feat = global_featself.feature_transform = feature_transformif self.feature_transform:self.fstn = STNkd(k=64)def forward(self, x):n_pts = x.size()[2]trans = self.stn(x)x = x.transpose(2, 1)x = torch.bmm(x, trans)x = x.transpose(2, 1)x = F.relu(self.bn1(self.conv1(x)))if self.feature_transform:trans_feat = self.fstn(x)x = x.transpose(2,1)x = torch.bmm(x, trans_feat)x = x.transpose(2,1)else:trans_feat = Nonepointfeat = xx = F.relu(self.bn2(self.conv2(x)))x = self.bn3(self.conv3(x))x = torch.max(x, 2, keepdim=True)[0]x = x.view(-1, 1024)if self.global_feat:return x, trans, trans_feat # (B, 1024) (B, 3, 3) (B, 64, 64)else:x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)return torch.cat([x, pointfeat], 1), trans, trans_feat # (B, 1088, 2500) (B,3, 3) (B, 64, 64)class PointNetCls(nn.Module):# 分類網絡def __init__(self, k=2, feature_transform=False):super(PointNetCls, self).__init__()self.feature_transform = feature_transformself.feat = PointNetfeat(global_feat=True, feature_transform=feature_transform)self.fc1 = nn.Linear(1024, 512)self.fc2 = nn.Linear(512, 256)self.fc3 = nn.Linear(256, k)self.dropout = nn.Dropout(p=0.3)self.bn1 = nn.BatchNorm1d(512)self.bn2 = nn.BatchNorm1d(256)self.relu = nn.ReLU()def forward(self, x):x, trans, trans_feat = self.feat(x)x = F.relu(self.bn1(self.fc1(x)))x = F.relu(self.bn2(self.dropout(self.fc2(x))))x = self.fc3(x)return F.log_softmax(x, dim=1), trans, trans_featclass PointNetDenseCls(nn.Module):# 分割網絡def __init__(self, k = 2, feature_transform=False):super(PointNetDenseCls, self).__init__()self.k = kself.feature_transform=feature_transformself.feat = PointNetfeat(global_feat=False, feature_transform=feature_transform)self.conv1 = torch.nn.Conv1d(1088, 512, 1)self.conv2 = torch.nn.Conv1d(512, 256, 1)self.conv3 = torch.nn.Conv1d(256, 128, 1)self.conv4 = torch.nn.Conv1d(128, self.k, 1)self.bn1 = nn.BatchNorm1d(512)self.bn2 = nn.BatchNorm1d(256)self.bn3 = nn.BatchNorm1d(128)def forward(self, x):batchsize = x.size()[0]n_pts = x.size()[2]x, trans, trans_feat = self.feat(x)x = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = F.relu(self.bn3(self.conv3(x)))x = self.conv4(x)x = x.transpose(2,1).contiguous()x = F.log_softmax(x.view(-1,self.k), dim=-1)x = x.view(batchsize, n_pts, self.k)return x, trans, trans_featif __name__ == '__main__':sim_data = Variable(torch.rand(32,3,2500))trans = STN3d()out = trans(sim_data)print('stn', out.size())print('loss', feature_transform_regularizer(out))sim_data_64d = Variable(torch.rand(32, 64, 2500))trans = STNkd(k=64)out = trans(sim_data_64d)print('stn64d', out.size())print('loss', feature_transform_regularizer(out))pointfeat = PointNetfeat(global_feat=True)out, _, _ = pointfeat(sim_data)print('global feat', out.size())pointfeat = PointNetfeat(global_feat=False)out, _, _ = pointfeat(sim_data)print('point feat', out.size())cls = PointNetCls(k = 5)out, _, _ = cls(sim_data)print('class', out.size())seg = PointNetDenseCls(k = 3)out, _, _ = seg(sim_data)print('seg', out.size()) stn torch.Size([32, 3, 3]) loss tensor(2.5054, grad_fn=<MeanBackward0>)stn64d torch.Size([32, 64, 64]) loss tensor(127.5234, grad_fn=<MeanBackward0>)global feat torch.Size([32, 1024])point feat torch.Size([32, 1088, 2500])class torch.Size([32, 5])seg torch.Size([32, 2500, 3])參考資料
https://www.cnblogs.com/yibeimingyue/p/12002469.html
https://github.com/fxia22/pointnet.pytorch
http://stanford.edu/~rqi/pointnet/
https://zhuanlan.zhihu.com/p/86331508
https://www.bilibili.com/video/BV1As411377S/?spm_id_from=333.788.videocard.1
總結
以上是生活随笔為你收集整理的PointNet:3D点集分类与分割深度学习模型的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 兴业信用卡年费多少 免年费政策帮你省钱
- 下一篇: 农行信用卡年费多少 想省钱的多刷卡