PyTorch 实现 GradCAM
Grad-CAM 概述:給定圖像和感興趣的類別作為輸入,我們通過模型的 CNN 部分前向傳播圖像,然后通過特定于任務的計算獲得該類別的原始分數。 除了期望的類別(虎),所有類別的梯度都設置為零,該類別設置為 1。然后將該信號反向傳播到卷積特征圖,我們將其結合起來計算粗略的 Grad-CAM 定位( 藍色熱圖)它表示模型在做出特定決策時必須查看的位置。 最后,我們將熱圖與反向傳播逐點相乘,以獲得高分辨率和特定于概念的引導式 Grad-CAM 可視化。
在本文中,我們將學習如何在 PyTorch 中繪制 GradCam [1]。
為了獲得 GradCam 輸出,我們需要激活圖和這些激活圖的梯度。
讓我們直接跳到代碼中!!
引入相應的包
import numpy as np import matplotlib as mpl import matplotlib.pyplot as plt import torch import torch.nn as nn from torchvision import models from skimage.io import imread from skimage.transform import resize我們將使用鉤子函數從所需的層和張量獲得激活映射和梯度。在本教程中,我們將從ResNet50的layer4中獲取激活映射,并對相同的輸出張量進行梯度。
class GradCamModel(nn.Module):def __init__(self):super().__init__()self.gradients = Noneself.tensorhook = []self.layerhook = []self.selected_out = None#PRETRAINED MODELself.pretrained = models.resnet50(pretrained=True)self.layerhook.append(self.pretrained.layer4.register_forward_hook(self.forward_hook()))for p in self.pretrained.parameters():p.requires_grad = Truedef activations_hook(self,grad):self.gradients = graddef get_act_grads(self):return self.gradientsdef forward_hook(self):def hook(module, inp, out):self.selected_out = outself.tensorhook.append(out.register_hook(self.activations_hook))return hookdef forward(self,x):out = self.pretrained(x)return out, self.selected_out我們向ResNet50模型的層添加一個前向鉤子。前向鉤子接受該層的輸入和該層的輸出作為參數。對于輸出張量,我們使用register_hook方法注冊一個鉤子。這個方法注冊一個向后掛鉤到一個張量,并且每次計算梯度時調用這個張量。它的輸入參數是相對于輸出張量的梯度。
以下是聲明模型實例
gcmodel = GradCamModel().to(‘cuda:0’)讀取圖片
計算類梯度激活映射
out, acts = gcmodel(inpimg) acts = acts.detach().cpu()loss = nn.CrossEntropyLoss()(out,torch.from_numpy(np.array([600])).to(‘cuda:0’)) loss.backward() grads = gcmodel.get_act_grads().detach().cpu() pooled_grads = torch.mean(grads, dim=[0,2,3]).detach().cpu() for i in range(acts.shape[1]):acts[:,i,:,:] += pooled_grads[i]heatmap_j = torch.mean(acts, dim = 1).squeeze() heatmap_j_max = heatmap_j.max(axis = 0)[0] heatmap_j /= heatmap_j_max現在,需要調整熱圖的大小和顏色。
調整大小
heatmap_j = resize(heatmap_j,(224,224),preserve_range=True)顏色映射
cmap = mpl.cm.get_cmap(‘jet’,256) heatmap_j2 = cmap(heatmap_j,alpha = 0.2)可視化
fig, axs = plt.subplots(1,1,figsize = (5,5)) axs.imshow((img*std+mean)[0].transpose(1,2,0)) axs.imshow(heatmap_j2) plt.show()結果如下
我們換一種更清晰的方式查看熱圖
heatmap_j3 = (heatmap_j > 0.75)可視化
fig, axs = plt.subplots(1,1,figsize = (5,5)) axs.imshow(((img*std+mean)[0].transpose(1,2,0))*heatmap_j3) plt.show()結果
最后我們移除剛才設置的鉤子
for h in gcmodel.layerhook:h.remove() for h in gcmodel.tensorhook:h.remove()引用
[1] R. R. Selvaraju, M. Cogswell, A. Das, R. Vedantam, D. Parikh and D. Batra, “Grad-CAM: Visual Explanations from Deep Networks via Gradient-Based Localization,” 2017 IEEE International Conference on Computer Vision (ICCV), 2017, pp. 618–626, doi: 10.1109/ICCV.2017.74.本文作者: the owl
總結
以上是生活随笔為你收集整理的PyTorch 实现 GradCAM的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 营业执照在线生成_平罗县实现个体户营业执
- 下一篇: 华为手机android怎么解锁,怎么查看