Loading

多项式全家桶:基础篇

推荐博客:洛谷 command_block 的多项式讲解

前置知识:初中数学水平

多项式加减

给定一个多项式 \(F(x)\)\(G(x)\),求它们的 和 / 差。

直接相应次数系数加减即可。

多项式乘法

给定一个多项式 \(F(x)\)\(G(x)\),求它们的积。\(n \leq 10^5\)

FFT

FFT 的实现分为三步:

  1. 把系数表达转换为点值表达(DFT)。
  2. \(F\)\(G\) 的对应点值相乘。
  3. 把点值表达还原为系数表达(IDFT)。

很明显,主要的步骤就是 DFT 和 IDFT。先来了解 DFT 的思想。

DFT 是一个分治的过程。假设当前要处理的函数为 \(F\)(最高次为 \(n\)),而我们已经知道了 \(FL\)\(FR\)\(\omega^0_\frac{n}{2}, \omega^1_\frac{n}{2}...,\omega^k_\frac{n}{2}\) 下的点值表示(其中 \(k = \frac{n}{2} - 1\)),我们可以用以下式子合并。

\[F(w_n^k) = FL(w_{n/2}^k)+w_n^k \times FR(w_{n/2}^k) \]

\[F(w_n^{k+n/2}) = FL(w_{n/2}^k)-w_n^k \times FR(w_{n/2}^k) \]

而 IDFT 只需要对上式的 \(w^k_n\) 部分取逆,最终输出的答案再除以 \(n\) 即可。

FFT 可以使用迭代或者递归来实现。以下为迭代版的核心代码。

void DFT(Point *f, bool op) {
    for(int i = 0; i < n; ++i) if(i < trans[i]) swap(f[i], f[trans[i]]);
    for(int len = 2; len <= n; len <<= 1) {
        Point base = {cos(2 * PI / len), sin(2 * PI / len)};
        if(op) base.y *= -1.0;
        for(int i = 0; i < n; i += len) {
            Point now = {1.0, 0.0};
            for(int k = i; k < i + (len >> 1); ++k) {
                Point temp = Mul(now, f[k + (len >> 1)]);
                f[k + (len >> 1)] = Minus(f[k], temp);
                f[k] = Plus(f[k], temp);
                now = Mul(now, base);
            }
        }
    }
}

for(int i = 0; i < n; ++i) trans[i] = (trans[i >> 1] >> 1) | ((i & 1) ? n >> 1 : 0);

NTT

NTT 的实现类似于 FFT。我们用原根 \(g\) 来替代单位根。

  • 原根:(理解性定义)令 \(p = k2^t+1\),若 \(g^0,g^1,g^2,...g^{p-2}\) 在模 \(p\) 意义下各不相同,则称 \(g\) 是模 \(p\) 意义下的原根。

OI 中常用的 NTT 模数为 \(998244353\),其原根为 \(3\)

NTT 的核心代码:

const LL invG = Pow(G, MOD - 2);

void NTT(LL *F, bool op) {
    for(int i = 0; i < n; ++i) if(i < trans[i]) swap(F[i], F[trans[i]]);
    for(int len = 2; len <= n; len <<= 1) {
        LL base;
        if(!op) base = Pow(G, (MOD - 1) / len);
        else base = Pow(invG, (MOD - 1) / len);
        for(int i = 0; i < n; i += len) {
            LL now = 1;
            for(int k = i; k < i + (len >> 1); ++k) {
                LL temp = now * F[k + (len >> 1)] % MOD;
                F[k + (len >> 1)] = (F[k] - temp) % MOD;
                F[k] = (F[k] + temp) % MOD;
                now = now * base % MOD;
            }
        }
    }
}

任意模数 NTT

有的时候 NTT 的模数不是那些常见的模数甚至不是质数,这时候引入三模 NTT。

基本思想就是,做三组 NTT(一共是九次),最终结果用 CRT 合并即可。

  • 三模 NTT 常用模数:\(998244353, 1004535809,469762049\),它们的原根都是 \(3\)

由于直接用 CRT 合并会爆 long long,这里需要用到一个巧妙的办法。

假设三次 NTT 的结果分别为 \(x_1,x_2,x_3\)(均为自然数),模数分别为 \(m_1,m_2,m_3\),我们有

\[\begin{array}{ll} x \equiv x_{1} & (\bmod m_1) \\ x \equiv x_{2} & (\bmod m_2) \\ x \equiv x_{3} & (\bmod m_3) \end{array} \]

先合并前两个方程,由于 \(x=k_1m_1+x_1=k_2m_2+x_2\),所以 \(k_1m_1\equiv x_2-x_1(\bmod m_2)\)

