PyTorch-模型可视化工具TorchSummary
簡介
不同于TensorboardX對Tensorboard的支持以方便了PyTorch的訓練可視化,PyTorch并沒有很好的模型可視化工具,TorchSummary對此做出了補足,極大降低了模型可視化難度,也方便模型參數等數據的統計。本文介紹TorchSummary這個小工具的使用。
安裝
使用pip安裝即可。
pip install torchsummary開發緣由
首先,我們知道,PyTorch其實自帶模型可視化的功能,其基礎調用格式如下。
print(model) import torch import torch.nn as nn import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 10, kernel_size=5)self.conv2 = nn.Conv2d(10, 20, kernel_size=5)self.conv2_drop = nn.Dropout2d()self.fc1 = nn.Linear(320, 50)self.fc2 = nn.Linear(50, 10)def forward(self, x):x = F.relu(F.max_pool2d(self.conv1(x), 2))x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))x = x.view(-1, 320)x = F.relu(self.fc1(x))x = F.dropout(x, training=self.training)x = self.fc2(x)return F.log_softmax(x, dim=1)例如,下述簡單模型通過print可視化,結果如下,顯然,這只是對模型含有的modules做了一個對象及其參數打印,我們更希望輸出每一層的layer類型、參數量以及輸出feature map尺寸等。
Net((conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))(conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))(conv2_drop): Dropout2d(p=0.5, inplace=False)(fc1): Linear(in_features=320, out_features=50, bias=True)(fc2): Linear(in_features=50, out_features=10, bias=True) )對此,TorchSummary提供了更詳細的信息分析,包括模塊信息(每一層的類型、輸出shape和參數量)、模型整體的參數量、模型大小、一次前向或者反向傳播需要的內存大小等。
使用教程
TorchSummary的使用基于下述核心API,只要提供給summary函數模型以及輸入的size就可以了。
from torchsummary import summary summary(model, input_size=(channels, H, W))如在一個簡單CNN上進行模型可視化,代碼和結果如下(測試均使用PyTorch1.6.0),可視化輸出包括我上一節文末提到的我們需要的常用信息,非常豐富。
import torch import torch.nn as nn import torch.nn.functional as F from torchsummary import summaryclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 10, kernel_size=5)self.conv2 = nn.Conv2d(10, 20, kernel_size=5)self.conv2_drop = nn.Dropout2d()self.fc1 = nn.Linear(320, 50)self.fc2 = nn.Linear(50, 10)def forward(self, x):x = F.relu(F.max_pool2d(self.conv1(x), 2))x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))x = x.view(-1, 320)x = F.relu(self.fc1(x))x = F.dropout(x, training=self.training)x = self.fc2(x)return F.log_softmax(x, dim=1)device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = Net().to(device)summary(model, (1, 28, 28)) ----------------------------------------------------------------Layer (type) Output Shape Param # ================================================================Conv2d-1 [-1, 10, 24, 24] 260Conv2d-2 [-1, 20, 8, 8] 5,020Dropout2d-3 [-1, 20, 8, 8] 0Linear-4 [-1, 50] 16,050Linear-5 [-1, 10] 510 ================================================================ Total params: 21,840 Trainable params: 21,840 Non-trainable params: 0 ---------------------------------------------------------------- Input size (MB): 0.00 Forward/backward pass size (MB): 0.06 Params size (MB): 0.08 Estimated Total Size (MB): 0.15 ----------------------------------------------------------------對于多輸入的情況,只要傳入的input_size改為一個安裝輸入所需size組成的列表就行,示例如下。
import torch import torch.nn as nn from torchsummary import summaryclass SimpleConv(nn.Module):def __init__(self):super(SimpleConv, self).__init__()self.features = nn.Sequential(nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1),nn.ReLU(),)def forward(self, x, y):x1 = self.features(x)x2 = self.features(y)return x1, x2device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = SimpleConv().to(device)summary(model, input_size=[(1, 16, 16), (1, 28, 28)])結果如下。
----------------------------------------------------------------Layer (type) Output Shape Param # ================================================================Conv2d-1 [-1, 1, 16, 16] 10ReLU-2 [-1, 1, 16, 16] 0Conv2d-3 [-1, 1, 28, 28] 10ReLU-4 [-1, 1, 28, 28] 0 ================================================================ Total params: 20 Trainable params: 20 Non-trainable params: 0 ---------------------------------------------------------------- Input size (MB): 0.77 Forward/backward pass size (MB): 0.02 Params size (MB): 0.00 Estimated Total Size (MB): 0.78 ----------------------------------------------------------------補充說明
本文簡單介紹了我比較喜歡的PyTorch模型可視化工具,文中示例代碼參考官網,如果對你有所幫助,麻煩點贊支持一下。
總結
以上是生活随笔為你收集整理的PyTorch-模型可视化工具TorchSummary的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: FairMOT实时多目标跟踪
- 下一篇: SFTP连接问题