NTT的实现及其加速

NTT的实现及其加速

NTT快速数论变换原理

多项式相乘的困难

假设有多项式A(x),一共有n项,最高次项为\(x^{n-1}\),设其系数为\(a = (a[0],a[1],\cdots,a[n-1])\),类似的假设另一个多项式B(x),设其系数为\(b = (b[0],b[1],\cdots,b[n-1])\),如果想要将两个多项式相乘,很明显一共要乘\(n^2\)次,能不能减少乘法的次数,使其复杂度比\(O(n^2)\)更少呢?

基于FFT(快速傅里叶变换)的多项式相乘

\(x=x_0\)代入A(x)可以得到多项式在点x0处的值,类似的,我们代入n个这样的点$ (x_0,x_1,,x_{n-1})\(,可以得到n个多项式取值\) (A(x_0),A(x_1),,A(x_{n-1}))$,可以证明,仅通过这n个多项式取值,我们可以还原出原本的多项式。

证明:

假设还原出的多项式不唯一,分别设为f(x)和g(x),这个两个多项式的最高次为n-1,否则还原失败

令h(x) = f(x)-g(x) ,这个多项式的最高次至少为n-1

那么代入$ (x_0,x_1,,x_{n-1})$个点,h(x)均等于0,即这个多项式有n个解

而即使在复数域上n-1次方程也至多有n-1个解,矛盾

故只能还原出唯一的一个多项式

类似的,代入多项式B(x)得到$ (B(x_0),B(x_1),,B(x_{n-1}))$,我们在进行下面的点乘操作得到 \[ (A(x_0)*B(x_0),A(x_1)*B(x_1),\cdots,A(x_{n-1})*B(x_{n-1})) \] 这是一个n维向量,按照这个向量还原多项式,我们就可以得到两个多项式的乘积 \[ C(x) = A(x)*B(X) \] 注意这里的乘法为多项式乘法,这样多项式相乘就巧妙的转化为点乘,复杂度只有O(n)

但有个疑问,把点代入多项式计算不是也有计算量吗?这个算法快就是因为,我们可以设置点$ (x_0,x_1,,x_{n-1})$,这些点经过精心设置,可以在计算例如A(x0)的时候很快。

基于CT蝴蝶变换的NTT算法

注意这里的参数需要满足一些条件

  • \(q \equiv 1 \mod{2n}\)

  • \(n = 2^k,k\in Z_+\)

  • \(\psi\)为模数q的2N次单位根,即满足\(\psi^{2n} \equiv 1 \mod{q}\)

  • 数组$= (1,1,2,,^{n-1}) $

  • 数组\(\psi_{rev}\)是把数组\(\phi\)按照bit-reversed顺序重新排列出来的,见下面的例子(多项式次数n = 8)

    X(0) = X(0,0,0) --> X(0,0,0) = X(0)

    X(1) = X(0,0,1) --> X(1,0,0) = X(4)

    X(2) = X(0,1,0) --> X(0,1,0) = X(2)

    X(3) = X(0,1,1) --> X(1,1,0) = X(6)

    以此类推

下面是蝴蝶变换算法

image-20221215122722792

基于GS蝴蝶变换的NTT逆变换

image-20221215122822035

Barrett Reduction 乘法取模加速

参考博客

Barrett reduction是一种求模运算的优化方法,它可以将求模运算的时间复杂度从O(n)降低到O(log n)。

原理简述

一般来说,32位整数加法操作比乘法操作快得多,大概快3到8倍。而移位操作又比加法操作快10倍以上,核心思想就是把除法尽可能迁移到移位操作上。

我们(人工)计算取模,用的是 \[ r= a\bmod p=a-\left\lfloor \dfrac{a}{p} \right\rfloor *p \]

这个计算中有除法,在计算机组成原理中,两个32bit的数相除需要32次移位和32次加减法操作,开销比较大。

而两个32bit的数相乘只需要32次移位操作。我们希望能用乘法替换除法,计算出 \[ q=\left\lfloor \dfrac{a}{p} \right\rfloor \] 我们可以钦定一个整数 k,再弄出一个整数 m,使得 \[ \dfrac{m}{2^k}\approx\dfrac{1}{p} \] 那么 q不就约等于$ $了吗?这样除法运算就被拆成了一次乘法和k次位移,速度大大加快。

