python 笔记:爱因斯坦求和 einsum
1 einsum簡(jiǎn)介
????????使用愛(ài)因斯坦求和約定,可以以簡(jiǎn)單的方式表示許多常見(jiàn)的多維線性代數(shù)數(shù)組運(yùn)算。
????????給定兩個(gè)矩陣A和B,我們想對(duì)它們做一些操作,比如 multiply、sum或者transpose等。雖然numpy里面有可以直接使用的接口,能夠?qū)崿F(xiàn)這些功能,但是使用enisum可以做的更快、更節(jié)省空間。
????????舉例說(shuō)明,我們現(xiàn)在有兩個(gè)矩陣A和B。我們想計(jì)算A和B的哈達(dá)瑪乘積(即逐元素乘積),然后按行求和。
import numpy as np A = np.array([0, 1, 2]) B = np.array([[ 0, 1, 2, 3],[ 4, 5, 6, 7],[ 8, 9, 10, 11]])? ? ? ? 如果我們不適用einsum的話,也不是不能計(jì)算,就是需要寫(xiě)幾步來(lái)完成:
A.reshape(-1,1) ''' array([[0],[1],[2]]) '''A.reshape(-1,1)*B ''' array([[ 0, 0, 0, 0],[ 4, 5, 6, 7],[16, 18, 20, 22]]) '''(A.reshape(-1,1)*B).sum(axis=1) #array([ 0, 22, 76])? ? ? ? 如果我們使用einsum的話,一行就可以實(shí)現(xiàn):
np.einsum('i,ij->i', A, B) #array([ 0, 22, 76])? 2 einsum原理
????????????????使用einsum的關(guān)鍵是,正確地labelling(標(biāo)記)輸入數(shù)組和輸出數(shù)組的axes(軸)。
????????????????我們可以使用字符串(比如:ijk,這種表示方式更常用)或者一個(gè)整數(shù)列表(比如:[0,1])來(lái)標(biāo)記axes。
? ? ? ? ? ? ? ? 比如,為了實(shí)現(xiàn)矩陣乘法,我們可以用einsum這么寫(xiě)(至于為什么這個(gè)是矩陣乘法,我們?cè)诤竺鏁?huì)說(shuō)明)
np.einsum('ij,jk->ik', A, B)????????????????字符串'ij,jk->ik'可以根據(jù)'->'的位置來(lái)切分,左邊的部分('ij,jk')標(biāo)記了輸入的axes,右邊的('ik')標(biāo)記了輸出的axes。
????????????????輸入標(biāo)記又根據(jù)','的位置進(jìn)行切分,'ij'標(biāo)記了第一個(gè)輸入A的axes,'jk'標(biāo)記了第二個(gè)輸入B的axes。
????????????????'ij'、'jk'的字符長(zhǎng)度都是2,對(duì)應(yīng)著A和B為2D數(shù)組,'ik'的長(zhǎng)度也為2,因此輸出也是2D數(shù)組。
? ? ? ? ? ? ? ? 給定輸入
A = np.array([[1, 3, 5],[7, 9, -7],[-5, -3, -1]]) B = np.array([[0, 2,4],[6, 8, 6],[4, 2, 0]])?????????np.einsum('ij,jk->ik', A, B)可以看作是:
- 在輸入數(shù)組的標(biāo)記之間,重復(fù)字母表示沿這些軸的值將相乘,這些乘積構(gòu)成輸出數(shù)組的值。比如圖中沿著j軸做乘積。
- 從輸出標(biāo)記中省略的字母表示沿該軸的值將被求和。比如圖中的輸出沒(méi)有包含j軸,因此沿著j軸求和得到了輸出數(shù)組中的每一項(xiàng)。
-
如果輸出的標(biāo)記是'ijk',那么會(huì)得到一個(gè) 3x3x3 的矩陣。?
-
?輸出標(biāo)記是'ik'的時(shí)候,并不會(huì)創(chuàng)建中間的 3x3x3 的矩陣,而是直接將總和累加到2D數(shù)組中。
A = np.array([[1, 3, 5],[7, 9, -7],[-5, -3, -1]]) B = np.array([[0, 2,4],[6, 8, 6],[4, 2, 0]]) np.einsum('ij,jk->ik', A, B) ''' array([[ 38, 36, 22],[ 26, 72, 82],[-22, -36, -38]]) '''如果輸出的標(biāo)記是空,那么輸出整個(gè)矩陣的和
A = np.array([[1, 3, 5],[7, 9, -7],[-5, -3, -1]]) B = np.array([[0, 2,4],[6, 8, 6],[4, 2, 0]]) np.einsum('ij,jk->', A, B) #180?我們可以按任意順序排序不求和的軸。
A = np.array([[1, 3, 5],[7, 9, -7],[-5, -3, -1]]) B = np.array([[0, 2,4],[6, 8, 6],[4, 2, 0]]) np.einsum('ij,jk->kji', A, B)''' array([[[ 0, 0, 0],[ 18, 54, -18],[ 20, -28, -4]],[[ 2, 14, -10],[ 24, 72, -24],[ 10, -14, -2]],[[ 4, 28, -20],[ 18, 54, -18],[ 0, 0, 0]]]) '''3 einsum分析
3.1?'ij,jk->ijk'? 與 'ij,jk->kji'
?我們一個(gè)一個(gè)分析一下
A = np.array([[1, 3, 5],[7, 9, -7],[-5, -3, -1]]) B = np.array([[0, 2,4],[6, 8, 6],[4, 2, 0]]) np.einsum('ij,jk->ijk', A, B)''' array([[[ 0, 2, 4],[ 18, 24, 18],[ 20, 10, 0]],[[ 0, 14, 28],[ 54, 72, 54],[-28, -14, 0]],[[ 0, -10, -20],[-18, -24, -18],[ -4, -2, 0]]]) '''首先,這幾個(gè)數(shù)字是怎么得到的?
| 0=1*0 | 2=1*2 | 4=1*4 |
| 18=3*6 | 24=3*8 | 18=3*6 |
| 20=5*4 | 10=5*2 | 0=5*0 |
| 0=7*0 | 14=7*2 | 28=7*4 |
| 54=9*6 | 72=9*8 | 54=9*6 |
| -28=-7*4 | -14=-7*2 | 0=-7*0 |
| 0=-5*0 | -10=-5*2 | -20=-5*4 |
| -18=-3*6 | -24=-3*8 | -18=-3*6 |
| -4=-1*4 | -2=-1*2 | 0=-1*0 |
轉(zhuǎn)換成坐標(biāo),有:
| [0,0]*[0,0] | [0,0]*[0,1] | [0,0]*[0,2] |
| [0,1]*[1,0] | [0,1]*[1,1] | [0,1]*[1,2] |
| [0,2]*[2,0] | [0,2]*[2,1] | [0,2]*[2,2] |
| [1,0]*[0,0] | [1,0]*[0,2] | [1,0]*[0,4] |
| [1,1]*[1,0] | [1,1]*[1,1] | [1,1]*[1,2] |
| [1,2]*[2,0] | [1,2]*[2,1] | [1,2]*[2,2] |
| [2,0]*[0,0] | [2,0]*[0,1] | [2,0]*[0,2] |
| [2,1]*[1,0] | [2,1]*[1,1] | [2,1]*[1,2] |
| [2,2]*[2,0] | [2,2]*[2,1] | [2,2]*[2,2] |
?與上面類似,我們就看第一個(gè)3*3的矩陣吧
| 0=1*0 | 0=7*0 | 0=-5*0 |
| 18=3*6 | 54=9*6 | -18=-3*6 |
| 20=5*4 | -28=-7*4 | -4=-1*4 |
| [0,0]*[0,0] | [1,0]*[0,0] | [2,0]*[0,0] |
| [0,1]*[1,0] | [1,1]*[1,0] | [2,1]*[1,0] |
| [0.2]*[2,0] | [1,2]*[2,0] | [2,2]*[2,0] |
可以這么考慮 對(duì)于 結(jié)果矩陣(比如ijk),第【i,j,k】元素的結(jié)果等于【i,j】乘以【j,k】
3.2?'ij,jk->ik'??
A = np.array([[1, 3, 5],[7, 9, -7],[-5, -3, -1]]) B = np.array([[0, 2,4],[6, 8, 6],[4, 2, 0]]) np.einsum('ij,jk->ik', A, B) ''' array([[ 38, 36, 22],[ 26, 72, 82],[-22, -36, -38]]) '''我們前面?'ij,jk->ijk'的結(jié)果是?
| 0=1*0 | 2=1*2 | 4=1*4 |
| 18=3*6 | 24=3*8 | 18=3*6 |
| 20=5*4 | 10=5*2 | 0=5*0 |
| 0=7*0 | 14=7*2 | 28=7*4 |
| 54=9*6 | 72=9*8 | 54=9*6 |
| -28=-7*4 | -14=-7*2 | 0=-7*0 |
| 0=-5*0 | -10=-5*2 | -20=-5*4 |
| -18=-3*6 | -24=-3*8 | -18=-3*6 |
| -4=-1*4 | -2=-1*2 | 0=-1*0 |
這邊相當(dāng)于
????????
| 38=0+18+20 | 36=2+24+10 | 22=4+18 |
| 26=54-28 | 72=14+72-14 | 82=54+28 |
| -22=-18-4 | -36=-10-24-2 | -38=-20-18 |
可以這么考慮 對(duì)于 結(jié)果矩陣(比如ik),第【i,k】元素的結(jié)果等于:對(duì)所有的j,【i,j】乘以【j,k】的結(jié)果的和
4 常用的Einsum
4.1 向量篇
| ('i',A) | 向量A的一個(gè)視圖 可以看成'i->i',即結(jié)果的第i位,是A的第i位 |
| ('i->', A) | sum(A) 可以看成'i->0',即結(jié)果的第0位,是A的第i位的和 |
| ('i,i->i', A,B) | 向量A,B對(duì)應(yīng)位置相乘 'i,i->i':結(jié)果的第i位,是A和B的第i位的積? |
| ('i,i->', A,B) | 向量A,B的內(nèi)積 ?'i,i->':可以看成'i,i->0' 結(jié)果的第0位,是A和B的第i位的積 再求和? |
| ('i,j->ij', A,B) | 向量A,B的外積 ?'i,j->ij':結(jié)果的第i行第j列,是A的第i個(gè)元素和B的第j個(gè)元素的乘積 |
?4.2 矩陣篇
| ('ij', A) | 返回矩陣A 看成'ij->ij' ,結(jié)果的第i行第j列,是A的第i行第j列 |
| ('ji->ij', A) | 返回矩陣A的轉(zhuǎn)置 結(jié)果的第i行第j列的元素是A的第j行第i列的元素? |
| ('ii->i', A) | 矩陣A的對(duì)角線元素 結(jié)果的第i個(gè)元素是A的第i行第i列的元素? |
| ('ij->', A) | 矩陣A的元素之和 可以看成'ij->0' 結(jié)果的第0位是A的第i行第j列的元素,再求和? |
| ('ij->j', A) | A縱向求和 結(jié)果的第j個(gè)元素是,對(duì)所有的i,A的第(i,j)個(gè)元素的和? |
| ('ij->i', A) | A橫向求和 結(jié)果的第i個(gè)元素是,對(duì)所有的j,A的第(i,j)個(gè)元素的和? |
| ('ij,ij->ij', A,B) | 矩陣A,B相應(yīng)位置的乘積 結(jié)果第i,j個(gè)元素,是A的第(i,j)個(gè)元素和B的第(i,j)個(gè)元素的乘積? |
| ('ij,ji->ij', A,B) | 矩陣A和? 矩陣B的轉(zhuǎn)置? ?相應(yīng)位置的乘積 結(jié)果第i,j個(gè)元素,是A的第(i,j)個(gè)元素和B的第(j,i)個(gè)元素的乘積? |
| ('ij,jk->ik', A,B) | A,B的矩陣乘積 結(jié)果的第(i,k)個(gè)元素等于A的第(i,j)個(gè)元素乘以B的第(j,k)個(gè)元素? |
| ('ij,kj->ik', A,B) | 矩陣A和矩陣B的內(nèi)積 |
| ('ij,kl->jikl', A,B) | A的每個(gè)元素乘以矩陣B 結(jié)果的 第(j,i,k,l)個(gè)元素是A的(i,j)和B的(k,l)的乘積? |
| ('dn,nd->',A,B) | 相當(dāng)于tr(AB) ? |
?5 顯示表明和隱式表明
我們將指定'->'和輸出標(biāo)記稱為 explicit mode。
如果不指定'->'和輸出標(biāo)記,numpy會(huì)將輸入標(biāo)記中只出現(xiàn)一次的標(biāo)記按照字母表順序,作為輸出標(biāo)記(也就是 implicit mode)。
'ij,jk->ik' 等價(jià)于 'ij,jk'
參考文章:einsum初探 - 知乎 (zhihu.com)
總結(jié)
以上是生活随笔為你收集整理的python 笔记:爱因斯坦求和 einsum的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 文巾解题 113. 路径总和 II
- 下一篇: pytorch 学习: STGCN