如何将AVX256加速矩阵乘法成?

摘要:最近打PKU的HPCGAME留下的代码,速度不是很快 const int BLOCK_SIZE = 1024; const int BLOCK_SIZE2 = 256; inline static void block_avx256_16x
最近打PKU的HPCGAME留下的代码,速度不是很快 const int BLOCK_SIZE = 1024; const int BLOCK_SIZE2 = 256; inline static void block_avx256_16x2( // 电脑不支持AVX512捏 int n,int K, //方阵大小 double* A, double* B, double* C) { __m256d c0000_0300, c0400_0700, c0800_1100, c1200_1500, c0001_0301, c0401_0701, c0801_1101, c1201_1501; __m256d a0x_3x, a4x_7x, a8x_11x, a12x_15x, bx0, bx1; double* c0001_0301_ptr = C + n; c0000_0300 = _mm256_load_pd(C); c0400_0700 = _mm256_load_pd(C + 4); c0800_1100 = _mm256_load_pd(C + 8); c1200_1500 = _mm256_load_pd(C + 12); c0001_0301 = _mm256_load_pd(c0001_0301_ptr); c0401_0701 = _mm256_load_pd(c0001_0301_ptr + 4); c0801_1101 = _mm256_load_pd(c0001_0301_ptr + 8); c1201_1501 = _mm256_load_pd(c0001_0301_ptr + 12); for (int x = 0; x < K; ++x) { a0x_3x = _mm256_load_pd(A); a4x_7x = _mm256_load_pd(A + 4); a8x_11x = _mm256_load_pd(A + 8); a12x_15x = _mm256_load_pd(A + 12); A+= 16; bx0 = _mm256_broadcast_sd(B++); bx1 = _mm256_broadcast_sd(B++); c0000_0300 = _mm256_add_pd(_mm256_mul_pd(a0x_3x, bx0), c0000_0300); c0400_0700 = _mm256_add_pd(_mm256_mul_pd(a4x_7x, bx0), c0400_0700); c0800_1100 = _mm256_add_pd(_mm256_mul_pd(a8x_11x, bx0), c0800_1100); c1200_1500 = _mm256_add_pd(_mm256_mul_pd(a12x_15x, bx0), c1200_1500); c0001_0301 = _mm256_add_pd(_mm256_mul_pd(a0x_3x, bx1), c0001_0301); c0401_0701 = _mm256_add_pd(_mm256_mul_pd(a4x_7x, bx1), c0401_0701); c0801_1101 = _mm256_add_pd(_mm256_mul_pd(a8x_11x, bx1), c0801_1101); c1201_1501 = _mm256_add_pd(_mm256_mul_pd(a12x_15x, bx1), c1201_1501); } _mm256_storeu_pd(C, c0000_0300); _mm256_storeu_pd(C + 4, c0400_0700); _mm256_storeu_pd(C + 8, c0800_1100); _mm256_storeu_pd(C + 12, c1200_1500); _mm256_storeu_pd(c0001_0301_ptr, c0001_0301); _mm256_storeu_pd(c0001_0301_ptr + 4, c0401_0701); _mm256_storeu_pd(c0001_0301_ptr + 8, c0801_1101); _mm256_storeu_pd(c0001_0301_ptr + 12, c1201_1501); } static inline void copy_b(int lda, const int K, double* b_src, double* b_dest)
阅读全文