为了防止算出的商超过实际的商,我们一般取 \[ m=\left\lfloor \dfrac{2^k}{p} \right\rfloor \] 这里,我们取 \[ k\ge \lceil 2\log_2 p \rceil\ \] 也就是使得$ 2kp2$。下面我们证明,这样取 k时,\(0\le a-pq<p\),也就是我们稍后在计算余数\(a-pq\) 时,得到的答案至多需要再做一次减法 不需要再调整。

下面是这样设置参数的合理性证明,证明\(0\le a-pq<p\)

由于 \(q=\dfrac{am}{2^k}\)

因此 \(pq=\dfrac{apm}{2^k},a-pq=\dfrac{a}{2^k}\cdot (2^k-pm)\)

第一点,由于\(2^k\approx p^2\),而a是模p意义下两个数的乘积,所以\(a<p^2\)

于是有\(0<\dfrac{a}{2^k}<1\)

第二点,由于\(m=\left\lfloor \dfrac{2^k}{p} \right\rfloor\)

所以有$ -1 < m $,进一步推出\(0\leq(2^k-pm)<p\)

综上可以证明\(0\le a-pq<p\)

总结这个算法的流程如下:

  • 根据 pp 的规模选取合适的 k,一般要求 \(k\ge \lceil 2\log_2 p \rceil\)
  • 根据 k,p 预处理出 \(m=\left\lfloor \dfrac{2^k}{p} \right\rfloor\)
  • 实际计算时,用 \(q=\dfrac{a\cdot m}{2^k}\)计算出商,再用$ r=a-pq$ 得出余数

c++例子

下面是一个例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
void RingMultiplier::mulModBarrett(uint64_t& r, uint64_t a, uint64_t b, uint64_t p, uint64_t pr) 
{
unsigned __int128 mul = static_cast<unsigned __int128>(a) * b;
uint64_t abot = static_cast<uint64_t>(mul);//只会返回a*b的低64位
uint64_t atop = static_cast<uint64_t>(mul >> 64);//得到a*b的高64位
unsigned __int128 tmp = static_cast<unsigned __int128>(abot) * pr;
tmp >>= 64;
tmp += static_cast<unsigned __int128>(atop) * pr;
tmp >>= kbar2 - 64;
tmp *= p;
tmp = mul - tmp;
r = static_cast<uint64_t>(tmp);
if(r >= p) r -= p;
}

这个函数主要用来计算r=(a*b) mod p的结果,采用Barrett乘法算法,其中\(pr=2^{kbar2} / p\),kbar2是一个预设的常量,模数p的大约满足\(log_2p = 60\),所以这里的kbar2 = 120,于是有\(pr \approx p\)

算法原理: \[ r = mul - \left\lfloor \dfrac{mul}{p}\right\rfloor*p \\ =mul - mul*\dfrac{pr}{2^k}*p \quad where \quad r=2^k \quad and \quad pr = \dfrac{2^k}{p}\\ 这里把mul*pr/2^k变成如下操作,其中abot存放mul的低64位,atop存放高64位\\ [(abot*pr)>>64+atop*pr]>>(kbar2-64) \\ =[(abot*pr)+atop*pr*2^{64}]/(2^{kbar2})\\ 这样可以完成barrett的快速求余操作 \] 算法步骤:

  • 1)先把a*b的结果存到mul变量中,abot存放mul的低64位,atop存放高64位;
  • 2)然后把abot乘以预设的pr,得到tmp,并右移64位;
  • 3)再把atop乘以pr,再加到tmp上;
  • 4)把tmp右移kbar2-64位,再乘以p;
  • 5)最后求出mul-tmp,得到r;
  • 6)最后判断r是否大于p,如果大于,则减去p,得到最终结果。

NTT的实现及其加速
http://example.com/2022/12/14/NTT的实现及其加速/
作者
harper
发布于
2022年12月14日
许可协议