mmdetection学习之anchor_generator
什么是mmdetection就不介紹了,自己可取baidu或者google
文件:mmdet/core/anchor/anchor_generator.py
在文件中定義了一個detector產(chǎn)生預(yù)選框的類。
程序具體如下。
import torchclass AnchorGenerator(object):def __init__(self, base_size, scales, ratios, scale_major=True, ctr=None):self.base_size = base_sizeself.scales = torch.Tensor(scales)self.ratios = torch.Tensor(ratios)self.scale_major = scale_majorself.ctr = ctrself.base_anchors = self.gen_base_anchors()@propertydef num_base_anchors(self):return self.base_anchors.size(0)def gen_base_anchors(self):w = self.base_sizeh = self.base_sizeif self.ctr is None:x_ctr = 0.5 * (w - 1)y_ctr = 0.5 * (h - 1)else:x_ctr, y_ctr = self.ctrh_ratios = torch.sqrt(self.ratios)w_ratios = 1 / h_ratiosif self.scale_major:ws = (w * w_ratios[:, None] * self.scales[None, :]).view(-1)hs = (h * h_ratios[:, None] * self.scales[None, :]).view(-1)else:ws = (w * self.scales[:, None] * w_ratios[None, :]).view(-1)hs = (h * self.scales[:, None] * h_ratios[None, :]).view(-1)base_anchors = torch.stack([x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1),x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1)],dim=-1).round()return base_anchorsdef _meshgrid(self, x, y, row_major=True):xx = x.repeat(len(y))yy = y.view(-1, 1).repeat(1, len(x)).view(-1)if row_major:return xx, yyelse:return yy, xxdef grid_anchors(self, featmap_size, stride=16, device='cuda'):base_anchors = self.base_anchors.to(device)feat_h, feat_w = featmap_sizeshift_x = torch.arange(0, feat_w, device=device) * strideshift_y = torch.arange(0, feat_h, device=device) * strideshift_xx, shift_yy = self._meshgrid(shift_x, shift_y)shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1)shifts = shifts.type_as(base_anchors)# first feat_w elements correspond to the first row of shifts# add A anchors (1, A, 4) to K shifts (K, 1, 4) to get# shifted anchors (K, A, 4), reshape to (K*A, 4)all_anchors = base_anchors[None, :, :] + shifts[:, None, :]all_anchors = all_anchors.view(-1, 4)# first A rows correspond to A anchors of (0, 0) in feature map,# then (0, 1), (0, 2), ...return all_anchorsdef valid_flags(self, featmap_size, valid_size, device='cuda'):feat_h, feat_w = featmap_sizevalid_h, valid_w = valid_sizeassert valid_h <= feat_h and valid_w <= feat_wvalid_x = torch.zeros(feat_w, dtype=torch.uint8, device=device)valid_y = torch.zeros(feat_h, dtype=torch.uint8, device=device)valid_x[:valid_w] = 1valid_y[:valid_h] = 1valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)valid = valid_xx & valid_yyvalid = valid[:, None].expand(valid.size(0), self.num_base_anchors).contiguous().view(-1)return valid?
num_base_anchors:產(chǎn)生的是以特征圖上一個點為中心產(chǎn)生的預(yù)框的數(shù)量,數(shù)量由scales, ratios這兩個tensor的size決定。例如:這兩個tensor的size都為3,則預(yù)選框的數(shù)量為3X3=9
gen_base_anchors:產(chǎn)生的上述預(yù)選框的具體操作
_meshgrid:在給定兩個tensor情況下,產(chǎn)生這兩個tensor的形狀的網(wǎng)格
grid_anchors:在給定一個特征圖具體尺寸,如特征圖大小為[10,10],scales, ratios這兩個tensor的size均為3的情況下,會以特征圖的每一個點為中心,在這個點上產(chǎn)生不同的9個預(yù)選框,共產(chǎn)生10X10X9=900個預(yù)選框。
valid_flags:給定兩個尺寸,如第一個尺寸為特征圖尺寸,第二個為標簽在特征圖同一尺寸下的標簽,則該函數(shù)會產(chǎn)生產(chǎn)生一個區(qū)域,該區(qū)域大小和特征圖尺寸一樣,標簽在該特征圖下的區(qū)域中的值全為1,其他區(qū)域值全為0。函數(shù)最后還將該區(qū)域使用expand操作將該區(qū)域擴張了多次,次數(shù)為特征圖一個點下預(yù)選框的數(shù)目。
?
總結(jié)
以上是生活随笔為你收集整理的mmdetection学习之anchor_generator的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 初识SQL
- 下一篇: Java编程语言的重要性