高效的4×4矩阵乘法(C vs汇编)

我正在寻找一种更快,更棘手的方法来将C中的两个4×4矩阵相乘。我目前的研究主要集中在具有SIMD扩展的x86-64汇编上。 到目前为止,我已经创建了一个函数,比一个简单的C实现快了大约6倍,这超出了我对性能改进的期望。 不幸的是,只有在没有使用优化标志进行编译时(GCC 4.7),这种情况才会成立。 随着-O2 ,C变得更快,我的努力变得毫无意义。

我知道现代编译器利用复杂的优化技术来实现几乎完美的代码,通常比巧妙的手工assembly更快。 但在少数性能关键的情况下,人类可能会尝试使用编译器争取时钟周期。 特别是,当一些支持现代ISA的数学可以被探索时(就像我的情况一样)。

我的函数如下(AT&T语法,GNU汇编程序):

  .text .globl matrixMultiplyASM .type matrixMultiplyASM, @function matrixMultiplyASM: movaps (%rdi), %xmm0 # fetch the first matrix (use four registers) movaps 16(%rdi), %xmm1 movaps 32(%rdi), %xmm2 movaps 48(%rdi), %xmm3 xorq %rcx, %rcx # reset (forward) loop iterator .ROW: movss (%rsi), %xmm4 # Compute four values (one row) in parallel: shufps $0x0, %xmm4, %xmm4 # 4x 4FP mul's, 3x 4FP add's 6x mov's per row, mulps %xmm0, %xmm4 # expressed in four sequences of 5 instructions, movaps %xmm4, %xmm5 # executed 4 times for 1 matrix multiplication. addq $0x4, %rsi movss (%rsi), %xmm4 # movss + shufps comprise _mm_set1_ps intrinsic shufps $0x0, %xmm4, %xmm4 # mulps %xmm1, %xmm4 addps %xmm4, %xmm5 addq $0x4, %rsi # manual pointer arithmetic simplifies addressing movss (%rsi), %xmm4 shufps $0x0, %xmm4, %xmm4 mulps %xmm2, %xmm4 # actual computation happens here addps %xmm4, %xmm5 # addq $0x4, %rsi movss (%rsi), %xmm4 # one mulps operand fetched per sequence shufps $0x0, %xmm4, %xmm4 # | mulps %xmm3, %xmm4 # the other is already waiting in %xmm[0-3] addps %xmm4, %xmm5 addq $0x4, %rsi # 5 preceding comments stride among the 4 blocks movaps %xmm5, (%rdx,%rcx) # store the resulting row, actually, a column addq $0x10, %rcx # (matrices are stored in column-major order) cmpq $0x40, %rcx jne .ROW ret .size matrixMultiplyASM, .-matrixMultiplyASM 

它通过处理在128位SSE寄存器中打包的四个浮点数来计算每次迭代的结果矩阵的整列。 通过一些数学运算(操作重新排序和聚合)和用于4xfloat包的并行乘法/加法的mullps / addps指令,可以实现完全矢量化。 代码重用用于传递参数的寄存器( %rdi%rsi%rdx :GNU / Linux ABI),从(内部)循环展开中获益,并在XMM寄存器中完全保存一个矩阵以减少内存读取。 你可以看到,我已经研究了这个话题,并且花了很多时间来尽我所能地实现它。

