pytorch中的squeeze和unsqueeze
生活随笔
收集整理的這篇文章主要介紹了
pytorch中的squeeze和unsqueeze
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
squeeze:壓縮,要減少維度。
unsqueeze:解壓縮,要增加維度。
torch.squeeze(input),那么會把input中所有維度長度為1的維度去掉。
torch.squeeze(input,dim=1),那么在給定dim的情況下,就只去掉dim這個維度,其他維度還保留。
tensor([[0.0621, 0.2074, 0.5420],
[0.5897, 0.3664, 0.4387],
[0.0115, 0.3464, 0.0702],
[0.7800, 0.4727, 0.1952],
[0.6879, 0.8595, 0.3933]])
這時候x的形狀還是5行3列。因為沒有哪個維度的長度為1。
那么x的形狀是(5,1,3),有5個塊,每個塊是1行3列。
對于unsquueze來講,維度可以比原有維度高1。例如最開始x的形狀是(5,3)。可以如下操作。
import torch x = torch.rand(5,3) x = x.unsqueeze(2) tensor([[[0.3757],[0.8054],[0.0250]],[[0.9423],[0.5109],[0.2437]],[[0.6276],[0.4251],[0.3276]],[[0.6699],[0.0768],[0.3541]],[[0.6123],[0.0268],[0.4193]]])那么得到的tensor形狀是(5,3,1)。
還是看你想要什么樣的形狀。
總結
以上是生活随笔為你收集整理的pytorch中的squeeze和unsqueeze的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: maven-settings.xml的那
- 下一篇: [密码学][困难问题][常见规约]密码学