Neon 指令
1.向量相乘
?
可以看出上面的運算就是向量的每一維相乘然后相加,相乘之間具有良好的并行性,所以可以通過ARM NEON intrinsic指令進行加速。下面是代碼實現:
inline float dot(const float* A, const float* B, int K){float sum = 0;float32x4_t sum_vec = vdupq_n_f32(0);//, left_vec, right_vec;for (int k = 0; k<K; k += 8){sum_vec = vmlap_f32(sum_vec, vld1q_f32(A + k), vld1q_f32(B + k));sum_vec = vmlap_f32(sum_vec, vld1q_f32(A + k + 4), vld1q_f32(B + k + 4));// sum_vec = vmlap_f32(sum_vec, vld1q_f32(A + k+8), vld1q_f32(B+k+8));// sum_vec = vmlap_f32(sum_vec, vld1q_f32(A + k+12), vld1q_f32(B+k+12));}float32x2_t r = vadd_f32(vget_high_f32(sum_vec), vget_low_f32(sum_vec));sum += vget_lane_f32(vpadd_f32(r, r), 0);return sum;}?
代碼比較簡單,核心代碼就是先將兩個數組每次4個存入ARM NEON intrinsic下的128位變量中,然后利用一個乘加指令計算4個乘積的累加和。最后將4個sum再相加就得到最終的結果。相比于串行代碼,上面的代碼有接近4倍的加速比。當數據類型是short或者char時,可以取得更高的加速比,下面以char舉例:
int dot(char* A,char* B,int K) {int sum=0;int16x8_t sum_vec=vdupq_n_s16(0);int8x8_t left_vec, right_vec;int32x4_t part_sum4;int32x2_t part_sum2;//有溢出的風險for(k=0;k<K;k+=8){left_vec=vld1_s8(A+A_pos+k);right_vec=vld1_s8(B+B_pos+k);sum_vec=vmlal_s8(sum_vec,left_vec,right_vec);}part_sum4=vaddl_s16(vget_high_s16(sum_vec),vget_low_s16(sum_vec));part_sum2=vadd_s32(vget_high_s32(part_sum4),vget_low_s32(part_sum4));sum+=vget_lane_s32(vpadd_s32(part_sum2,part_sum2),0);return sum; }基于char類型的點積代碼和float類型的類似,不過由于char乘法存在溢出的可能性,所以相乘之后我們需要將數據類型升級成short。上面的代碼也特別注釋了一句:可能存在溢出,這是因為單個乘法不會溢出,但是乘法的結果相加可能會存在溢出。如果合理設計兩個向量的值溢出的概率就會很低,更重要的一點是上面代碼的加速比是float類型的2倍還要多,所以在速度要求非常嚴格的程序中,上面代碼會帶來非常明顯的速度提升。
?
2.exp加速
?
算法的基本原理是考慮了float數據類型在內存中的布局而精巧設計的,想了解更多細節可以參考原博客,本文只介紹如何將其用ARM NEON intrinsic指令進行加速(相比原始博客,代碼中第二個常量有點變化,該新常量是我試驗出來的,誤差更小)。?
ARM NEON intrinsic指令的優勢是并行計算,所以我們對一個數組的每一個元素進行exp并相加,然后將其加速:
在原始算法中是先計算(1<<23),然后將其和另外一部分相乘,我們將其簡化成一個乘加操作:12102203.1616540672*x+1064807160.56887296。算法和點積很相似,先加載4個變量,然后執行乘加操作。之后的操作首先是將float類型的變量轉成int型變量,之后再通過地址強轉獲取float值并累加。相比原始的exp累加,速度能有5、6倍左右的提升。?
?
sse_neon_search.h
#ifndef _SSE_NEON_SEARCH__ #define _SSE_NEON_SEARCH__#ifdef __GNUC__ #include <arm_neon.h> #elif _WIN32 #include <immintrin.h> #define zq_mm256_fmadd_ps _mm256_fmadd_ps #endifnamespace sse_neon_search {/**windows dim shoud be [128, 256, 512] , linux dim should be aligned to 8.*/ #ifdef __GNUC__inline float _cal_similarity_avx_neon(float* pt1, float* pt2, int dim){float sum = 0;float32x4_t sum_vec = vdupq_n_f32(0);//, left_vec, right_vec;float *A1 = pt1;float *B1 = pt2;for (int k = 0; k < dim; k += 8){sum_vec = vmlaq_f32(sum_vec, vld1q_f32(A1), vld1q_f32(B1));A1 += 4;B1 += 4;sum_vec = vmlaq_f32(sum_vec, vld1q_f32(A1), vld1q_f32(B1));A1 += 4;B1 += 4;}float32x2_t r = vadd_f32(vget_high_f32(sum_vec), vget_low_f32(sum_vec));sum += vget_lane_f32(vpadd_f32(r, r), 0);return sum;}#elif _WIN32inline float _cal_similarity_avx_neon(float* pt1, float* pt2, int dim){if (dim == 128) {_declspec(align(32)) float q[8];__m256 sum_vec = _mm256_mul_ps(_mm256_load_ps(pt1), _mm256_load_ps(pt2));sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 8), _mm256_load_ps(pt2 + 8), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 16), _mm256_load_ps(pt2 + 16), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 24), _mm256_load_ps(pt2 + 24), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 32), _mm256_load_ps(pt2 + 32), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 40), _mm256_load_ps(pt2 + 40), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 48), _mm256_load_ps(pt2 + 48), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 56), _mm256_load_ps(pt2 + 56), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 64), _mm256_load_ps(pt2 + 64), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 72), _mm256_load_ps(pt2 + 72), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 80), _mm256_load_ps(pt2 + 80), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 88), _mm256_load_ps(pt2 + 88), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 96), _mm256_load_ps(pt2 + 96), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 104), _mm256_load_ps(pt2 + 104), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 112), _mm256_load_ps(pt2 + 112), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 120), _mm256_load_ps(pt2 + 120), sum_vec);_mm256_store_ps(q, sum_vec);float score = q[0] + q[1] + q[2] + q[3] + q[4] + q[5] + q[6] + q[7];return score;}else if (dim == 256) {_declspec(align(32)) float q[8];__m256 sum_vec = _mm256_mul_ps(_mm256_load_ps(pt1), _mm256_load_ps(pt2));sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 8), _mm256_load_ps(pt2 + 8), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 16), _mm256_load_ps(pt2 + 16), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 24), _mm256_load_ps(pt2 + 24), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 32), _mm256_load_ps(pt2 + 32), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 40), _mm256_load_ps(pt2 + 40), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 48), _mm256_load_ps(pt2 + 48), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 56), _mm256_load_ps(pt2 + 56), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 64), _mm256_load_ps(pt2 + 64), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 72), _mm256_load_ps(pt2 + 72), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 80), _mm256_load_ps(pt2 + 80), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 88), _mm256_load_ps(pt2 + 88), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 96), _mm256_load_ps(pt2 + 96), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 104), _mm256_load_ps(pt2 + 104), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 112), _mm256_load_ps(pt2 + 112), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 120), _mm256_load_ps(pt2 + 120), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 128), _mm256_load_ps(pt2 + 128), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 136), _mm256_load_ps(pt2 + 136), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 144), _mm256_load_ps(pt2 + 144), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 152), _mm256_load_ps(pt2 + 152), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 160), _mm256_load_ps(pt2 + 160), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 168), _mm256_load_ps(pt2 + 168), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 176), _mm256_load_ps(pt2 + 176), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 184), _mm256_load_ps(pt2 + 184), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 192), _mm256_load_ps(pt2 + 192), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 200), _mm256_load_ps(pt2 + 200), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 208), _mm256_load_ps(pt2 + 208), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 216), _mm256_load_ps(pt2 + 216), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 224), _mm256_load_ps(pt2 + 224), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 232), _mm256_load_ps(pt2 + 232), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 240), _mm256_load_ps(pt2 + 240), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 248), _mm256_load_ps(pt2 + 248), sum_vec);_mm256_store_ps(q, sum_vec);float score = q[0] + q[1] + q[2] + q[3] + q[4] + q[5] + q[6] + q[7];return score;}else if (dim == 512) {_declspec(align(32)) float q[8];__m256 sum_vec = _mm256_mul_ps(_mm256_load_ps(pt1), _mm256_load_ps(pt2));sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 8), _mm256_load_ps(pt2 + 8), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 16), _mm256_load_ps(pt2 + 16), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 24), _mm256_load_ps(pt2 + 24), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 32), _mm256_load_ps(pt2 + 32), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 40), _mm256_load_ps(pt2 + 40), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 48), _mm256_load_ps(pt2 + 48), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 56), _mm256_load_ps(pt2 + 56), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 64), _mm256_load_ps(pt2 + 64), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 72), _mm256_load_ps(pt2 + 72), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 80), _mm256_load_ps(pt2 + 80), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 88), _mm256_load_ps(pt2 + 88), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 96), _mm256_load_ps(pt2 + 96), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 104), _mm256_load_ps(pt2 + 104), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 112), _mm256_load_ps(pt2 + 112), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 120), _mm256_load_ps(pt2 + 120), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 128), _mm256_load_ps(pt2 + 128), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 136), _mm256_load_ps(pt2 + 136), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 144), _mm256_load_ps(pt2 + 144), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 152), _mm256_load_ps(pt2 + 152), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 160), _mm256_load_ps(pt2 + 160), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 168), _mm256_load_ps(pt2 + 168), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 176), _mm256_load_ps(pt2 + 176), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 184), _mm256_load_ps(pt2 + 184), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 192), _mm256_load_ps(pt2 + 192), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 200), _mm256_load_ps(pt2 + 200), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 208), _mm256_load_ps(pt2 + 208), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 216), _mm256_load_ps(pt2 + 216), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 224), _mm256_load_ps(pt2 + 224), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 232), _mm256_load_ps(pt2 + 232), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 240), _mm256_load_ps(pt2 + 240), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 248), _mm256_load_ps(pt2 + 248), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 256), _mm256_load_ps(pt2 + 256), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 264), _mm256_load_ps(pt2 + 264), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 272), _mm256_load_ps(pt2 + 272), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 280), _mm256_load_ps(pt2 + 280), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 288), _mm256_load_ps(pt2 + 288), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 296), _mm256_load_ps(pt2 + 296), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 304), _mm256_load_ps(pt2 + 304), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 312), _mm256_load_ps(pt2 + 312), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 320), _mm256_load_ps(pt2 + 320), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 328), _mm256_load_ps(pt2 + 328), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 336), _mm256_load_ps(pt2 + 336), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 344), _mm256_load_ps(pt2 + 344), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 352), _mm256_load_ps(pt2 + 352), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 360), _mm256_load_ps(pt2 + 360), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 368), _mm256_load_ps(pt2 + 368), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 376), _mm256_load_ps(pt2 + 376), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 384), _mm256_load_ps(pt2 + 384), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 392), _mm256_load_ps(pt2 + 392), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 400), _mm256_load_ps(pt2 + 400), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 408), _mm256_load_ps(pt2 + 408), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 416), _mm256_load_ps(pt2 + 416), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 424), _mm256_load_ps(pt2 + 424), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 432), _mm256_load_ps(pt2 + 432), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 440), _mm256_load_ps(pt2 + 440), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 448), _mm256_load_ps(pt2 + 448), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 456), _mm256_load_ps(pt2 + 456), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 464), _mm256_load_ps(pt2 + 464), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 472), _mm256_load_ps(pt2 + 472), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 480), _mm256_load_ps(pt2 + 480), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 488), _mm256_load_ps(pt2 + 488), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 496), _mm256_load_ps(pt2 + 496), sum_vec);sum_vec = zq_mm256_fmadd_ps(_mm256_load_ps(pt1 + 504), _mm256_load_ps(pt2 + 504), sum_vec);_mm256_store_ps(q, sum_vec);float score = q[0] + q[1] + q[2] + q[3] + q[4] + q[5] + q[6] + q[7];return score;}else {return -1;}} #endif}#endif?
?
指令集介紹:https://www.jianshu.com/p/53c94628abc9
https://developer.arm.com/architectures/instruction-sets/simd-isas/neon/intrinsics
寄存器介紹:https://blog.csdn.net/ceasar11/article/details/19481375
Neon優化:https://www.jianshu.com/p/16d60ac56249
總結
- 上一篇: ARM NEON寄存器
- 下一篇: NEON快速入门