Error:output with shape [1, 224, 224] doesn‘t match the broadcast shape [3, 224, 224]
生活随笔
收集整理的這篇文章主要介紹了
Error:output with shape [1, 224, 224] doesn‘t match the broadcast shape [3, 224, 224]
小編覺(jué)得挺不錯(cuò)的,現(xiàn)在分享給大家,幫大家做個(gè)參考.
Error:output with shape [1, 224, 224] doesn't match the broadcast shape [3, 224, 224]
原模型輸入的圖片為RGB三通道,輸入的為單通道灰度圖片。
解決如下:
from torch import nn from torchvision import datasets from torchvision import transforms as T from torch.utils.data import DataLoader from torchvision.utils import make_grid, save_image import numpy as np import matplotlib.pyplot as plttransform = T.Compose([T.ToTensor(), #這會(huì)將介于0到255之間的numpy數(shù)組轉(zhuǎn)換為介于0到1之間的浮點(diǎn)張量T.Normalize((0.5, ), (0.5, )), #在normalize()方法中, 我們指定了用來(lái)標(biāo)準(zhǔn)化張量圖像所有通道的均值, 并且還指定了中心偏差。 ]) dataset = datasets.MNIST('data/', download=True, train=False, transform=transform) dataloader = DataLoader(dataset, shuffle=True, batch_size=100)print(type(dataset[0][0]),dataset[0][0].size()) # print(dataset[0][0]) # 要繪制張量圖像, 我們必須將其更改回numpy array。 # 我們將在函數(shù)def im_convert()中完成此工作, 該函數(shù)包含一個(gè)將成為張量圖像的參數(shù)。 def im_convert(tensor):image=tensor.clone().detach().numpy()# 使用torch.clone()獲得的新tensor和原來(lái)的數(shù)據(jù)不再共享內(nèi)存,但仍保留在計(jì)算圖中,# clone操作在不共享數(shù)據(jù)內(nèi)存的同時(shí)支持梯度梯度傳遞與疊加,所以常用在神經(jīng)網(wǎng)絡(luò)中某個(gè)單元需要重復(fù)使用的場(chǎng)景下。# 通常如果原tensor的requires_grad=True,則:# clone()操作后的tensor requires_grad=True# detach()操作后的tensor requires_grad=False。image=image.transpose(1, 2, 0)# 將轉(zhuǎn)換為numpy數(shù)組的張量具有第一, 第二和第三維的形狀。第一維表示顏色通道, 第二維和第三維表示圖像和像素的寬度和高度。# 我們知道MNIST數(shù)據(jù)集中的每個(gè)圖像都是對(duì)應(yīng)于單個(gè)彩色通道的灰度, 其寬度和高度為28 * 28像素。因此, 形狀將為(1、28、28)。# 為了繪制圖像, 要求圖像的形狀為(28, 28, 1)。因此, 通過(guò)將軸零, 一和二交換print(image.shape)image=image*(np.array((0.5, 0.5, 0.5))+np.array((0.5, 0.5, 0.5)))print(image.shape)# 我們對(duì)圖像進(jìn)行歸一化, 而之前我們必須對(duì)其進(jìn)行歸一化。通過(guò)減去平均值并除以標(biāo)準(zhǔn)偏差來(lái)完成歸一化。# 我們將乘以標(biāo)準(zhǔn)偏差, 然后將平均值相加image=image.clip(0, 1)print(image.shape,type(image))return image# 為了確保介于0和1之間的范圍, 我們使用了clip()# 函數(shù)并傳遞了零和一作為參數(shù)。我們將clip函數(shù)應(yīng)用到最小值0和最大值1并返回圖像。# 它將創(chuàng)建一個(gè)對(duì)象, 該對(duì)象使我們可以一次通過(guò)一個(gè)可變的訓(xùn)練加載器。 # 我們通過(guò)在dataiter上調(diào)用next來(lái)一次訪問(wèn)一個(gè)元素。 # next()函數(shù)將獲取我們的第一批訓(xùn)練數(shù)據(jù), 并且該訓(xùn)練數(shù)據(jù)將被分為以下圖像和標(biāo)簽 dataiter=iter(dataloader) images, labels=dataiter.next()fig=plt.figure(figsize=(25, 6)) #fig=plt.figure(figsize=(25, 4)) #圖片輸出寬度較上面小 for idx in np.arange(20):ax=fig.add_subplot(2, 10, idx+1)plt.imshow(im_convert(images[idx]))ax.set_title([labels[idx].item()]) plt.show()最終結(jié)果如下:
總結(jié)
以上是生活随笔為你收集整理的Error:output with shape [1, 224, 224] doesn‘t match the broadcast shape [3, 224, 224]的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: Netgear Readyshare:U
- 下一篇: 《完美应用Ubuntu》第3版 何晓龙