如何对batch的数据求Gram矩阵
Gram矩陣概念和理解
在風格遷移中,我們要比較生成圖片和風格圖片的相似性,評判標準就是通過計算Gram矩陣得到的。關于Gram矩陣的定義,可以參考[1]。
由這個矩陣的樣子,很容易就想到協(xié)方差矩陣。如果協(xié)方差矩陣是什么忘了的化可以參考[2],可以看到Gram矩陣是沒有減去均值的協(xié)方差矩陣。協(xié)方差矩陣是一種相關性度量的矩陣,通過協(xié)方差來度量相關性,也就是度量兩個圖片風格的相似性。(如果相對協(xié)方差和相關系數(shù)有進一步了解,可以參考[3])
如何通過代碼實現(xiàn)Gram矩陣計算
了解Gram矩陣的概念和性質 ,我們就來看一看如何用代碼來實現(xiàn)Gram矩陣的計算。這里,使用PyTorch來實現(xiàn)計算過程。
PyTorch中有兩個函數(shù)torch.mm和torch.bmm前者是計算矩陣乘法,后者是計算batch數(shù)據的矩陣乘法,風格遷移中是對batch數(shù)據進行操作,所以使用bmm。
我們創(chuàng)造一個batch為2,單通道,2*2大小的數(shù)據
a = torch.arange(8, dtype=torch.int).reshape(2, 1, 2, 2) a >>> tensor([[[[0, 1],[2, 3]]],[[[4, 5],[6, 7]]]], dtype=torch.int32)之后從新reshape一下,將w和h通道的數(shù)據合起來,變成向量形式
features = a.view(2, 1, 4) features >>> tensor([[[0, 1, 2, 3]],[[4, 5, 6, 7]]], dtype=torch.int32)為了構造計算Gram矩陣的向量,對shape進行一個交換操作
features_t = features.transpose(1, 2) features_t >>> tensor([[[0],[1],[2],[3]],[[4],[5],[6],[7]]], dtype=torch.int32)之后用矩陣乘法把這兩個向量乘起來就可以了,就計算出Gram矩陣了。
gram = features.bmm(features_t) gram >>> tensor([[[ 14]],[[126]]], dtype=torch.int32)Reference
[1]Gram格拉姆矩陣在風格遷移中的應用
[2]如何直觀地理解「協(xié)方差矩陣」
[3]如何通俗易懂地解釋「協(xié)方差」與「相關系數(shù)」的概念?
總結
以上是生活随笔為你收集整理的如何对batch的数据求Gram矩阵的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 抽象基类和纯虚函数
- 下一篇: 训练生成对抗网络的过程中,训练gan的地