GEMM算法及优化流程详解
目錄
前言
im2col+GEMM算法簡介
GEMM算法優(yōu)化
optimize1
optimize2
optimize3
前言
神經(jīng)網(wǎng)絡(luò)前向耗時主要由卷積的耗時決定,參考賈楊青畢業(yè)論文,那么如何對卷積加速便成了重要的一個點,主流的加速方法有
以下幾種:
im2col+GEMM:目前幾乎所有的主流計算框架包括 Caffe, MXNet 等都實現(xiàn)了該方法. 該方法把整個卷積過程轉(zhuǎn)化成了GEMM過程,而GEMM在各種 BLAS 庫中都是被極致優(yōu)化的,一般來說,速度較快。
Winograd: Winograd 是存在已久最近被重新發(fā)現(xiàn)的方法,在大部分場景中, Winograd方法都顯示和較大的優(yōu)勢,目前cudnn中計算卷積就使用了該方法。
Strassen:1969年,Volker Strassen提出了第一個時間復(fù)雜度低于O(N^3)的算法,其復(fù)雜度為O(N^(2^(log2(7)))),但這種方法只在大卷積核情況下優(yōu)勢才比較明顯,目前還沒有在開源框架中見到這種方法。
FFT:傅里葉變換和快速傅里葉變化是在經(jīng)典圖像處理里面經(jīng)常使用的計算方法,但是,在 ConvNet中通常不采用,主要是因為在 ConvNet 中的卷積模板通常都比較小,例如?3×3?等,這種情況下,FFT 的時間開銷反而更大,所以很少在CNN中利用FFT實現(xiàn)卷積。
很高興你看完前言:最近發(fā)現(xiàn)這篇文章寫的很好,阿里那邊的,《支付寶如何優(yōu)化移動端深度學(xué)習(xí)引擎》推薦給大家~
?
im2col+GEMM算法簡介
GEMM在深度學(xué)習(xí)中是十分重要的,全連接層以及卷積層基本上都是通過GEMM來實現(xiàn)的,而網(wǎng)絡(luò)中大約90%的運算都是在這兩層中。而一個良好的GEMM的實現(xiàn)可以充分利用系統(tǒng)的多級存儲結(jié)構(gòu)和程序執(zhí)行的局部性來充分加速運算。
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?常規(guī)的卷積操作為:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ???? ? ? ?
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? 3維卷積運算執(zhí)行完畢,得一個2維的平面:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?
將卷積操作的3維立體變?yōu)槎S矩陣乘法,可以調(diào)用BLAS中的GEMM庫,按 [kernel_height, kernel_width, kernel_depth] ? 將輸入分成 3 維的 patch,并將其展成一維向量:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?
此時的卷積操作就可轉(zhuǎn)化為矩陣乘法:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ??
下面我們將以M=K=N=600為例說明GEMM算法的優(yōu)化過程:
?
直接暴力卷積:
for (int m = 0; m < M; m++) {for (int n = 0; n < N; n++) {for (int k = 0; k < K; k++) {C[m][n]+= A[m][k] * B[k][n];}} }上述公式總計算量為2MNK FLOPs(其中 𝑀、𝑁、𝐾 分別指代三層循環(huán)執(zhí)行的次數(shù),2 指代循環(huán)最內(nèi)層的一次乘法和加法) ,內(nèi)存訪問操作總數(shù)為 4MNK(其中 2MNK 指代對 𝐶 的內(nèi)存訪問,𝐶 需要先讀取內(nèi)存、累和再存儲)。GEMM 的優(yōu)化均以此為基點。
耗時分析:上述暴力gemm代碼耗時約為872ms
?
GEMM算法優(yōu)化
optimize1
首先能想到的就是減少C矩陣的訪存次數(shù),將C[m][n]放到外面,全部累和之后再賦值即可:
for (int m = 0; m < M; m++) {for (int n = 0; n < N; n++) {float temp = C[m][n];for (int k = 0; k < K; k++) {temp += A[m][k] * B[k][n];}C[m][n] = temp;} }上述公式總計算量依然為2MNK FLOPs,內(nèi)存訪問操作總數(shù)為 2MNK+2MN(其中 2MN?指代對 𝐶 的內(nèi)存訪問,𝐶 需要先讀取內(nèi)存、累加完畢在存儲)。
耗時分析:上述代碼耗時約為791ms,耗時變少的原因是減少了部分C的訪存
?
optimize2
將輸出的計算拆分為 1×4 的小塊,即將 𝑁 維度拆分為兩部分。計算該塊輸出時,需要使用 𝐴 矩陣的 1 行,和 𝐵 矩陣的 4 列。
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ??
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?圖一:矩陣乘計算?1×4輸出
下面是該計算的偽代碼表示,這里已經(jīng)將 1×4 中 N 維度的內(nèi)部拆分進行了展開。這里的計算量仍然是 2𝑀𝑁𝐾 ,這一點在本文中不會有變化。
for (int m = 0; m < M; m++) {for (int n = 0; n < N; n += 4) {float temp_m0n0 = C[m][n + 0];float temp_m0n1 = C[m][n + 1];float temp_m0n2 = C[m][n + 2];float temp_m0n3 = C[m][n + 3];for (int k = 0; k < K; k++) {float temp = A[m][k];temp_m0n0 += temp * B[k][n + 0];temp_m0n1 += temp * B[k][n + 1];temp_m0n2 += temp * B[k][n + 2];temp_m0n3 += temp * B[k][n + 3];}C[m][n + 0] = temp_m0n0;C[m][n + 1] = temp_m0n1;C[m][n + 2] = temp_m0n2;C[m][n + 3] = temp_m0n3;} }簡單的觀察即可發(fā)現(xiàn),上述偽代碼的最內(nèi)側(cè)計算使用的矩陣 𝐴 的元素是一致的。因此可以將 𝐴[𝑚][𝑘] 讀取到寄存器中,從而實現(xiàn) 4 次數(shù)據(jù)復(fù)用(這里不再給出示例)。一般將最內(nèi)側(cè)循環(huán)稱作計算核(micro kernel)。進行這樣的優(yōu)化后,內(nèi)存訪問操作數(shù)量變?yōu)?2MN+5/4MNK,訪存約為上面的5/8。
耗時分析:本優(yōu)化耗時約為473ms,相比暴力耗時減少300ms左右,可能的兩個原因:1、由于B是行優(yōu)先排列,1x4方法能夠減少數(shù)據(jù)從內(nèi)存到cache的加載次數(shù);2、合理利用寄存器,減少對𝐴矩陣訪存次數(shù)
?
optimize3
類似地,我們可以繼續(xù)拆分輸出的 𝑀 維度,從而在內(nèi)側(cè)循環(huán)中計算 4×4 輸出,如圖二。
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?圖二:矩陣乘計算?4×4輸出
同樣地,將計算核心展開,可以得到下面的偽代碼。由于乘數(shù)效應(yīng),4×4 的拆分可以將對輸入數(shù)據(jù)的訪存縮減到 MN/16*(16*2+8K)=2MN+1/2*MNK。這相對于最開始的 4MNK 已經(jīng)得到了 8X 的改進,這些改進都是通過展開循環(huán)后利用寄存器存儲數(shù)據(jù)減少訪存得到的。
for (int m = 0; m < M; m += 4) {for (int n = 0; n < N; n += 4) {float temp_m0n0 = C[m + 0][n + 0];float temp_m0n1 = C[m + 0][n + 1];float temp_m0n2 = C[m + 0][n + 2];float temp_m0n3 = C[m + 0][n + 3];float temp_m1n0 = C[m + 1][n + 0];float temp_m1n1 = C[m + 1][n + 1];float temp_m1n2 = C[m + 1][n + 2];float temp_m1n3 = C[m + 1][n + 3];float temp_m2n0 = C[m + 2][n + 0];float temp_m2n1 = C[m + 2][n + 1];float temp_m2n2 = C[m + 2][n + 2];float temp_m2n3 = C[m + 2][n + 3];float temp_m3n0 = C[m + 3][n + 0];float temp_m3n1 = C[m + 3][n + 1];float temp_m3n2 = C[m + 3][n + 2];float temp_m3n3 = C[m + 3][n + 3];for (int k = 0; k < K; k++) {float temp_m0 = A[m + 0][k];float temp_m1 = A[m + 1][k];float temp_m2 = A[m + 2][k];float temp_m3 = A[m + 3][k];float temp_n0 = B[k][n + 0];float temp_n1 = B[k][n + 1];float temp_n2 = B[k][n + 2];float temp_n3 = B[k][n + 3];temp_m0n0 += temp_m0 * temp_n0;temp_m0n1 += temp_m0 * temp_n1;temp_m0n2 += temp_m0 * temp_n2;temp_m0n3 += temp_m0 * temp_n3;temp_m1n0 += temp_m1 * temp_n0;temp_m1n1 += temp_m1 * temp_n1;temp_m1n2 += temp_m1 * temp_n2;temp_m1n3 += temp_m1 * temp_n3;temp_m2n0 += temp_m2 * temp_n0;temp_m2n1 += temp_m2 * temp_n1;temp_m2n2 += temp_m2 * temp_n2;temp_m2n3 += temp_m2 * temp_n3;temp_m3n0 += temp_m3 * temp_n0;temp_m3n1 += temp_m3 * temp_n1;temp_m3n2 += temp_m3 * temp_n2;temp_m3n3 += temp_m3 * temp_n3;}C[m + 0][n + 0] = temp_m0n0;C[m + 0][n + 1] = temp_m0n1;C[m + 0][n + 2] = temp_m0n2;C[m + 0][n + 3] = temp_m0n3;C[m + 1][n + 0] = temp_m1n0;C[m + 1][n + 1] = temp_m1n1;C[m + 1][n + 2] = temp_m1n2;C[m + 1][n + 3] = temp_m1n3;C[m + 2][n + 0] = temp_m2n0;C[m + 2][n + 1] = temp_m2n1;C[m + 2][n + 2] = temp_m2n2;C[m + 2][n + 3] = temp_m2n3;C[m + 3][n + 0] = temp_m3n0;C[m + 3][n + 1] = temp_m3n1;C[m + 3][n + 2] = temp_m3n2;C[m + 3][n + 3] = temp_m3n3;} }耗時分析:本優(yōu)化耗時約為354ms,相比1x4耗時減少120ms左右
總結(jié)
以上是生活随笔為你收集整理的GEMM算法及优化流程详解的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: tensorflow量化策略详解
- 下一篇: linux静态库的打包及链接使用