基于部分卷积Pconv的图片修复
論文:Image Inpainting for Irregular Holes Using Partial Convolutions?
Github:
https://github.com/MathiasGruber/PConv-Keras
https://github.com/deeppomf/DeepCreamPy#dependencies-for-running-the-code-yourself
https://github.com/deeppomf/DeepCreamPy/releases/tag/v1.2.1-beta
?
英偉達(dá)的論文,非常值得閱讀,PConv和loss func都很有特點。
論文貢獻(xiàn):
?
網(wǎng)絡(luò)結(jié)構(gòu):
網(wǎng)絡(luò)采用U-Net結(jié)構(gòu),分為編碼模塊(PConv1-PConv8)和解碼模塊(PConv9-PConv16)兩部分。
?
Partial Convolutional(PConv):
部分卷積將卷積分為了輸入圖片的卷積和輸入掩碼mask的卷積。之前的論文都是只在第一層使用mask,mask也不會得到跟新,本文的partial convolutions,每次都使用跟新后的mask,隨著網(wǎng)絡(luò)層數(shù)的增加,mask輸出m’中為0的像素越來越少,輸出的結(jié)果x’中有效區(qū)域的面積越來越大,mask對整體loss的影響會越來越小(如上圖所示,表示了不同層的mask輸出)。
如上式所示,W表示卷積層濾波器的weights,b表示卷積層濾波器的bias,X表示輸入的圖片,M表示掩碼mask,⊙ 表示element-wise點乘運算,x'表示輸入圖片經(jīng)過卷積后的輸出,m’表示輸入掩碼經(jīng)過卷積后的輸出。
Keras實現(xiàn):
def call(self, inputs, mask=None): # Both image and mask must be suppliedif type(inputs) is not list or len(inputs) != 2:raise Exception('PartialConvolution2D must be called on a list of two tensors [img, mask]. Instead got: ' + str(inputs))# Create normalization. Slight change here compared to paper, using mean mask value instead of sumnormalization = K.mean(inputs[1], axis=[1,2], keepdims=True)normalization = K.repeat_elements(normalization, inputs[1].shape[1], axis=1)normalization = K.repeat_elements(normalization, inputs[1].shape[2], axis=2)# Apply convolutions to imageimg_output = K.conv2d((inputs[0]*inputs[1]) / normalization, self.kernel, strides=self.strides,padding=self.padding,data_format=self.data_format,dilation_rate=self.dilation_rate)# Apply convolutions to maskmask_output = K.conv2d(inputs[1], self.kernel_mask, strides=self.strides,padding=self.padding, data_format=self.data_format,dilation_rate=self.dilation_rate)# Where something happened, set 1, otherwise 0 mask_output = K.cast(K.greater(mask_output, 0), 'float32')# Apply bias only to the image (if chosen to do so)if self.use_bias:img_output = K.bias_add(img_output,self.bias,data_format=self.data_format)# Apply activations on the imageif self.activation is not None:img_output = self.activation(img_output)return [img_output, mask_output]Loss:
Iin:輸入的圖片
Iout:網(wǎng)絡(luò)的預(yù)測輸出
M :掩碼,孔洞為0,有效像素為1
Igt:label,即ground truth
Icomp :孔洞像素的輸出
Ψn :第n層激活后的特征圖,本文取pool1, pool2, pool3
?
孔洞的損失:
1-M表示孔洞區(qū)域,整體表示了孔洞區(qū)域的輸出和ground truth的L1 loss。
Keras實現(xiàn):
def loss_hole(self, mask, y_true, y_pred):"""Pixel L1 loss within the hole / mask"""return self.l1((1-mask) * y_true, (1-mask) * y_pred)非孔洞的有效像素的損失:
M表示非孔洞區(qū)域,整體表示非孔洞區(qū)域的網(wǎng)絡(luò)預(yù)測輸出和ground truth的L1 loss。
Keras實現(xiàn):
def loss_valid(self, mask, y_true, y_pred):"""Pixel L1 loss outside the hole / mask"""return self.l1(mask * y_true, mask * y_pred)感知的損失:
感知的損失,或者內(nèi)容的損失,表示了pool1, pool2, pool3層的輸出和ground truth的L1 損失。表示了width,height,channel三個方面的差異。
Keras實現(xiàn):
def loss_perceptual(self, vgg_out, vgg_gt, vgg_comp): """Perceptual loss based on VGG16, see. eq. 3 in paper""" loss = 0for o, c, g in zip(vgg_out, vgg_comp, vgg_gt):loss += self.l1(o, g) + self.l1(c, g)return loss風(fēng)格的損失:
Kn :歸一化參數(shù),表示為1/CnHnWn
Ψn 的形狀為(HnWn) × Cn ,因此Ψn 的轉(zhuǎn)置和Ψn 的矩陣乘積后輸出的矩陣大小為Cn × Cn 。
整體公式表示了pool1, pool2, pool3層的輸出和輸出的轉(zhuǎn)置與ground truth和ground truth的轉(zhuǎn)置的差異。表示了channel方面的差異。
Keras實現(xiàn):
def loss_style(self, output, vgg_gt):"""Style loss based on output/computation, used for both eq. 4 & 5 in paper"""loss = 0for o, g in zip(output, vgg_gt):loss += self.l1(self.gram_matrix(o), self.gram_matrix(g))return loss平滑性的損失:
P表示經(jīng)過1個像素的膨脹后的孔洞區(qū)域。
平滑性損失total variation (TV) 表示為孔洞區(qū)域內(nèi)一個像素和該像素的右側(cè)像素和下面像素的L1 loss。總體來看衡量了2個孔洞區(qū)域(一個為原始孔洞區(qū)域,另一個為在水平方向右移一個像素的區(qū)域,或者在垂直方向下移一個像素的區(qū)域)在水平方向和垂直方向的差異。
Keras實現(xiàn):
def loss_tv(self, mask, y_comp):"""Total variation loss, used for smoothing the hole region, see. eq. 6"""# Create dilated hole region using a 3x3 kernel of all 1s.kernel = K.ones(shape=(3, 3, mask.shape[3], mask.shape[3]))dilated_mask = K.conv2d(1-mask, kernel, data_format='channels_last', padding='same')# Cast values to be [0., 1.], and compute dilated hole region of y_compdilated_mask = K.cast(K.greater(dilated_mask, 0), 'float32')P = dilated_mask * y_comp# Calculate total variation lossa = self.l1(P[:,1:,:,:], P[:,:-1,:,:])b = self.l1(P[:,:,1:,:], P[:,:,:-1,:]) return a+b總的loss:
每個loss前面的權(quán)重大小是在100個驗證圖片上使用參數(shù)搜索得到的。
?
實驗結(jié)果:
本文的Pconv方法優(yōu)于PM(PatchMatch),GL,GntIpt 等方法。
?
總結(jié):
總結(jié)
以上是生活随笔為你收集整理的基于部分卷积Pconv的图片修复的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 亚马逊电商数据自动化管理接口平台体系设计
- 下一篇: 用信息化为科研加速