在x86-64平台上计算C(++)中64位无符号参数的(a * b)%m FAST?

我正在寻找一种快速方法来有效地计算uint64_t类型的abn的模数n (在数学意义上)。 我可以忍受前提条件,例如n!=0 ,甚至a<n && b<n

请注意,C表达式(a*b)%n不会删除它,因为产品被截断为64位。 我正在寻找(uint64_t)(((uint128_t)a*b)%n)除了我没有uint128_t (我知道,在Visual C ++中)。

我正在使用Visual C ++(最好)或GCC / clang内部,充分利用x86-64平台上可用的底层硬件; 或者如果不能用于便携式inlinefunction。

好的,这个怎么样(未经测试)

 modmul: ; rcx = a ; rdx = b ; r8 = n mov rax, rdx mul rcx div r8 mov rax, rdx ret 

前提条件是a * b / n <= ~0ULL ,否则会出现除法误差。 这是一个比a < n && m < n稍微不那么严格的条件,只要另一个足够小,其中一个可以大于n

不幸的是,它必须单独组装和链接,因为MSVC不支持64位目标的内联asm。

它仍然很慢,真正的问题是64位div ,这可能需要近百个周期(严重的是,例如Nehalem上最多90个周期)。

您可以通过移位/加/减来采用传统方式。 以下代码假设< n
n <2 63 (所以事情不会溢出):

 uint64_t mulmod(uint64_t a, uint64_t b, uint64_t n) { uint64_t rv = 0; while (b) { if (b&1) if ((rv += a) >= n) rv -= n; if ((a += a) >= n) a -= n; b >>= 1; } return rv; } 

你可以使用while (a && b)作为循环而不是短路的东西,如果它可能是n的因子。 如果a不是n的因子,则会稍微更慢(更多的比较和可能正确预测的分支)。

如果你真的,绝对需要最后一位(允许n高达2 64 -1),你可以使用:

 uint64_t mulmod(uint64_t a, uint64_t b, uint64_t n) { uint64_t rv = 0; while (b) { if (b&1) { rv += a; if (rv < a || rv >= n) rv -= n; } uint64_t t = a; a += a; if (a < t || a >= n) a -= n; b >>= 1; } return rv; } 

或者,只需使用GCC instrinsics访问底层x64指令:

 inline uint64_t mulmod(uint64_t a, uint64_t b, uint64_t n) { uint64_t rv; asm ("mul %3" : "=d"(rv), "=a"(a) : "1"(a), "r"(b)); asm ("div %4" : "=d"(rv), "=a"(a) : "0"(rv), "1"(a), "r"(n)); return rv; } 

但是,64位div指令非常慢,因此循环实际上可能更快。 你需要描述一下以确定。

此内在函数名为__mul128

 typedef unsigned long long BIG; // handles only the "hard" case when high bit of n is set BIG shl_mod( BIG v, BIG n, int by ) { if (v > n) v -= n; while (by--) { if (v > (nv)) v -= nv; else v <<= 1; } return v; } 

现在你可以使用shl_mod(B, n, 64)

没有内联组件有点糟糕。 无论如何,函数调用开销实际上非常小。 参数在易失性寄存器中传递,不需要清理。

我没有汇编程序,x64目标不支持__asm,所以我别无选择,只能自己从操作码“汇编”我的函数。

显然这取决于。 我使用mpir(gmp)作为参考来显示函数产生正确的结果。

 #include "stdafx.h" // mulmod64(a, b, m) == (a * b) % m typedef uint64_t(__cdecl *mulmod64_fnptr_t)(uint64_t a, uint64_t b, uint64_t m); uint8_t mulmod64_opcodes[] = { 0x48, 0x89, 0xC8, // mov rax, rcx 0x48, 0xF7, 0xE2, // mul rdx 0x4C, 0x89, 0xC1, // mov rcx, r8 0x48, 0xF7, 0xF1, // div rcx 0x48, 0x89, 0xD0, // mov rax,rdx 0xC3 // ret }; mulmod64_fnptr_t mulmod64_fnptr; void init() { DWORD dwOldProtect; VirtualProtect( &mulmod64_opcodes, sizeof(mulmod64_opcodes), PAGE_EXECUTE_READWRITE, &dwOldProtect); // NOTE: reinterpret byte array as a function pointer mulmod64_fnptr = (mulmod64_fnptr_t)(void*)mulmod64_opcodes; } int main() { init(); uint64_t a64 = 2139018971924123ull; uint64_t b64 = 1239485798578921ull; uint64_t m64 = 8975489368910167ull; // reference code mpz_t a, b, c, m, r; mpz_inits(a, b, c, m, r, NULL); mpz_set_ui(a, a64); mpz_set_ui(b, b64); mpz_set_ui(m, m64); mpz_mul(c, a, b); mpz_mod(r, c, m); gmp_printf("(%Zd * %Zd) mod %Zd = %Zd\n", a, b, m, r); // using mulmod64 uint64_t r64 = mulmod64_fnptr(a64, b64, m64); printf("(%llu * %llu) mod %llu = %llu\n", a64, b64, m64, r64); return 0; }