PyTorch基础(15)-- torch.flatten()方法
生活随笔
收集整理的這篇文章主要介紹了
PyTorch基础(15)-- torch.flatten()方法
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
前言
最近在復現論文中一個塊的時候需要使用到torch.flatten()這個方法,這個方法其實很簡單,但其中有一些細節可能需要注意,且有個關鍵點很容易忘記,故在此記錄以備查閱。
方法解析
flatten的中文含義為“扁平化”,具體怎么理解呢?我們可以嘗試這么理解,假設你的數據為1維數據,那么這個數據天然就已經扁平化了,如果是2維數據,那么扁平化就是將2維數據變為1維數據,如果是3維數據,那么就要根據你自己所選擇的“扁平化程度”來進行操作,假設需要全部扁平化,那么就直接將3維數據變為1維數據,如果只需要部分扁平化,那么有一維的數據不會進行扁平操作,具體看下面的案例分析。
可以看到,torch.flatten()方法有三個參數,分別:
- input tensor:該方法的輸入
- start_dim:開始flatten的維度
- end_dim:結束flatten的維度
案例解析
- 導包
- 案例1 – 全部扁平化
- 案例2 – 部分扁平化
- 案例3 – 部分扁平化
總結
以上是生活随笔為你收集整理的PyTorch基础(15)-- torch.flatten()方法的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 互联网晚报 | 10月23日 星期六 |
- 下一篇: 百天搞懂Java(一) - JDK环境配