DeepLabV3+语义分割实战
DeepLabV3+語義分割實戰
語義分割是計算機視覺的一項重要任務,本文使用Jittor框架實現了DeepLabV3+語義分割模型。
DeepLabV3+論文:https://arxiv.org/pdf/1802.02611.pdf
完整代碼:https://github.com/Jittor/deeplab-jittor
- 數據集
1.1 數據準備
VOC2012數據集是目標檢測、語義分割等任務常用的數據集之一, 本文使用VOC數據集的2012 trainaug (train + sbd set)作為訓練集,2012 val set作為測試集。
VOC數據集中的物體共包括20個前景類別:‘aeroplane’, ‘bicycle’, ‘bird’, ‘boat’, ‘bottle’, ‘bus’, ‘car’, ‘cat’, ‘chair’, ‘cow’, ‘diningtable’, ‘dog’, ‘horse’, ‘motorbike’, ‘person’, ‘pottedplant’, ‘sheep’, ‘sofa’, ‘train’, ‘tvmonitor’ 和背景類別
最終數據集的文件組織如下。
文件組織
根目錄
|----voc_aug
| |----datalist
| | |----train.txt
| | |----val.txt
| |----images
| |----annotations
1.2 數據加載
使用jittor.dataset.dataset的基類Dataset可以構造自己的數據集,需要實現__init__、getitem、函數。
- init: 定義數據路徑,這里的data_root需設置為之前設定的 voc_aug, split 為 train val test 之一,表示選擇訓練集、驗證集還是測試集。同時需要調用self.set_attr來指定數據集加載所需的參數batch_size,total_len、shuffle。
- getitem: 返回單個item的數據。
import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt
from jittor.dataset.dataset import Dataset, dataset_root
import jittor as jt
import os
import os.path as osp
from PIL import Image, ImageOps, ImageFilter
import numpy as np
import scipy.io as sio
import random
def fetch(image_path, label_path):
with open(image_path, ‘rb’) as fp:
image = Image.open(fp).convert(‘RGB’)
with open(label_path, 'rb') as fp:label = Image.open(fp).convert('P')return image, label
def scale(image, label):
SCALES = (0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0)
ratio = np.random.choice(SCALES)
w,h = image.size
nw = (int)(wratio)
nh = (int)(hratio)
image = image.resize((nw, nh), Image.BILINEAR)
label = label.resize((nw, nh), Image.NEAREST)return image, label
def pad(image, label):
w,h = image.size
crop_size = 513
pad_h = max(crop_size - h, 0)
pad_w = max(crop_size - w, 0)
image = ImageOps.expand(image, border=(0, 0, pad_w, pad_h), fill=0)
label = ImageOps.expand(label, border=(0, 0, pad_w, pad_h), fill=255)
return image, label
def crop(image, label):
w, h = image.size
crop_size = 513
x1 = random.randint(0, w - crop_size)
y1 = random.randint(0, h - crop_size)
image = image.crop((x1, y1, x1 + crop_size, y1 + crop_size))
label = label.crop((x1, y1, x1 + crop_size, y1 + crop_size))
return image, label
def normalize(image, label):
mean = (0.485, 0.456, 0.40)
std = (0.229, 0.224, 0.225)
image = np.array(image).astype(np.float32)
label = np.array(label).astype(np.float32)
image /= 255.0
image -= mean
image /= std
return image, label
def flip(image, label):
if random.random() < 0.5:
image = image.transpose(Image.FLIP_LEFT_RIGHT)
label = label.transpose(Image.FLIP_LEFT_RIGHT)
return image, label
class BaseDataset(Dataset):
def init(self, data_root=’/voc/’, split=‘train’, batch_size=1, shuffle=False):
super().init()
‘’’ total_len , batch_size, shuffle must be set ‘’’
self.data_root = data_root
self.split = split
self.batch_size = batch_size
self.shuffle = shuffle
self.image_root = os.path.join(data_root, 'images')self.label_root = os.path.join(data_root, 'annotations')self.data_list_path = os.path.join(self.data_root,'/datalist/' + self.split + '.txt')self.image_path = []self.label_path = []with open(self.data_list_path, "r") as f:lines = f.read().splitlines()for idx, line in enumerate(lines):_img_path = os.path.join(self.image_root, line + '.jpg')_label_path = os.path.join(self.label_root, line + '.png')assert os.path.isfile(_img_path)assert os.path.isfile(_label_path)self.image_path.append(_img_path)self.label_path.append(_label_path)self.total_len = len(self.image_path)# set_attr must be called to set batch size total len and shuffle like __len__ function in pytorchself.set_attr(batch_size = self.batch_size, total_len = self.total_len, shuffle = self.shuffle) # bs , total_len, shuffledef __getitem__(self, image_id):return NotImplementedError
class TrainDataset(BaseDataset):
def init(self, data_root=’/voc/’, split=‘train’, batch_size=1, shuffle=False):
super(TrainDataset, self).init(data_root, split, batch_size, shuffle)
def __getitem__(self, image_id):image_path = self.image_path[image_id]label_path = self.label_path[image_id]image, label = fetch(image_path, label_path)image, label = scale(image, label)image, label = pad(image, label)image, label = crop(image, label)image, label = flip(image, label)image, label = normalize(image, label)image = np.array(image).astype(np.float).transpose(2, 0, 1)image = jt.array(image)label = jt.array(np.array(label).astype(np.int))return image, label
class ValDataset(BaseDataset):
def init(self, data_root=’/voc/’, split=‘train’, batch_size=1, shuffle=False):
super(ValDataset, self).init(data_root, split, batch_size, shuffle)
def __getitem__(self, image_id):image_path = self.image_path[image_id]label_path = self.label_path[image_id]image, label = fetch(image_path, label_path)image, label = normalize(image, label)image = np.array(image).astype(np.float).transpose(2, 0, 1)image = jt.array(image)label = jt.array(np.array(label).astype(np.int))return image, label
- 模型定義
上圖為DeepLabV3+論文給出的網絡架構圖。本文采用ResNe為backbone。輸入圖像尺寸為513*513。
整個網絡可以分成 backbone aspp decoder 三個部分。
2.1 backbonb 這里使用最常見的ResNet,作為backbone并且在ResNet的最后兩次使用空洞卷積來擴大感受野,其完整定義如下:
import jittor as jt
from jittor import nn
from jittor import Module
from jittor import init
from jittor.contrib import concat, argmax_pool
import time
class Bottleneck(Module):
expansion = 4
def init(self, inplanes, planes, stride=1, dilation=1, downsample=None):
super(Bottleneck, self).init()
self.conv1 = nn.Conv(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm(planes)
self.conv2 = nn.Conv(planes, planes, kernel_size=3, stride=stride,
dilation=dilation, padding=dilation, bias=False)
self.bn2 = nn.BatchNorm(planes)
self.conv3 = nn.Conv(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm(planes * 4)
self.relu = nn.ReLU()
self.downsample = downsample
self.stride = stride
self.dilation = dilation
def execute(self, x):residual = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)if self.downsample is not None:residual = self.downsample(x)out += residualout = self.relu(out)return out
class ResNet(Module):
def init(self, block, layers, output_stride):
super(ResNet, self).init()
self.inplanes = 64
blocks = [1, 2, 4]
if output_stride == 16:
strides = [1, 2, 2, 1]
dilations = [1, 1, 1, 2]
elif output_stride == 8:
strides = [1, 2, 1, 1]
dilations = [1, 1, 2, 4]
else:
raise NotImplementedError
# Modulesself.conv1 = nn.Conv(3, 64, kernel_size=7, stride=2, padding=3, bias=False)self.bn1 = nn.BatchNorm(64)self.relu = nn.ReLU()# self.maxpool = nn.Pool(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0])self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1])self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2])self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3])def _make_layer(self, block, planes, blocks, stride=1, dilation=1):downsample = Noneif stride != 1 or self.inplanes != planes * block.expansion:downsample = nn.Sequential(nn.Conv(self.inplanes, planes * block.expansion,kernel_size=1, stride=stride, bias=False),nn.BatchNorm(planes * block.expansion),)layers = []layers.append(block(self.inplanes, planes, stride, dilation, downsample))self.inplanes = planes * block.expansionfor i in range(1, blocks):layers.append(block(self.inplanes, planes, dilation=dilation))return nn.Sequential(*layers)def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1):downsample = Noneif stride != 1 or self.inplanes != planes * block.expansion:downsample = nn.Sequential(nn.Conv(self.inplanes, planes * block.expansion,kernel_size=1, stride=stride, bias=False),nn.BatchNorm(planes * block.expansion),)layers = []layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation,downsample=downsample))self.inplanes = planes * block.expansionfor i in range(1, len(blocks)):layers.append(block(self.inplanes, planes, stride=1,dilation=blocks[i]*dilation))return nn.Sequential(*layers)def execute(self, input):x = self.conv1(input)x = self.bn1(x)x = self.relu(x)x = argmax_pool(x, 2, 2)x = self.layer1(x)low_level_feat = xx = self.layer2(x)x = self.layer3(x)x = self.layer4(x)return x, low_level_feat
def resnet50(output_stride):
model = ResNet(Bottleneck, [3,4,6,3], output_stride)
return model
def resnet101(output_stride):
model = ResNet(Bottleneck, [3,4,23,3], output_stride)
return model
2.2 ASPP
即使用不同尺寸的 dilation conv 對 backbone 得到的 feature map 進行卷積,最后 concat 并整合得到新的特征。
import jittor as jt
from jittor import nn
from jittor import Module
from jittor import init
from jittor.contrib import concat
class Single_ASPPModule(Module):
def init(self, inplanes, planes, kernel_size, padding, dilation):
super(Single_ASPPModule, self).init()
self.atrous_conv = nn.Conv(inplanes, planes, kernel_size=kernel_size,
stride=1, padding=padding, dilation=dilation, bias=False)
self.bn = nn.BatchNorm(planes)
self.relu = nn.ReLU()
def execute(self, x):x = self.atrous_conv(x)x = self.bn(x)x = self.relu(x)return x
class ASPP(Module):
def init(self, output_stride):
super(ASPP, self).init()
inplanes = 2048
if output_stride == 16:
dilations = [1, 6, 12, 18]
elif output_stride == 8:
dilations = [1, 12, 24, 36]
else:
raise NotImplementedError
self.aspp1 = Single_ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0])self.aspp2 = Single_ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1])self.aspp3 = Single_ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2])self.aspp4 = Single_ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3])self.global_avg_pool = nn.Sequential(GlobalPooling(),nn.Conv(inplanes, 256, 1, stride=1, bias=False),nn.BatchNorm(256),nn.ReLU())self.conv1 = nn.Conv(1280, 256, 1, bias=False)self.bn1 = nn.BatchNorm(256)self.relu = nn.ReLU()self.dropout = nn.Dropout(0.5)def execute(self, x):x1 = self.aspp1(x)x2 = self.aspp2(x)x3 = self.aspp3(x)x4 = self.aspp4(x)x5 = self.global_avg_pool(x)x5 = x5.broadcast((1,1,x4.shape[2],x4.shape[3]))x = concat((x1, x2, x3, x4, x5), dim=1)x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.dropout(x)return x
class GlobalPooling (Module):
def init(self):
super(GlobalPooling, self).init()
def execute (self, x):
return jt.mean(x, dims=[2,3], keepdims=1)
2.3 Decoder:
Decoder 將 ASPP 的特征放大后與 ResNet 的中間特征一起 concat, 得到最后分割所用的特征。
import jittor as jt
from jittor import nn
from jittor import Module
from jittor import init
from jittor.contrib import concat
import time
class Decoder(nn.Module):
def init(self, num_classes):
super(Decoder, self).init()
low_level_inplanes = 256
self.conv1 = nn.Conv(low_level_inplanes, 48, 1, bias=False)self.bn1 = nn.BatchNorm(48)self.relu = nn.ReLU()self.last_conv = nn.Sequential(nn.Conv(304, 256, kernel_size=3, stride=1, padding=1, bias=False),nn.BatchNorm(256),nn.ReLU(),nn.Dropout(0.5),nn.Conv(256, 256, kernel_size=3, stride=1, padding=1, bias=False),nn.BatchNorm(256),nn.ReLU(),nn.Dropout(0.1),nn.Conv(256, num_classes, kernel_size=1, stride=1, bias=True))def execute(self, x, low_level_feat):low_level_feat = self.conv1(low_level_feat)low_level_feat = self.bn1(low_level_feat)low_level_feat = self.relu(low_level_feat)x_inter = nn.resize(x, size=(low_level_feat.shape[2], low_level_feat.shape[3]) , mode='bilinear')x_concat = concat((x_inter, low_level_feat), dim=1)x = self.last_conv(x_concat)return x
2.4 完整的模型整合如下: 即將以上部分通過一個類連接起來。
import jittor as jt
from jittor import nn
from jittor import Module
from jittor import init
from jittor.contrib import concat
from decoder import Decoder
from aspp import ASPP
from backbone import resnet50, resnet101
class DeepLab(Module):
def init(self, output_stride=16, num_classes=21):
super(DeepLab, self).init()
self.backbone = resnet101(output_stride=output_stride)
self.aspp = ASPP(output_stride)
self.decoder = Decoder(num_classes)
def execute(self, input):x, low_level_feat = self.backbone(input)x = self.aspp(x)x = self.decoder(x, low_level_feat)x = nn.resize(x, size=(input.shape[2], input.shape[3]), mode='bilinear')return x
- 模型訓練
3.1 模型訓練參數設定如下:
Learning parameters
batch_size = 8
learning_rate = 0.005
momentum = 0.9
weight_decay = 1e-4
epochs = 50
3.2 定義模型、優化器、數據加載器。
model = DeepLab(output_stride=16, num_classes=21)
optimizer = nn.SGD(model.parameters(),
lr,
momentum=momentum,
weight_decay=weight_decay)
train_loader = TrainDataset(data_root=’/vocdata/’,
split=‘train’,
batch_size=batch_size,
shuffle=True)
val_loader = ValDataset(data_root=’/vocdata/’,
split=‘val’,
batch_size=1,
shuffle=False)
3.3 模型訓練與驗證
lr scheduler
def poly_lr_scheduler(opt, init_lr, iter, epoch, max_iter, max_epoch):
new_lr = init_lr * (1 - float(epoch * max_iter + iter) / (max_epoch * max_iter)) ** 0.9
opt.lr = new_lr
train function
def train(model, train_loader, optimizer, epoch, init_lr):
model.train()
max_iter = len(train_loader)
for idx, (image, target) in enumerate(train_loader):poly_lr_scheduler(optimizer, init_lr, idx, epoch, max_iter, 50) # using poly_lr_scheduler image = image.float32()pred = model(image)loss = nn.cross_entropy_loss(pred, target, ignore_index=255)optimizer.step (loss)print ('Training in epoch {} iteration {} loss = {}'.format(epoch, idx, loss.data[0]))
val function
we omit evaluator code and you can
def val (model, val_loader, epoch, evaluator):
model.eval()
evaluator.reset()
for idx, (image, target) in enumerate(val_loader):
image = image.float32()
output = model(image)
pred = output.data
target = target.data
pred = np.argmax(pred, axis=1)
evaluator.add_batch(target, pred)
print (‘Test in epoch {} iteration {}’.format(epoch, idx))
Acc = evaluator.Pixel_Accuracy()
Acc_class = evaluator.Pixel_Accuracy_Class()
mIoU = evaluator.Mean_Intersection_over_Union()
FWIoU = evaluator.Frequency_Weighted_Intersection_over_Union()
best_miou = 0.0
if (mIoU > best_miou):best_miou = mIoU
print ('Testing result of epoch {} miou = {} Acc = {} Acc_class = {} \FWIoU = {} Best Miou = {}'.format(epoch, mIoU, Acc, Acc_class, FWIoU, best_miou))
3.4 evaluator 寫法:使用混淆矩陣計算 Pixel accuracy 和 mIoU。
class Evaluator(object):
def init(self, num_class):
self.num_class = num_class
self.confusion_matrix = np.zeros((self.num_class,)*2)
def Pixel_Accuracy(self):Acc = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum()return Accdef Pixel_Accuracy_Class(self):Acc = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1)Acc = np.nanmean(Acc)return Accdef Mean_Intersection_over_Union(self):MIoU = np.diag(self.confusion_matrix) / (np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0)-np.diag(self.confusion_matrix))MIoU = np.nanmean(MIoU)return MIoUdef Frequency_Weighted_Intersection_over_Union(self):freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix)iu = np.diag(self.confusion_matrix) / (np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -np.diag(self.confusion_matrix))FWIoU = (freq[freq > 0] * iu[freq > 0]).sum()return FWIoUdef _generate_matrix(self, gt_image, pre_image):mask = (gt_image >= 0) & (gt_image < self.num_class)label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]count = np.bincount(label, minlength=self.num_class**2)confusion_matrix = count.reshape(self.num_class, self.num_class)return confusion_matrixdef add_batch(self, gt_image, pre_image):assert gt_image.shape == pre_image.shapeself.confusion_matrix += self._generate_matrix(gt_image, pre_image)def reset(self):self.confusion_matrix = np.zeros((self.num_class,) * 2)
3.5 訓練入口函數
epochs = 50
evaluator = Evaluator(21)
train_loader = TrainDataset(data_root=’/voc/data/path/’, split=‘train’, batch_size=8, shuffle=True)
val_loader = ValDataset(data_root=’/voc/data/path/’, split=‘val’, batch_size=1, shuffle=False)
learning_rate = 0.005
momentum = 0.9
weight_decay = 1e-4
optimizer = nn.SGD(model.parameters(), learning_rate, momentum, weight_decay)
for epoch in range (epochs):
train(model, train_loader, optimizer, epoch, learning_rate)
val(model, val_loader, epoch, evaluator)
4. 參考
- pytorch-deeplab-xception
- Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation
總結
以上是生活随笔為你收集整理的DeepLabV3+语义分割实战的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Megengine量化
- 下一篇: Jittor框架API