求出一个 自然数 \(k_1\),令 \(x'=k_1m_1+x_1\)。这个结果不可能达到 \(m_1m_2\),且一定是个 自然数,还没有爆 long long。

接下来合并第三个方程,此时我们有

\[x \equiv x'(\bmod m_1m_2) \\ x \equiv x_3(\bmod m_3) \]

同理,我们得到同余方程 \(k_2m_1m_2 \equiv x_3 - x' (\bmod m_3)\),解出 自然数 \(k_2\)

最终的合并结果即为 \(x=k_2m_1m_2+x'\),这个结果不可能达到 \(m_1m_2m_3\) 且一定是个 自然数

注意到这个时候的 \(x\) 可能会爆 long long,但它一定是取模前的正确答案。我们可以在计算 \(k_2,m_1,m_2\) 的乘积的中途便对给定模数 \(p\) 取模,整个过程便不会超出 long long 的范围。

CRT 合并代码:

void crt() {
    for(int i = 0; i <= m; ++i) {
        LL k1 = (ans2[i] - ans1[i]) % MOD2 * Pow(MOD1, MOD2 - 2, MOD2) % MOD2;
        k1 = (k1 % MOD2 + MOD2) % MOD2;
        ans2[i] = k1 * MOD1 + ans1[i];

        LL k2 = (ans3[i] - ans2[i]) % MOD3 * Pow(MOD1 * MOD2 % MOD3, MOD3 - 2, MOD3) % MOD3;
        k2 = (k2 % MOD3 + MOD3) % MOD3;
        LL x = (k2 % p * MOD1 % p * MOD2 % p + ans2[i]) % p;
        
        printf("%lld ", (x + p) % p);
    }
}

多项式乘法逆

给定多项式 \(F(x)\),求 \(G(x)\) 使得 \(F(x)G(x)=1(\bmod x^n)\),系数对 \(998244353\) 取模。\(n \leq 10^5\)

假设我们已经求得 \(G'(x)F(x) \equiv 1 (\bmod x^{\frac{n}{2}})\)\(G'(x)\),考虑如何得到 \(G(x)F(x) \equiv 1 (\bmod x^{n})\)

\[显然有 \ G(x)F(x) \equiv 1 (\bmod x^{\frac{n}{2}}) \\\ 两式相减,得 \ G(x) - G'(x) \equiv 0 (\bmod x^{\frac{n}{2}}) \\\ 可以证明,平方后得 \ G^2(x) + G'^2(x) - 2 G(x)G'(x) \equiv 0 (\bmod x^n) \\\ 同时乘以 \ F(x),得到 \ G(x) \equiv 2G'(x)-G'^2(x)F(x) (\bmod x^n) \]

显然这是一个倍增的过程,每次倍增用三次 NTT 计算即可。总复杂度为 \(O(n \log n)\)

多项式乘法逆核心部分:

void getinv() {
    int n = 1; while(n < m) n <<= 1;
    g[0] = Pow(f[0], MOD - 2);
    for(int len = 2; len <= n; len <<= 1) {
        for(int i = 0; i < (len >> 1); ++i) t1[i] = 2 * g[i] % MOD;
        for(int i = 0; i < len; ++i) t2[i] = f[i];
        NTT(g, 0, len << 1); NTT(t2, 0, len << 1);
        for(int i = 0; i < (len << 1); ++i) g[i] = t2[i] % MOD * g[i] % MOD * g[i] % MOD;
        NTT(g, 1, len << 1);
        for(int i = 0; i < (len << 1); ++i) g[i] = g[i] * Pow(len << 1, MOD - 2) % MOD;
        for(int i = 0; i < len; ++i) g[i] = (t1[i] - g[i]) % MOD;
        for(int i = len; i < (len << 1); ++i) g[i] = 0;
    }
    for(int i = 0; i < m; ++i) printf("%lld ", (g[i] % MOD + MOD) % MOD);
}

写代码时注意的地方(备忘):

  • 每次 NTT 的 \(n\) 是变化的。

  • 拷贝数组时要拷贝完全,还要记得清空数组的一部分。

    • 拷贝 \(\text{t1}\) 数组时只用循环到 \(\frac{len}{2}\)(因为上一次的 \(g\) 只到了 \(\frac{len}{2}\))。
    • 拷贝 \(\text{t2}\) 数组时要循环到 \(len\)(因为 \(F\) 的次数上界可以超过 \(len\),而 \(F\) 的第 \(len-1\) 次项会产生影响)。
    • 最后要把数组 \(\text{g}\) 大于等于 \(len\) 次项的部分全部清零,避免对后续造成影响(当前是对 \(x^{len}\) 取模,因此 \(len\) 次项及以上的系数均可以清零)。
  • 做 NTT 时,范围为 \(2 \times len\),因为 \(\text{t2}\) 的最高次项为 \(len-1\)\(\text{g}\) 的最高次项为 \(\frac{len}{2}-1\),因此 \(\text{t2}\times\text{g}^2\) 的最高次项可以到达\(2 \times len-3\)

大概就是这样。

posted @ 2020-08-07 22:52  Sqrtyz  阅读(236)  评论(0)    收藏  举报