征服我的代码的天真C计算如下所示:

 void matrixMultiplyNormal(mat4_t *mat_a, mat4_t *mat_b, mat4_t *mat_r) { for (unsigned int i = 0; i < 16; i += 4) for (unsigned int j = 0; j m[i + j] = (mat_b->m[i + 0] * mat_a->m[j + 0]) + (mat_b->m[i + 1] * mat_a->m[j + 4]) + (mat_b->m[i + 2] * mat_a->m[j + 8]) + (mat_b->m[i + 3] * mat_a->m[j + 12]); } 

我已经研究了上面C代码的优化汇编输出,它在XMM寄存器中存储浮点数时, 不涉及任何并行操作 – 只是标量计算,指针运算和条件跳转。 编译器的代码似乎不那么刻意,但它仍然比我的矢量化版本稍微更有效,预计会快4倍。 我确信一般的想法是正确的 – 程序员做同样的事情并获得有益的结果。 但这里有什么问题? 是否有任何我不知道的寄存器分配或指令调度问题? 你知道任何支持我与机器作战的x86-64assembly工具或技巧吗?

4×4矩阵乘法是64次乘法和48次加法。 使用SSE,这可以减少到16次乘法和12次加法(和16次广播)。 以下代码将为您执行此操作。 它只需要SSE( #include )。 arraysABC需要16字节对齐。 使用诸如hadd (SSE3)和dpps (SSE4.1)之类的水平指令效率较低 (尤其是dpps )。 我不知道循环展开是否有帮助。

 void M4x4_SSE(float *A, float *B, float *C) { __m128 row1 = _mm_load_ps(&B[0]); __m128 row2 = _mm_load_ps(&B[4]); __m128 row3 = _mm_load_ps(&B[8]); __m128 row4 = _mm_load_ps(&B[12]); for(int i=0; i<4; i++) { __m128 brod1 = _mm_set1_ps(A[4*i + 0]); __m128 brod2 = _mm_set1_ps(A[4*i + 1]); __m128 brod3 = _mm_set1_ps(A[4*i + 2]); __m128 brod4 = _mm_set1_ps(A[4*i + 3]); __m128 row = _mm_add_ps( _mm_add_ps( _mm_mul_ps(brod1, row1), _mm_mul_ps(brod2, row2)), _mm_add_ps( _mm_mul_ps(brod3, row3), _mm_mul_ps(brod4, row4))); _mm_store_ps(&C[4*i], row); } } 

有一种方法可以加速代码并超越编译器。 它不涉及任何复杂的流水线分析或深度代码微优化(这并不意味着它不能从这些中进一步受益)。 优化使用三个简单的技巧:

  1. 该function现在是32字节对齐的(这显着提高了性能),

  2. 主循环反向,这减少了与零测试的比较(基于EFLAGS),

  3. 事实certificate,指令级地址算法比“外部”指针计算更快(即使在3/4情况下它需要两倍的加法)。 它通过四条指令缩短了循环体,并减少了其执行路径中的数据依赖性。 见相关问题 。

此外,代码使用相对跳转语法来抑制符号重定义错误,这种错误发生在GCC尝试内联它时(在放入asm语句并使用-O3编译之后)。

  .text .align 32 # 1. function entry alignment .globl matrixMultiplyASM # (for a faster call) .type matrixMultiplyASM, @function matrixMultiplyASM: movaps (%rdi), %xmm0 movaps 16(%rdi), %xmm1 movaps 32(%rdi), %xmm2 movaps 48(%rdi), %xmm3 movq $48, %rcx # 2. loop reversal 1: # (for simpler exit condition) movss (%rsi, %rcx), %xmm4 # 3. extended address operands shufps $0, %xmm4, %xmm4 # (faster than pointer calculation) mulps %xmm0, %xmm4 movaps %xmm4, %xmm5 movss 4(%rsi, %rcx), %xmm4 shufps $0, %xmm4, %xmm4 mulps %xmm1, %xmm4 addps %xmm4, %xmm5 movss 8(%rsi, %rcx), %xmm4 shufps $0, %xmm4, %xmm4 mulps %xmm2, %xmm4 addps %xmm4, %xmm5 movss 12(%rsi, %rcx), %xmm4 shufps $0, %xmm4, %xmm4 mulps %xmm3, %xmm4 addps %xmm4, %xmm5 movaps %xmm5, (%rdx, %rcx) subq $16, %rcx # one 'sub' (vs 'add' & 'cmp') jge 1b # SF=OF, idiom: jump if positive ret 

这是迄今为止我见过的最快的x86-64实现。 我会很感激,投票并接受任何答案,为此目的提供更快的assembly!

我想知道转置其中一个矩阵是否有益。

考虑我们如何乘以以下两个矩阵……

 A1 A2 A3 A4 W1 W2 W3 W4 B1 B2 B3 B4 X1 X2 X3 X4 C1 C2 C3 C4 * Y1 Y2 Y3 Y4 D1 D2 D3 D4 Z1 Z2 Z3 Z4 

这会导致……

 dot(A,?1) dot(A,?2) dot(A,?3) dot(A,?4) dot(B,?1) dot(B,?2) dot(B,?3) dot(B,?4) dot(C,?1) dot(C,?2) dot(C,?3) dot(C,?4) dot(D,?1) dot(D,?2) dot(D,?3) dot(D,?4) 

做一行和一列的点积是一件痛苦的事。

如果我们在乘以之前转换第二个矩阵怎么办?

 A1 A2 A3 A4 W1 X1 Y1 Z1 B1 B2 B3 B4 W2 X2 Y2 Z2 C1 C2 C3 C4 * W3 X3 Y3 Z3 D1 D2 D3 D4 W4 X4 Y4 Z4 

现在改为做行和列的点积,我们正在做两行的点积。 这可以使自己更好地使用SIMD指令。

希望这可以帮助。

上面的Sandy Bridge扩展了指令集以支持8元素向量算法。 考虑这个实现。

 struct MATRIX { union { float f[4][4]; __m128 m[4]; __m256 n[2]; }; }; MATRIX myMultiply(MATRIX M1, MATRIX M2) { // Perform a 4x4 matrix multiply by a 4x4 matrix // Be sure to run in 64 bit mode and set right flags // Properties, C/C++, Enable Enhanced Instruction, /arch:AVX // Having MATRIX on a 32 byte bundry does help performance MATRIX mResult; __m256 a0, a1, b0, b1; __m256 c0, c1, c2, c3, c4, c5, c6, c7; __m256 t0, t1, u0, u1; t0 = M1.n[0]; // t0 = a00, a01, a02, a03, a10, a11, a12, a13 t1 = M1.n[1]; // t1 = a20, a21, a22, a23, a30, a31, a32, a33 u0 = M2.n[0]; // u0 = b00, b01, b02, b03, b10, b11, b12, b13 u1 = M2.n[1]; // u1 = b20, b21, b22, b23, b30, b31, b32, b33 a0 = _mm256_shuffle_ps(t0, t0, _MM_SHUFFLE(0, 0, 0, 0)); // a0 = a00, a00, a00, a00, a10, a10, a10, a10 a1 = _mm256_shuffle_ps(t1, t1, _MM_SHUFFLE(0, 0, 0, 0)); // a1 = a20, a20, a20, a20, a30, a30, a30, a30 b0 = _mm256_permute2f128_ps(u0, u0, 0x00); // b0 = b00, b01, b02, b03, b00, b01, b02, b03 c0 = _mm256_mul_ps(a0, b0); // c0 = a00*b00 a00*b01 a00*b02 a00*b03 a10*b00 a10*b01 a10*b02 a10*b03 c1 = _mm256_mul_ps(a1, b0); // c1 = a20*b00 a20*b01 a20*b02 a20*b03 a30*b00 a30*b01 a30*b02 a30*b03 a0 = _mm256_shuffle_ps(t0, t0, _MM_SHUFFLE(1, 1, 1, 1)); // a0 = a01, a01, a01, a01, a11, a11, a11, a11 a1 = _mm256_shuffle_ps(t1, t1, _MM_SHUFFLE(1, 1, 1, 1)); // a1 = a21, a21, a21, a21, a31, a31, a31, a31 b0 = _mm256_permute2f128_ps(u0, u0, 0x11); // b0 = b10, b11, b12, b13, b10, b11, b12, b13 c2 = _mm256_mul_ps(a0, b0); // c2 = a01*b10 a01*b11 a01*b12 a01*b13 a11*b10 a11*b11 a11*b12 a11*b13 c3 = _mm256_mul_ps(a1, b0); // c3 = a21*b10 a21*b11 a21*b12 a21*b13 a31*b10 a31*b11 a31*b12 a31*b13 a0 = _mm256_shuffle_ps(t0, t0, _MM_SHUFFLE(2, 2, 2, 2)); // a0 = a02, a02, a02, a02, a12, a12, a12, a12 a1 = _mm256_shuffle_ps(t1, t1, _MM_SHUFFLE(2, 2, 2, 2)); // a1 = a22, a22, a22, a22, a32, a32, a32, a32 b1 = _mm256_permute2f128_ps(u1, u1, 0x00); // b0 = b20, b21, b22, b23, b20, b21, b22, b23 c4 = _mm256_mul_ps(a0, b1); // c4 = a02*b20 a02*b21 a02*b22 a02*b23 a12*b20 a12*b21 a12*b22 a12*b23 c5 = _mm256_mul_ps(a1, b1); // c5 = a22*b20 a22*b21 a22*b22 a22*b23 a32*b20 a32*b21 a32*b22 a32*b23 a0 = _mm256_shuffle_ps(t0, t0, _MM_SHUFFLE(3, 3, 3, 3)); // a0 = a03, a03, a03, a03, a13, a13, a13, a13 a1 = _mm256_shuffle_ps(t1, t1, _MM_SHUFFLE(3, 3, 3, 3)); // a1 = a23, a23, a23, a23, a33, a33, a33, a33 b1 = _mm256_permute2f128_ps(u1, u1, 0x11); // b0 = b30, b31, b32, b33, b30, b31, b32, b33 c6 = _mm256_mul_ps(a0, b1); // c6 = a03*b30 a03*b31 a03*b32 a03*b33 a13*b30 a13*b31 a13*b32 a13*b33 c7 = _mm256_mul_ps(a1, b1); // c7 = a23*b30 a23*b31 a23*b32 a23*b33 a33*b30 a33*b31 a33*b32 a33*b33 c0 = _mm256_add_ps(c0, c2); // c0 = c0 + c2 (two terms, first two rows) c4 = _mm256_add_ps(c4, c6); // c4 = c4 + c6 (the other two terms, first two rows) c1 = _mm256_add_ps(c1, c3); // c1 = c1 + c3 (two terms, second two rows) c5 = _mm256_add_ps(c5, c7); // c5 = c5 + c7 (the other two terms, second two rose) // Finally complete addition of all four terms and return the results mResult.n[0] = _mm256_add_ps(c0, c4); // n0 = a00*b00+a01*b10+a02*b20+a03*b30 a00*b01+a01*b11+a02*b21+a03*b31 a00*b02+a01*b12+a02*b22+a03*b32 a00*b03+a01*b13+a02*b23+a03*b33 // a10*b00+a11*b10+a12*b20+a13*b30 a10*b01+a11*b11+a12*b21+a13*b31 a10*b02+a11*b12+a12*b22+a13*b32 a10*b03+a11*b13+a12*b23+a13*b33 mResult.n[1] = _mm256_add_ps(c1, c5); // n1 = a20*b00+a21*b10+a22*b20+a23*b30 a20*b01+a21*b11+a22*b21+a23*b31 a20*b02+a21*b12+a22*b22+a23*b32 a20*b03+a21*b13+a22*b23+a23*b33 // a30*b00+a31*b10+a32*b20+a33*b30 a30*b01+a31*b11+a32*b21+a33*b31 a30*b02+a31*b12+a32*b22+a33*b32 a30*b03+a31*b13+a32*b23+a33*b33 return mResult; } 

显然,您可以一次从四个矩阵中获取项,并使用相同的算法同时乘以四个矩阵。