python矩阵运算dot_矩阵、张量乘法(numpy.tensordot)的时间复杂度分析
兩個大小都是\(N \times N\)的矩陣相乘,如果使用naive的算法,時間復(fù)雜度應(yīng)該是\(\mathcal{O}(N^3)\),如果使用一些高級的算法,可以使冪指數(shù)降到3以下。對于一般情況的矩陣乘法,特別是張量乘法(numpy中的tensordot函數(shù)),時間復(fù)雜度又如何呢?
二維矩陣乘法
首先規(guī)定一下記號:\(\mathbf{A}_{MN}\),表示一個有兩個指標,大小是\(M\times N\)的矩陣\(\mathbf{A}\)。那么\(\mathbf{A}_{MN}\mathbf{B}_{NL}\)的時間復(fù)雜度是\(\mathcal{O}(MNL)\)。如果我們把乘法的過程用計算機語言表示出來,這一結(jié)論就會非常清晰:
1
2
3
4
5C = np.zeros((M, L))
for m in range(M):
for l in range(L):
for n in range(N):
C[m][l] += A[m][n] * B[n][l]
我們也可以簡單地驗證一下numpy.dot函數(shù)是否滿足這樣的時間復(fù)雜度,首先變化\(M\)。為了節(jié)省篇幅,一次將其擴大到四倍:
1
2
3
4
5
6
7
8M = 71
N = 513
L = 4097
for i in range(5):
m1 = np.random.random((M, N))
m2 = np.random.random((N, L))
%timeit m1.dot(m2)
M *= 4
輸出是:
1
2
3
4
5100loops, best of 3: 6.82 ms per loop
10loops, best of 3: 22.5 ms per loop
10loops, best of 3: 77.5 ms per loop
1loop, best of 3: 304 ms per loop
1loop, best of 3: 1.38 s per loop
可見基本是線性的(耗時一次擴大到四倍)。然后變化\(N\),代碼和上面的一段只變了一個字母,輸出是:
1
2
3
4
5100loops, best of 3: 6.79 ms per loop
10loops, best of 3: 22.1 ms per loop
10loops, best of 3: 84.4 ms per loop
1loop, best of 3: 329 ms per loop
1loop, best of 3: 1.31 s per loop
仍然基本是線性的。最后變化\(L\),輸出是:
1
2
3
4
5100loops, best of 3: 8.42 ms per loop
10loops, best of 3: 43.5 ms per loop
10loops, best of 3: 115 ms per loop
1loop, best of 3: 408 ms per loop
1loop, best of 3: 1.88 s per loop
耗時是三組實驗中最長的。結(jié)果匯總起來如下圖
不難發(fā)現(xiàn),時間與矩陣維度的關(guān)系是線性的且斜率為1,所以\(\mathbf{A}_{MN}\mathbf{B}_{NL}\)的時間復(fù)雜度是\(\mathcal{O}(MNL)\)。
高維矩陣(張量)乘法-只對一個軸求和
在numpy中dot,einsum,tensordot等函數(shù)都可以做高維矩陣乘法,這里只研究最常見的tensordot。我們從\(\mathbf{A}_{MNL}\mathbf{B}_{LPQ}\)這樣一個例子入手。從理論上分析,\(\mathbf{A}_{MNL}\mathbf{B}_{LPQ}\)的時間復(fù)雜度是\(\mathcal{O}(MNLPQ)\),感興趣的讀者可以自己寫寫代碼分析,或者看一看我之前寫的一篇博文。這里簡單做一下實驗,變化\(M\):
1
2
3
4
5
6
7
8
9
10M = 63
N = 17
L = 255
P = 127
Q = 31
for i in range(5):
m1 = np.random.random((M, N, L))
m2 = np.random.random((L, P, Q))
%timeit np.tensordot(m1, m2, 1)
M *= 4
輸出是:
1
2
3
4
510loops, best of 3: 47.6 ms per loop
1loop, best of 3: 166 ms per loop
1loop, best of 3: 700 ms per loop
1loop, best of 3: 2.7 s per loop
1loop, best of 3: 11.5 s per loop
而變化\(L\)輸出是:
1
2
3
4
510loops, best of 3: 46.3 ms per loop
10loops, best of 3: 116 ms per loop
1loop, best of 3: 368 ms per loop
1loop, best of 3: 1.52 s per loop
1loop, best of 3: 6 s per loop
如圖所示:
類似地,耗時與\(M\)和\(L\)都是線性關(guān)系,后者速度貌似比前者略快。
高維矩陣(張量)乘法-對多個軸求和
下面我們再考慮對多個軸求和的情況,這種情況下“數(shù)學(xué)語言”已經(jīng)不好給出清晰的描述了。如果想舉個例子,也只能啰嗦地說:\(\mathbf{A}_{MNL}\)和\(\mathbf{B}_{NLP}\)之間進行雙點積contract掉維數(shù)為\(N\)和\(L\)的兩個指標。倒是計算機語言還算游刃有余:
1
2
3
4
5
6C = np.zeros((M, P))
for m in range(M):
for p in range(P):
for n in range(N):
for l in range(L):
C[m][p] += A[m][n][l] * B[n][l][p]
也容易據(jù)此估計出時間復(fù)雜度為\(\mathcal{O}(MNLP)\)。實驗一下的話,首先試試\(M\):
1
2
3
4
5
6
7
8
9M = 63
N = 31
L = 255
P = 127
for i in range(5):
m1 = np.random.random((M, N, L))
m2 = np.random.random((N, L, P))
%timeit np.tensordot(m1, m2, 2)
M *= 4
輸出為:
1
2
3
4
5100loops, best of 3: 2.41 ms per loop
100loops, best of 3: 5.8 ms per loop
10loops, best of 3: 23.2 ms per loop
10loops, best of 3: 171 ms per loop
1loop, best of 3: 817 ms per loop
然后\(N\)和\(L\)分別為:
1
2
3
4
5100loops, best of 3: 2.43 ms per loop
100loops, best of 3: 8.69 ms per loop
10loops, best of 3: 33.7 ms per loop
10loops, best of 3: 138 ms per loop
1loop, best of 3: 560 ms per loop
和
1
2
3
4
5100loops, best of 3: 2.69 ms per loop
100loops, best of 3: 9.01 ms per loop
10loops, best of 3: 36.2 ms per loop
10loops, best of 3: 140 ms per loop
1loop, best of 3: 563 ms per loop
總結(jié)起來如圖所示:
結(jié)語
總結(jié)規(guī)律的話,要想知道矩陣、張量乘法的時間復(fù)雜度,就把兩個矩陣、張量所有沒contract掉的維度乘起來,再把contract掉的維度兩個取一個乘起來即可。舉個例子:\(\mathbf{A}_{MNL}\mathbf{B}_{LPQ}\),沒有contract掉的維度乘起來即\(NMPQ\),contract掉的維度有兩個\(L\),只取一個,最后合起來就是\(\mathcal{O}(MNLPQ)\)。
這一規(guī)律其實很好理解。np.tensordot在實現(xiàn)時實際上是對普通的np.dot的一個包裝,進行了一些前處理和后處理。所謂前處理,基本上就是通過轉(zhuǎn)置和合并(np.reshape)把兩個參與運算的高階張量分別變成矩陣,其中一個指標是原張量所有沒contract掉的指標組成的,維度自然就是這些指標的維度的積,而另一個指標是原張量要進行contract的指標組成的,維度也是這些指標的維度的積。而后處理,就是將np.dot之后的結(jié)果再通過np.reshape變回原來的形狀。np.tensordot的代碼位于numpy/core/numeric.py中,核心部分如下圖所示(NumPy 1.15):
1
2
3
4at = a.transpose(newaxes_a).reshape(newshape_a)
bt = b.transpose(newaxes_b).reshape(newshape_b)
res = dot(at, bt)
return res.reshape(olda + oldb)
其中a和b是調(diào)用者傳入的要進行tensordot的矩陣,newaxes_a等參數(shù)是根據(jù)調(diào)用者指定的contract規(guī)則確定的用于將a或者b變形為適合進行np.dot的參數(shù)。得到變形后的at和bt后直接進行dot,再將中間結(jié)果reshape回去就得到了最終的結(jié)果。所以張量乘法的時間復(fù)雜度與矩陣乘法的時間復(fù)雜度其實是一回事。
總結(jié)
以上是生活随笔為你收集整理的python矩阵运算dot_矩阵、张量乘法(numpy.tensordot)的时间复杂度分析的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: python打包的程序很大_Pyinst
- 下一篇: sap 获取计划订单bapi_sapba