【深度学习】图像去模糊算法代码实践!
作者:陳信達,上海科技大學,Datawhale成員
1.起源:GAN
結構與原理
在介紹DeblurGANv2之前,我們需要大概了解一下GAN,GAN最初的應用是圖片生成,即根據訓練集生成圖片,如生成手寫數字圖像、人臉圖像、動物圖像等等,其主要結構如下:
我們先由上圖的左下方開始,假設現在只有一個樣本,即batch size為1,則Random noise是一個由服從標準正態分布的隨機數組成的向量。首先,我們將Random noise輸入Generator,最原始GAN的Generator是一個多層感知機,其輸入是一個向量,輸出也是一個向量,然后我們將輸出的向量reshape成一個矩陣,這個矩陣就是一張圖片(一個矩陣是因為MNIST手寫數據集中的圖片是單通道的灰度圖,如果想生成彩色圖像就reshape成三個矩陣),即與上圖的“8”對應。我們稱Generator生成的圖像為fake image,訓練集中的圖片為real image。
上圖中的Distriminator為判別器,它是一個二分類的多層感知機,輸出只有一個數,由于多層感知機只接受向量為其輸入,我們將一張圖片由矩陣展開為向量后再輸入Discriminator,經過一系列運算后輸出一個0~1之間的數,這個數越接近于0,代表著判別器認為這張圖片是fake image;反之,假如輸出的數越接近于1,則判別器認為這張圖片是real image。為了方便,我們將Generator簡稱為G,Distriminator簡稱為D。
總而言之,G的目的是讓自己生成的fake image盡可能欺騙D,而D的任務是盡可能辨別出fake image和real image,二者不停博弈。最終理想情況下,G生成的數據與真實數據非常接近,而D無論輸入fake image還是real image都輸出0.5。
損失函數
GAN的損失函數是Binary cross entropy loss,簡稱為BCELoss,其主要利用了極大似然的思想,實際上就是二分類對應的交叉熵損失函數。公式如下:
其中是樣本數,是第個樣本的真實值,是第個樣本的預測值。對于第個樣本來說,由于取值只能是0或1,此時只看第個樣本,所以。當時,,而的取值范圍為0~1,故當時,=0,當時,,我們的目標是使的值越小越好,即當越接近0時,的值越小。反之,當時,,越接近1時,的值越小。總之,當越接近于時,的值越小。
那么BCELoss和GAN有什么關系呢?
我們將GAN的Loss分為和,即生成器的損失和判別器的損失。
對于生成器來說,它希望自己生成的圖片能騙過判別器,即希望D(fake)越接近1越好,D(fake)就是G生成的圖片輸入D后的輸出值,D(fake)接近于1意味著G生成的圖片可以以假亂真來欺騙判別器,所以GLoss的公式如下所示:
當越接近1,越小,意味著生成器騙過了判別器;
對于判別器來說,它的損失分為兩部分,首先,它不希望自己被fake image欺騙,即與相反,這里用表示:
當越接近0,越小,意味著判別器分辨出了fake image;
其次,判別器做出判斷必須有依據,所以它需要知道真實圖片是什么樣的才能正確地辨別假圖片,這里用表示:
當越接近1,越小,意味著判別器辨別出了real image。
其實就是這兩個損失值的平均值:
優化器
介紹完GAN的損失函數后,我們還剩下最后一個問題:怎么使損失函數的值越來越小?
這里就需要說一下優化器(Optimizer),優化器就是使損失函數值越來越小的工具,常用的優化器有SGD、NAG、RMSProp、Adagrad、Adam和Adam的一些變種,其中最常用的是Adam。
最終結果
由上圖我們可以清楚地看出來,隨著訓練輪數增加,G生成的fake image越來越接近手寫數字。
目前GAN有很多應用,每個應用對應的論文和Pytorch代碼可以參考下面的鏈接,其中也有GAN的代碼,大家可以根據代碼進一步理解GAN:https://github.com/eriklindernoren/PyTorch-GAN
2.圖像去模糊算法:DeblurGANv2
數據集
圖像去模糊的數據集通常由許多組圖像組成,每組圖像就是一張清晰圖像和與之對應的模糊圖像。然而,其數據集的制作并不容易,目前常用的方法有兩種,第一種是用高幀數的攝像機拍攝視頻,從視頻中找到連續幀中的模糊圖片和清晰圖片作為一組數據;第二種方法是用已知或隨機生成的運動模糊核對清晰圖片進行模糊操作,生成對應的一組數據。albumentations是Python中常用的數據擴增庫,可以對圖片進行旋轉、縮放、裁剪等操作,我們也可以使用albumentations給圖像增加運動模糊,具體操作如下:
首先安裝albumentations庫,在cmd或虛擬環境中輸入:
python?-m?pip?install?albumentations為了給圖像添加運動模糊,我們需要用matplotlib庫來讀取、顯示和保存圖片。
import?albumentations?as?A from?matplotlib?import?pyplot?as?plt#?讀取和顯示原圖 img?=?plt.imread('./images/ywxd.jpg') plt.imshow(img) plt.axis('off') plt.show()albumentations添加運動模糊操作如下,其中blur_limit是卷積核大小的范圍,這里卷積核大小在150到180之間,卷積核越大,模糊效果越明顯;p是進行運動模糊操作概率。
如果想查看對應的模糊核,我們可以對aug這個實例調用get_params方法,這里為了大家觀看方便,我使用的是3*3的卷積核。
我使用的數據集是DeblurGANv1的數據集,鏈接:https://gas.graviti.cn/dataset/datawhale/BlurredSharp
模糊圖片:
清晰圖片:
網絡結構
DeblurGANv2的思路與GAN大致相同,區別之處在于其對GAN做了大量優化,我們先來看Generator的結構:
觀察上圖可以發現,G主要有兩個改變:
輸入用模糊的圖片替代了GAN中的隨機向量
網絡結構引入了目標檢測中的FPN結構,融合了多尺度的特征
另外,在特征提取部分作者提供了三種網絡主干:MobileNetv2、inceptionresnetv2和densenet121,經過作者實驗得出,inceptionresnetv2的效果最好,但模型較大,而MobilNetv2在不降低太大效果的基礎上大大減少了網絡參數,網絡主干在上圖中對應部分如下所示:
最后,將fpn的輸出與原圖進行按元素相加操作得到最終輸出。
DeblurGANv2的判別器由全局和局部兩部分組成,全局判別器輸入的是整張圖片,局部判別器輸入的是隨機裁剪后的圖片,將輸入圖片經過一系列卷積操作后輸出一個數,這個數代表判別器認為其為real image的概率,判別器的結構如下所示:
損失函數
DeblurGANv2與GAN差別最大的部分就是它的損失函數,我們首先看看D的loss:
D的目的是為了辨別圖片的真假,所以D(fake)越小,D(real)越大時,代表D能很好地判斷圖片的真假,故對于D來說,越小越好
為了防止過擬合,后面還會加上一個L2懲罰項:
G的loss較D復雜很多,它由和組成,其實就是一個perceptual loss,它其實就是將real image和fake image分別輸入vgg19,將輸出的特征圖做MSELoss(均方誤差),而作者在perceptual loss的基礎上又做了一些改變,公式可以總結為下式:
由公式可以很容易推斷,的作用就是讓G生成的圖片和原圖盡可能相似來達到去模糊的目的。
對于來說,其可以總結為下面公式:
由于G的目的是盡可能以假亂真騙過D,所以和越接近于1越好,即越小越好。
最后,G的loss如下所示:
作者給出的lambda為0.001,可以看出作者更注重生成圖像與原圖的相似性。
3.代碼實踐
訓練自己的數據集
(目前僅支持gpu訓練!)
github項目地址:https://github.com/VITA-Group/DeblurGANv2
數據地址:https://gas.graviti.cn/dataset/datawhale/BlurredSharp
首先將數據文件夾和項目文件夾按照下面結構放置:
安裝python環境,在cmd中輸入:
conda?create?-n?deblur?python=3.9 conda?activate?deblur python?-m?pip?install?-r?requirements.txt修改config文件夾中的配置文件config.yaml:
project:?deblur_gan experiment_desc:?fpntrain:files_a:?&FILES_A?./dataset/train/blurred/*.png??files_b:?&FILES_B?./dataset/train/sharp/*.png??size:?&SIZE?256?crop:?random??preload:?&PRELOAD?falsepreload_size:?&PRELOAD_SIZE?0bounds:?[0,?.9]scope:?geometriccorrupt:?&CORRUPT-?name:?cutoutprob:?0.5num_holes:?3max_h_size:?25max_w_size:?25-?name:?jpegquality_lower:?70quality_upper:?90-?name:?motion_blur-?name:?median_blur-?name:?gamma-?name:?rgb_shift-?name:?hsv_shift-?name:?sharpenval:files_a:?&FILE_A?./dataset/val/blurred/*.pngfiles_b:?&FILE_B?./dataset/val/sharp/*.pngsize:?*SIZEscope:?geometriccrop:?centerpreload:?*PRELOADpreload_size:?*PRELOAD_SIZEbounds:?[.9,?1]corrupt:?*CORRUPTphase:?train warmup_num:?3 model:g_name:?resnetblocks:?9d_name:?double_gan?#?may?be?no_gan,?patch_gan,?double_gan,?multi_scaled_layers:?3content_loss:?perceptualadv_lambda:?0.001disc_loss:?wgan-gplearn_residual:?Truenorm_layer:?instancedropout:?Truenum_epochs:?200 train_batches_per_epoch:?1000 val_batches_per_epoch:?100 batch_size:?1 image_size:?[256,?256]optimizer:name:?adamlr:?0.0001 scheduler:name:?linearstart_epoch:?50min_lr:?0.0000001如果是windows系統需要刪除train.py第180行
然后在cmd中cd到項目路徑并輸入:
python?train.py訓練結果可以在tensorboard中可視化出來:
驗證集ssim(結構相似性):
驗證集GLoss:
驗證集PSNR(峰值信噪比):
測試(CPU、GPU均可)
GPU
將測試圖片以test.png保存到DeblurGANv2-master文件夾下,在CMD中輸入:
python?predict.py?test.png運行成功后結果submit文件夾中,predict.py中的模型文件默認為best_fpn.h5,大家也可以在DeblurGANv2的github中下載作者訓練好的模型文件,保存在項目文件夾后將predict.py文件中的第93行改為想要用的模型文件即可,如將'best_fpn.h5'改為'fpn_inception.h5',但是需要將config.yaml中model對應的g_name改為相應模型,如想使用'fpn_mobilenet.h5',就將'fpn_inception'改為'fpn_mobilenet'
CPU
將predict.py文件中第21行、22和65行改為下面代碼即可
model.load_state_dict(torch.load(weights_path,?map_location=torch.device('cpu'))['model']) self.model?=?model inputs?=?[img]運行后就可以得到下面效果:
DeblurGAN的應用:優化YOLOv5性能
由上圖可以看出,圖片去模糊不僅可以提高YOLOv5的檢測置信度,還可以使檢測更準確。以Mobilenetv2為backbone的DeblurGANv2能達到圖片實時去模糊的要求,進而可以使用到視頻質量增強等方向。
線上訓練
如果我們不想把數據集下載到本地的話可以考慮格物鈦(Graviti)的線上訓練功能,在原項目的基礎上改幾行代碼即可。
首先我們打開項目文件夾中的dataset.py文件,在第一行導入tensorbay和PIL(如果沒有安裝tensorbay需要先pip install):
from?tensorbay?import?GAS from?tensorbay.dataset?import?Dataset?as?TensorBayDataset from?PIL?import?Image我們主要修改的是PairedDatasetOnline類還有_read_img函數,為了保留原來的類,我們新建一個類,將下面代碼復制粘貼到dataset.py文件中即可(記得將ACCESS_KEY改為自己空間的 Graviti AccessKey):
class?PairedDatasetOnline(Dataset):def?__init__(self,files_a:?Tuple[str],files_b:?Tuple[str],transform_fn:?Callable,normalize_fn:?Callable,corrupt_fn:?Optional[Callable]?=?None,preload:?bool?=?True,preload_size:?Optional[int]?=?0,verbose=True):assert?len(files_a)?==?len(files_b)self.preload?=?preloadself.data_a?=?files_aself.data_b?=?files_bself.verbose?=?verboseself.corrupt_fn?=?corrupt_fnself.transform_fn?=?transform_fnself.normalize_fn?=?normalize_fnlogger.info(f'Dataset?has?been?created?with?{len(self.data_a)}?samples')if?preload:preload_fn?=?partial(self._bulk_preload,?preload_size=preload_size)if?files_a?==?files_b:self.data_a?=?self.data_b?=?preload_fn(self.data_a)else:self.data_a,?self.data_b?=?map(preload_fn,?(self.data_a,?self.data_b))self.preload?=?Truedef?_bulk_preload(self,?data:?Iterable[str],?preload_size:?int):jobs?=?[delayed(self._preload)(x,?preload_size=preload_size)?for?x?in?data]jobs?=?tqdm(jobs,?desc='preloading?images',?disable=not?self.verbose)return?Parallel(n_jobs=cpu_count(),?backend='threading')(jobs)@staticmethoddef?_preload(x:?str,?preload_size:?int):img?=?_read_img(x)if?preload_size:h,?w,?*_?=?img.shapeh_scale?=?preload_size?/?hw_scale?=?preload_size?/?wscale?=?max(h_scale,?w_scale)img?=?cv2.resize(img,?fx=scale,?fy=scale,?dsize=None)assert?min(img.shape[:2])?>=?preload_size,?f'weird?img?shape:?{img.shape}'return?imgdef?_preprocess(self,?img,?res):def?transpose(x):return?np.transpose(x,?(2,?0,?1))return?map(transpose,?self.normalize_fn(img,?res))def?__len__(self):return?len(self.data_a)def?__getitem__(self,?idx):a,?b?=?self.data_a[idx],?self.data_b[idx]if?not?self.preload:a,?b?=?map(_read_img,?(a,?b))a,?b?=?self.transform_fn(a,?b)if?self.corrupt_fn?is?not?None:a?=?self.corrupt_fn(a)a,?b?=?self._preprocess(a,?b)return?{'a':?a,?'b':?b}@staticmethoddef?from_config(config):config?=?deepcopy(config)#?files_a,?files_b?=?map(lambda?x:?sorted(glob(config[x],?recursive=True)),?('files_a',?'files_b'))segment_name?=?'train'?if?'train'?in?config['files_a']?else?'val'ACCESS_KEY?=?"yours"gas?=?GAS(ACCESS_KEY)dataset?=?TensorBayDataset("BlurredSharp",?gas)segment?=?dataset[segment_name]files_a?=?[i?for?i?in?segment?if?'blurred'?==?i.path.split('/')[2]]files_b?=?[i?for?i?in?segment?if?'sharp'?==?i.path.split('/')[2]]transform_fn?=?aug.get_transforms(size=config['size'],?scope=config['scope'],?crop=config['crop'])normalize_fn?=?aug.get_normalize()corrupt_fn?=?aug.get_corrupt_function(config['corrupt'])#?ToDo:?add?more?hash?functionsverbose?=?config.get('verbose',?True)return?PairedDatasetOnline(files_a=files_a,files_b=files_b,preload=config['preload'],preload_size=config['preload_size'],corrupt_fn=corrupt_fn,normalize_fn=normalize_fn,transform_fn=transform_fn,verbose=verbose)再將_read_img改為:
def?_read_img(x):with?x.open()?as?fp:img?=?cv2.cvtColor(np.asarray(Image.open(fp)),?cv2.COLOR_RGB2BGR)if?img?is?None:logger.warning(f'Can?not?read?image?{x}?with?OpenCV,?switching?to?scikit-image')img?=?imread(x)[:,?:,?::-1]return?img最后一步將train.py第184行的datasets = map(PairedDataset.from_config, datasets)改為datasets = map(PairedDatasetOnline.from_config, datasets)即可。
往期精彩回顧適合初學者入門人工智能的路線及資料下載機器學習及深度學習筆記等資料打印機器學習在線手冊深度學習筆記專輯《統計學習方法》的代碼復現專輯 AI基礎下載黃海廣老師《機器學習課程》視頻課黃海廣老師《機器學習課程》711頁完整版課件本站qq群554839127,加入微信群請掃碼:
↓↓↓“閱讀原文”獲取數據集
總結
以上是生活随笔為你收集整理的【深度学习】图像去模糊算法代码实践!的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 如何进行系统还原
- 下一篇: 【深度学习】一文概览神经网络模型