多项式“全”家桶(转载自洛谷)

多项式/生成函数全家桶

1. 生成函数

定义 \(\lang a_0, a_1, \dots \rang\) 的普通生成函数为:

\[F(x) = \sum_{i \ge 0} a_ix^i \]

指数生成函数为:

\[F(x) = \sum_{i \ge 0}a_i\frac{x_i}{i!}\\ \]

不同生成函数在不同题目中有不同用途。

本文不着重于如何使用生成函数,而在于有限项多项式的运算。

2. 基本操作

  1. 多项式乘法
  2. 多项式求逆
  3. 多项式 \(\ln\)
  4. 多项式 \(\exp\)
  5. 多项式快速幂
  6. 多项式除法
  7. 多项式开方
  8. 分治乘法

3. FFT, NTT

3.1 多项式乘法,变为 \(n\) 个值相乘

多项式最重要的操作就是乘法:

\[H(x) = F(x) \times G(x) = \sum_{k}x^k\sum_{i + j = k}f_ig_j \]

直接算当然是 \(O(n^2)\) 的,我们要寻求更优秀的算法。

事实上,对于任意一个 \(x\),由定义有 \(H(x) = F(x) \times G(x)\)。假设我们得知 \(\deg H = n - 1\),则我们只需要 \(n\)\(x\) 的值,就可以通过一些方法得到 \(H\) 的所有系数了。

在这里,我们一般会选择单位根 \(\omega_1 = e^{\frac{2i\pi}{n}}\)。FFT 解决的,就是在 \(O(n\log n)\) 的复杂度内算出 \(F(\omega_1), F(\omega_2), \dots, F(\omega_n)\)

3.2 分治与蝶形算法

在实践中,我们往往会认为 \(n\) 是一个 \(2\) 的幂次,因为 \(2\) 的幂次有很多优良的性质。

我们尝试计算一对 \(F(\omega_i)\)\(F(\omega_{i + n/2})\),注意,其中有 \(-\omega_i = \omega_{i + n/2}\)

\[\begin{aligned} F(\omega_i)&= f_0 + f_1\omega_i + f_2\omega_i^2 + \dots + f_{n - 1}\omega_i^{n - 1}\\ F(\omega_{i + n/2}) &= f_0 + f_1\omega_{i+n/2} + f_2\omega_{i+n/2}^2 + \dots + f_{n - 1}\omega_{i + n / 2}^{n - 1}\\ &= f_0 - f_1\omega_i + f_2\omega_i^2 - \dots -f_{n - 1}\omega_i^{n - 1}\\ F(\omega_i) + F(\omega_{i + n/2}) &= 2\left(f_0 + f_2\omega_i^2 + \dots +f_{n - 2}\omega_i^{n - 2}\right)\\ F(\omega_i) - F(\omega_{i + n / 2}) &= 2\omega_i (f_1 + f_3\omega_i^2 + \dots + f_{n - 1}\omega_i^{n - 2}) \end{aligned} \]

这个形式就非常分治。在这个分治中,我们是根据 \(f\) 下标的奇偶性进行分治的,这与我们常见的分治算法略有不同。在常规的分治中,每次分治的内容是一个区间,在 \(n\)\(2\) 的幂次的情况下,即拥有相同的若干位 \(\operatorname{highbit}\)。但是,在这里的分治中,拥有相同 \(\operatorname{lowbit}\)\(f_i\) 被分到了一个组。对此有一个解决办法:翻转数位。

例如,原本给定的数列是 \(\lang f_0, f_1, f_2, f_3, f_4, f_5, f_6, f_7\rang\),进行位翻转后是 \(\lang f_0, f_4, f_2, f_6, f_1, f_5, f_3, f_7\rang\)。这保证了每次分治是一个区间。这个翻转的操作也称为蝶形算法。

我们很容易写出蝶形算法的代码。对于分治过程,由于分治区间的明确性,我们可以自底向上地进行分治,枚举一个 \(2\) 的幂 \(2^h = 1, 2, \dots, \frac{n}{2}\),然后对所有 \(F(\omega_i)\)\(F(\omega_{i + 2^h})\) 进行变换。

3.3 逆运算

事实上我们可以发现,刚才的过程可逆性非常的强。若简记 \(F_i = F(\omega_i)\),则可以发现,刚才的分治过程本质上是这样:

\[[F_i, F_{i + n/2}] = [2(F_i + \omega_i F_{i + n/2}), 2(F_i - \omega_iF_{i + n/2})] \]

它的逆运算就是:

\[[F_i, F_{i + n/2}] = [\frac{F_i + F_{i + n/2}}{4}, \frac{F_i - F_{i + n/2}}{4\omega_i}] \]

这两个步骤简直一模一样!观察到里面的 \(2, \frac{1}{4}\) 这些系数,其实可以将它们全部提出,等到变换结束后,再给所有数都乘上 \(\frac{1}{n}\)

其实可以注意到,很多时候,我们写代码的时候,在逆运算里面没有出现 \(\frac{1}{\omega_i}\) 这个操作,反而是 \(\times \omega_i\)。这是为什么呢?这其实利用了 \(\omega_i = \bar{\omega}_{n - 1}\),也就是,我们这样算出来的 \(F\) 实际上是 \(\lang f_n, f_{n - 1}, \dots, f_1\rang\)。因为 \(\omega_n = \omega_0\),所以我们只需要将后面的 \(n - 1\) 项翻转一下即可。

3.4 模意义下的变换, NTT

我们更常见的范围是在 \(\Z_{998244353}\) 等模意义下的乘法。为了贴合 FFT,我们希望找到这么一个数 \(g\) 来代替 \(\omega\)。这样的数需要满足 \(g^n \equiv 1\),且 \(g^0, g^1, \dots, g^{n - 1}\) 互不相同。

对于第二个性质,我们发现原根就十分符合。那第一个性质呢?我们知道,对于任意一个原根 \(g_0\),有 \(g_0^{\varphi(p)} \equiv 1\)。只要我们能够使得 \(n | \varphi(p)\),那我们取 \(g = g_0^{\frac{\varphi(p)}{n}} = g_0^{\frac{p - 1}{n}}\) 就是合理的。注意到,常见的多项式模数都满足其 \(-1\)\(2\) 的系数非常高。例如:

\[998244353 = 119 \times 2^{23} + 1, g_0 = 3\\ 1004535809 = 479 \times 2^{21} + 1, g_0 = 3 \]

因而,我们可以写出集 NTT 与逆变换 INTT 的代码:

void fft(ll *f, int len, int on)
{
    for (int i = 1, j = len >> 1; i < len - 1; i++)
    {
        if (i < j)
        {
            swap(f[i], f[j]);
        }
        int k = len >> 1;
        while (j >= k)
        {
            j -= k;
            k >>= 1;
        }
        j += k;
    }
    for (int h = 2; h <= len; h <<= 1)
    {
        ll w0 = pow(3, (MOD - 1) / h, MOD);
        for (int j = 0; j < len; j += h)
        {
            ll w = 1;
            for (int i = j; i < j + (h >> 1); i++)
            {
                ll u = f[i], t = f[i + (h >> 1)] * w % MOD;
                f[i] = (u + t) % MOD;
                f[i + (h >> 1)] = (u + MOD - t) % MOD;
                w = w * w0 % MOD;
            }
        }
    }
    if (on == -1) // 实际上,在这种写法中,on 只会出现在这里,也就是说,它可以被换成任意值,比如一个 bool。在别的写法中,on 还会出现在第 19 行。
    {
        reverse(f + 1, f + len);
        ll inl = pow(len, MOD - 2, MOD);
        for (int i = 0; i < len; i++)
        {
            (f[i] *= inl) %= MOD;
        }
    }
}

4. 递归与分治

在很多情况下,我们无法一次性得到我们要的所有系数,这时候就要使用递归或分治 FFT 了。

4.1 递归 FFT

为了求 \(F_n(x)\)(意为 \(F\) 的低 \(n\) 项系数组成的多项式),我们一般先会求 \(F_{\frac{n}{2}}(x)\),然后再通过一些较为简便的方法从 \(F_{\frac{n}{2}}(x)\) 推到 \(F_n(x)\)。这个过程也可以被称为倍增。

在更多情况下,我们会写作 \(F_n(x) \to F_{2n}(x)\)

在一般情况中,这种递推操作的复杂度为 \(O(n\log n)\),此时整个算法的复杂度 \(T(n) = T(n / 2) + O(n\log n) = O(n\log n)\)。当然常数会大一点。

很多多项式初等函数都利用了递归 FFT。例如,多项式求逆、多项式 \(\exp\)。这些会在下节说明。

4.2 分治 FFT

这里的分治 FFT 不是指将 \(n\) 个多项式乘一起的 \(O(n\log^2n)\) 算法。

有时候,我们会遇到一种“自卷积”的现象:

\[F_i = \sum_{j = 1}^iG_jF_{i - j} \]

这个时候,一次卷积是不够用的。观察一下,对于 \(F_j \times G_{i - j} \to F_i\),我们可以进行 cdq 分治。我们可以先分治计算 \(F\) 的低 \(\frac{n}{2}\) 项系数 \(F_{[0, \frac{n}{2})}\),再通过 \(F_{[0, \frac{n}{2})} \times G_{[0, n)}\) 来得到 \(F_{[\frac{n}{2}, n)}\) 部分得到的贡献,再对高 \(\frac{n}{2}\) 项进行分治。

这个算法的复杂度为 \(T(n) = 2T(n/2) + O(n\log n) = O(n\log^2 n)\),和三维偏序等问题复杂度相同。

事实上,根据上面的式子,\(F(x) = G(x)F(x)\),可以推导出 \(F(x) = \frac{1}{1 - G(x)}\),用多项式求逆可以做到 \(O(n\log n)\)

下面是分治乘法的代码。

/**
 * calculate the [l, r) power of f(x) = f(x) * g(x) (self convolution), f(0) = 1
 * it should be ensured that (r - l) is a power of 2
 */
void self_conv(ll __restrict__ *f, ll __restrict__ *g, int l, int r)
{
    static ll buffer[MAXN];
    if (l == r - 1)
    {
        if (!l)
        {
            f[l] = 1;
        }
        return;
    }
    int mid = (l + r) >> 1;
    self_conv(f, g, l, mid);
    for (int i = l; i < 2 * r - l; i++)
    {
        if (i < mid)
        {
            buffer[i - l] = f[i];
        }
        else
        {
            buffer[i - l] = 0;
        }
    }
    conv(buffer, g, (r - l) << 1);
    for (int i = mid; i < r; i++)
    {
        (f[i] += buffer[i - l]) %= MOD;
    }
    self_conv(f, g, mid, r);
}

5. 初等函数

5.1 多项式求逆

运用递归 FFT。

递归的终止条件是:\([x^0]f^{-1}(x) = ([x^0]f(x))^{-1}\)

假设我们已经知道了在模 \(x^n\) 意义下的逆元 \(f_0^{-1}\),我们希望得到模 \(x^{2n}\) 意义下的逆元 \(f^{-1}\)

由于两者在前 \(n\) 项相同,因此有 \(f^{-1}(x) - f_0^{-1}(x) \equiv 0 \pmod{x^n}\)。平方一下,则最低的 \(2n\) 项均为 \(0\)

\[(f^{-1}(x))^2 - 2f^{-1}(x)f^{-1}_0(x)+(f^{-1}_0(x))^2 \equiv 0 \pmod {x^{2n}}\\ f^{-1}(x) - 2f_0^{-1}(x) + (f^{-1}_0(x))^2 \times f(x) \equiv 0 \pmod {x^{2n}}\\ f^{-1}(x) \equiv f_{0}^{-1}(x)(2 - f(x)f^{-1}_0(x)) \pmod {x^{2n}} \]

因此可以递归计算。

/**
 * calculate f(x) = 1 / g(x) with length of len
 * it should be ensured that len is a power of 2
 */
void poly_inv(ll __restrict__ *f, ll __restrict__ *g, int len)
{
    static ll buffer[MAXN];
    f[0] = pow(g[0], MOD - 2, MOD);
    for (int i = 1; i < (len << 1); i++)
    {
        f[i] = 0;
    }
    for (int t = 2; t <= len; t <<= 1)
    {
        int t2 = 2 * t;
        for (int i = 0; i < t; i++)
        {
            buffer[i] = g[i];
        }
        for (int i = t; i < t2; i++)
        {
            buffer[i] = 0;
        }
        fft(f, t2, 1); // 注意我们出现了三个多项式相乘,所以要开 2 倍长度
        fft(buffer, t2, 1);
        for (int i = 0; i < t2; i++)
        {
            f[i] = f[i] * (MOD + 2 - f[i] * buffer[i] % MOD) % MOD;
        }
        fft(f, t2, -1);
        for (int i = t; i < t2; i++)
        {
            f[i] = 0;
        }
    }
}

5.2 多项式 \(\ln\)

我们可以照搬 \(\ln\) 在数学上的定义:

\[\frac{\operatorname{d}\ln f(x)}{\operatorname{d}x} = \frac{f'(x)}{f(x)}\\ \ln f(x) = \int \frac{f'(x)}{f(x)} \operatorname{d}x\\ \]

对于等式右边的式子,我们已经定义了它在 \(\bmod x^n\) 意义下的值,因此,我们便可以通过求导、求逆、求积分来得到 \(\ln f(x)\)。复杂度 \(O(n\log n)\)

注意,在这里我们必须有 \([x^0]f(x) = 1\),否则 \(\ln\) 之后常数项不收敛。因为最后是一个积分的形式,这说明 \(\ln f(x)\) 对应的多项式常数为 \(0\)

/**
 * calculate f = g' with length of len
 * it should be ensured that len is a power of 2
 * this algorithm will work if f = g
 */
void poly_der(ll *f, ll *g, int len)
{
    for (int i = 1; i < len; i++)
    {
        f[i - 1] = g[i] * i % MOD;
    }
    f[len - 1] = 0;
}

/**
 * calculate f = ∫g with length of len
 * it should be ensured that len is a power of 2
 * this algorithm will work if f = g
 *
 * it's name shouldn't be int because int is a c++ keyword
 */
void poly_int(ll *f, ll *g, int len)
{
    for (int i = len - 2; ~i; i--)
    {
        f[i + 1] = g[i] * pow(i + 1, MOD - 2, MOD) % MOD;
    }
    f[0] = 0;
}

/**
 * calculate f(x) = ln(g(x)) with length of len
 * it should be ensured that len is a power of 2
 *
 * ln(g(x)) = ∫(g'(x) / g(x))
 */
void poly_ln(ll __restrict__ *f, ll __restrict__ *g, int len)
{
    static ll buffer[MAXN];
    for (int i = len; i < (len << 1); i++)
    {
        g[i] = 0;
    }
    poly_inv(buffer, g, len); // this function clears buffer[len, 2len)
    poly_der(g, g, len);
    conv(buffer, g, len << 1, f);
    poly_int(f, f, len);
    poly_int(g, g, len);
    g[0] = pow(buffer[0], MOD - 2, MOD);
}

5.3 多项式 \(\exp\)

\(\ln\) 类似,所求的 \(f(x)\) 必须满足常数项是 \(0\),否则不收敛。否则,可以得到 \(\exp f(x)\)\(0\) 次项为 \(1\)

\(\exp f(x)\) 求导:

\[\frac{\operatorname{d} \exp f(x)}{\operatorname{d} x} \equiv \exp f(x) \times f'(x) \]

观察其中的 \(n-1\) 项系数:

\[[x^{n - 1}]\frac{\operatorname{d} \exp f(x)}{\operatorname{d} x} \equiv \sum_{i=0}^{n - 1}([x^i]\exp f(x))([x^{n - i - 1}]f'(x)) \]

因为此时 \(\exp f(x)\) 是一个多项式,所以 \([x^{n - 1}]\frac{\operatorname{d} \exp f(x)}{\operatorname{d} x}\) 等于 \(n[x^n]\exp f(x)\)(多项式求导)。右侧处理同理。因此得到:

\[n[x^n]\exp f(x) \equiv \sum_{i=0}^{n - 1}([x^i]\exp f(x))((n - i)[x^{n - i}]f(x)) \]

可以使用分治 FFT 解决,复杂度 \(O(n\log^2n)\)。注意,这个式子并不能化成类似 \(\frac{1}{1 - f(x)}\) 的形式。

事实上,更优的做法是使用牛顿公式,可以做到 \(O(n\log n)\),参见下节。

5.4 多个多项式乘与快速幂

由数学知识,\(a\times b = e^{\ln a + \ln b}\)。左侧有乘法,而右侧只有 \(\exp, \ln\) 和加法。

当我们拓展到 \(n\) 个多项式时同样成立。

有时,左侧运算需要保留很高的次项,而右侧会因为可能出现多个相同多项式乘在一起的情况,而大大减小 \(\ln\) 的计算次数。

一个经典问题就是多项式快速幂。我们试图计算 \(F^k(x)\)。根据上述规则,我们 \(\ln\) 之后,给每一项系数都乘上 \(k\),再 \(\exp\) 回去,这样的复杂度为 \(O(n\log n)\),并且与 \(k\) 的大小无关。他的另一个优点就是存的系数数量较少(只存我们想要知道的,额外空间是 \(O(n)\) 的)。

当然它也有一个缺陷。当 \(F(x)\)\(0\) 次项不为 \(1\) 时,无法进行 \(\ln\) 操作。这个时候,我们可以将 \(F(x)\) 写成 \(x^a \times b \times G(x)\) 的形式,其中 \(G(x)\)\(0\) 次项为 \(1\)。这样快速幂的结果就是 \(x^{ka} \times b^k \times G^k(x)\)。注意,\(k\) 在乘进 \(\ln G(x)\) 的时候的模数为 \(p\),而计算 \(b^k\) 时的模数为 \(p - 1\)

/**
 * calculate f(x) = (g(x))^p with length of len
 * it should be ensured that g[0] = 1
 *
 * (g(x))^p = exp(p * ln(g(x)))
 */
void poly_pow_1(ll __restrict__ *f, ll __restrict__ *g, ll p, int len)
{
    static ll buffer[MAXN];
    poly_ln(buffer, g, len);
    for (int i = 0; i < len; i++)
    {
        (buffer[i] *= p) %= MOD;
    }
    poly_exp(f, buffer, len);
}

/**
 * "i won't write poly_pow because the mod of the polynomial and the mod of the coefficients are different"
 */

5.5 多项式开方

完全可以注意到,其实这就是 \(F^{\frac{1}{2}}(x)\),使用快速幂即可(因为 \(\frac{1}{2} \equiv \frac{p + 1}{2}\),可以用它作为幂次)。

另外一种是使用牛顿公式,可以达到同样的复杂度。

6. 牛顿公式

又称牛顿迭代法。

6.1 基本操作

我们为了计算 \(f(x)\),我们构造一个二元函数 \(G(x, y)\),满足 \(G(x, f(x)) \equiv 0 \pmod {x^n}\)。该函数还需要满足存在一个数值 \(f_1\),使得:

  • \(G(0, f_1) = 0\)
  • \(\frac{\partial G}{\partial y}G(0, f_1) \not= 0\)

一般来说,这个 \(f_1\) 是好求的。我们也可以将 \(f_1\) 认为是一个常数多项式,我们的目标就是求出 \(f_n(x)\)。对此,我们参考递归的形式,要从 \(f_n(x)\) 推向 \(f_{2n}(x)\)

\(G(x, f_{2n}(x))\)\(f_{2n}(x) = f_n(x)\) 处进行泰勒展开:

\[0 \equiv G(x, f_{2n}(x)) \equiv \sum_{i=0}^{+\infty}\frac{\frac{\partial^iG}{\partial y^i}(x, f_n(x))}{i!}(f_{2n}(x) - f_n(x))^i \pmod {x^{2n}}\\ \]

观察最右边的括号,\(f_{2n}(x) - f_n(x)\) 在前 \(n\) 项都是 \(0\),因此只要 \(i \ge 2\),这个部分就是 \(0\)。因此只用保留两项:

\[\begin{aligned} 0 &\equiv \sum_{i=0}^{+\infty}\frac{\frac{\partial^iG}{\partial y^i}(x, f_n(x))}{i!}(f_{2n}(x) - f_n(x))^i\\ &\equiv G(x, f_{n}(x)) + \frac{\partial G}{\partial y}(x, f_n(x))(f_{2n}(x) - f_n(x)) \pmod {x^{2n}} \end{aligned} \]

移项,得到:

\[f_{2n}(x) \equiv f_n(x) - \frac{G(x, f_n(x))}{\frac{\partial G}{\partial y}(x, f_n(x))} \pmod {x^{2n}} \]

在实际情况中,右侧的式子往往是好求的(一般为 \(O(n\log n)\))。

6.2 多项式求逆

通过牛顿公式,我们能得到同样的式子。

假设我们要求 \(F(x)\) 的逆,取 \(G(x, y) = \frac{1}{y} - F(x)\)。很容易得到 \(f_1 = ([x^0]F)^{-1}\)。代入牛顿公式:

\[\begin{aligned} f_{2n}(x) &\equiv f_n(x) - \frac{G(x, f_n(x))}{\frac{\partial G}{\partial y}(x, f_n(x))}\\ &\equiv f_n(x) - \frac{1 / f_n(x) - F(x)}{-1 / f_n^2(x)}\\ &\equiv 2f_n(x) - f_n^2(x)F(x) &\pmod {x^{2n}}\\ \end{aligned} \]

6.3 多项式 \(\exp\)

通过牛顿公式,我们能得到 \(O(n\log n)\) 的解法。

假设我们要求 \(f(x) = \exp F(x)\)。根据定义,显然有 \(f_1 = 1\)。令 \(G(x, y) = \ln y - F(x)\),然后代入牛顿公式:

\[\begin{aligned} f_{2n}(x) &\equiv f_n(x) - \frac{G(x, f_n(x))}{\frac{\partial G}{\partial y}(x, f_n(x))}\\ &\equiv f_n(x) - \frac{\ln f_n(x) - F(x)}{1/f_n(x)}\\ &\equiv f_n(x)(1 - \ln f_n(x) + F(x)) &\pmod {x^{2n}} \end{aligned} \]

单次倍增显然是 \(O(n\log n)\) 的。

/**
 * calculate f(x) = exp(g(x)) with length of len
 * if should be ensured that len is a power of 2
 *
 * exp(g(x)) = e^(g(x))
 *
 * f_len(x) = f_(len/2)(x)(1 - ln(f_(len/2)(x)) + g(x))
 */
void poly_exp(ll __restrict__ *f, ll __restrict__ *g, int len)
{
    static ll buffer[MAXN];
    for (int i = 0; i < (len << 1); i++)
    {
        f[i] = 0;
    }
    f[0] = 1;
    for (int t = 2; t <= len; t <<= 1)
    {
        int t2 = t << 1;
        poly_ln(buffer, f, t);
        buffer[0] = (1 + g[0] - buffer[0] + MOD) % MOD;
        for (int i = 1; i < t; i++)
        {
            buffer[i] = (g[i] - buffer[i] + MOD) % MOD;
        }
        for (int i = t; i < t2; i++)
        {
            buffer[i] = 0;
        }
        fft(f, t2, 1);
        fft(buffer, t2, 1);
        for (int i = 0; i < t2; i++)
        {
            (f[i] *= buffer[i]) %= MOD;
        }
        fft(f, t2, -1);
        for (int i = t; i < t2; i++)
        {
            f[i] = 0;
        }
    }
}

6.4 多项式开方

假设要求 \(f(x) = \sqrt{F(x)}\)。我们一般考虑先将 \(f_1\) 单独计算,在很多题目中,这一项为 \(1\)

如果发现 \([x^0]F(x) = 0\),令 \(k\) 表示 \(F(x)\) 最低位 \(0\) 的个数,则若 \(k\) 为奇数,\(F(x)\) 无平方根;若 \(k\) 为偶数,我们可以将这些项提取出来,对剩余的项进行开方,最后在前面加入 \(\frac{k}{2}\)\(0\) 即可。

那么假设 \(f_1 = 1\)。考虑令 \(G(x, y) = y^2 - F(x)\),代入:

\[\begin{aligned} f_{2n}(x) &\equiv f_n(x) - \frac{G(x, f_n(x))}{\frac{\partial G}{\partial y}(x, f_n(x))}\\ &\equiv f_n(x) - \frac{f_n^2(x) - F(x)}{2f_n(x)}\\ &\equiv \frac{f_n^2(x) + F(x)}{2f_n(x)} &\pmod {x^{2n}} \end{aligned} \]

便可以在 \(O(n\log n)\) 计算了。

如果直接将 \(f_1 = 0\) 代入计算会有什么后果呢?例如,我们计算 \(\sqrt{x^2}\),按理来说得到 \(\pm x\)。可是,通过牛顿算法,在 \(n = 2\) 的时候会推出 \(f_{2n} = x^3 - x\) 等其他解,这显然不是我们想要的。

/**
  * calculate f(x) = sqrt(g(x)) with the length of len
  * it should be ensured that g[0] = 1
  *
  * f(x) = (f^2(x) + g(x)) / 2f(x)
  */
void poly_sqrt_1(ll __restrict__ *f, ll __restrict__ *g, int len)
{
    static ll buffer[MAXN];
    for (int i = 0; i < (len << 1); i++)
    {
        f[i] = 0;
    }
    f[0] = 1;
    for (int t = 2; t <= len; t <<= 1)
    {
        int t2 = t << 1;
        poly_inv(buffer, f, t);
        for (int i = 0; i < t; i++)
        {
            (buffer[i] *= I2) %= MOD;
        }
        for (int i = t; i < t2; i++)
        {
            buffer[i] = 0;
        }
        square(f, t);
        for (int i = 0; i < t; i++)
        {
            (f[i] += g[i]) %= MOD;
        }
        fft(f, t2, 1);
        fft(buffer, t2, 1);
        for (int i = 0; i < t2; i++)
        {
            (f[i] *= buffer[i]) %= MOD;
        }
        fft(f, t2, -1);
        for (int i = t; i < t2; i++)
        {
            f[i] = 0;
        }
    }
}

7. 其他操作

7.1 带余除法

假设我们要计算 \(f(x) = g(x) q(x) + r(x)\),其中 \(\deg r \lt \deg g\)

显然我们不能直接让 \(f\) 乘上 \(g^{-1}\),因为这样无法消除 \(r\) 的影响。无法消除是因为,\(r\) 处于低次项,我们要尝试将它转到高次项位去。

这里有一个 trick,就是我们可以将 \(\frac{1}{x}\) 代入上述式子。假设 \(\deg f = n, \deg g = m\),我们让整个式子乘上 \(x^n\) 使所有次项都为正:

\[x^nf(\frac{1}{x}) = x^mg(\frac{1}{x})x^{n - m}q(\frac{1}{x}) + x^{n - m + 1} \times x^{m - 1}r(\frac{1}{x}) \]

观察 \(x^nf(\frac{1}{x})\),它实际上就是将 \(f\) 的各项系数高低翻转了一下。令翻转之后为 \(f_R(x)\),则可以得到:

\[f_R(x) = g_R(x)q_R(x) + x^{n - m + 1}r_R(x) \]

注意到 \(q_R(x)\) 的最高次数为 \(n - m\),而 \(x^{n - m + 1}r_R(x)\) 的最低次为 \(n - m + 1\),故我们只保留 \(f_R(x) \times g_R^{-1}(x)\) 的低 \(n - m + 1\) 项即可得到 \(q_R(x)\)。之后我们就可以通过 \(f(x) - g(x)q(x)\) 得到 \(r(x)\) 了。

/**
 * calculate division: f(x) = g(x) * q(x) + r(x), where deg r < deg g
 * if n < m, this function only modifies r, and q will stay unchanged
 * otherwise, only the first m elements of r and the first (n - m + 1) elements of q will be meaningful
 */
void poly_div(ll __restrict__ *f, int n, ll __restrict__ *g, int m, ll __restrict__ *q, ll __restrict__ *r)
{
    static ll buffer[MAXN];
    if (n < m) // 若 n < m,我们无法使用上述式子,但是此时仅有余数有值,直接复制即可。
    {
        for (int i = 0; i <= n; i++)
        {
            r[i] = f[i];
        }
        return;
    }
    reverse(f, f + n + 1);
    reverse(g, g + m + 1);
    int len = 1;
    while (len < (n + m + 2))
    {
        len <<= 1;
    }
    for (int i = n + 1; i < (len << 1); i++)
    {
        f[i] = 0;
    }
    for (int i = m + 1; i < len; i++)
    {
        g[i] = 0;
    }
    poly_inv(buffer, g, len);
    for (int i = len; i < (len << 1); i++)
    {
        buffer[i] = 0;
    }
    conv(f, buffer, len << 1, q);
    reverse(f, f + n + 1);
    reverse(g, g + m + 1);
    reverse(q, q + (n - m) + 1);
    for (int i = n - m + 1; i < len; i++)
    {
        q[i] = 0;
    }
    conv(q, g, len, r);
    for (int i = 0; i < m; i++)
    {
        r[i] = (f[i] - r[i] + MOD) % MOD;
    }
    for (int i = m; i < len; i++)
    {
        r[i] = 0;
    }
}

7.2 任意模数 MTT

现在我们要对任意的模数取模。实际上,即使我们不取模,两个系数量级在 \(10^9\) 左右的多项式乘起来后,系数最大是 \(10^{24}\)。因此,我们只要尝试用一种方式维护这个大数字即可。

一种方式是三模 NTT。考虑将此多项式分别模三个 NTT 模数,最后再通过中国剩余定理得到具体的系数。这个方法常数疑似过大,但确实是可行的。

另一种方式是将所有系数拆成 \(a_i \times 10^{12} + b_i\),则我们可以给 \(A, B\) 两个多项式列出一些式子,然后再分别 NTT。

由于这个实在太不常见了,就没写代码了。

8. 完整代码

#define MAXN (1 << 20)
#define MOD 998244353
#define I2 499122177

using ll = long long;

ll pow(ll b, ll p, ll m)
{
    ll r = 1;
    while (p)
    {
        if (p & 1)
        {
            r = r * b % m;
        }
        b = b * b % m;
        p >>= 1;
    }
    return r;
}

namespace POLY
{

    void fft(ll *f, int len, int on)
    {
        for (int i = 1, j = len >> 1; i < len - 1; i++)
        {
            if (i < j)
            {
                swap(f[i], f[j]);
            }
            int k = len >> 1;
            while (j >= k)
            {
                j -= k;
                k >>= 1;
            }
            j += k;
        }
        for (int h = 2; h <= len; h <<= 1)
        {
            ll w0 = pow(3, (MOD - 1) / h, MOD);
            for (int j = 0; j < len; j += h)
            {
                ll w = 1;
                for (int i = j; i < j + (h >> 1); i++)
                {
                    ll u = f[i], t = f[i + (h >> 1)] * w % MOD;
                    f[i] = (u + t) % MOD;
                    f[i + (h >> 1)] = (u + MOD - t) % MOD;
                    w = w * w0 % MOD;
                }
            }
        }
        if (on == -1)
        {
            reverse(f + 1, f + len);
            ll inl = pow(len, MOD - 2, MOD);
            for (int i = 0; i < len; i++)
            {
                (f[i] *= inl) %= MOD;
            }
        }
    }

    /**
     * calculate h(x) = f(x) * g(x)
     * it should be ensured that len is a power of 2
     * if you want to store the result to f, omit the h parameter
     */
    void conv(ll __restrict__ *f, ll __restrict__ *g, int len, ll __restrict__ *h)
    {
        fft(f, len, 1);
        fft(g, len, 1);
        for (int i = 0; i < len; i++)
        {
            h[i] = f[i] * g[i] % MOD;
        }
        fft(f, len, -1);
        fft(g, len, -1);
        fft(h, len, -1);
    }

    /**
     * calculate f(x) * g(x) and store it to f
     * it should be ensured that len is a power of 2
     */
    void conv(ll __restrict__ *f, ll __restrict__ *g, int len)
    {
        fft(f, len, 1);
        fft(g, len, 1);
        for (int i = 0; i < len; i++)
        {
            (f[i] *= g[i]) %= MOD;
        }
        fft(f, len, -1);
        fft(g, len, -1);
    }

    /**
     * calculate f(x) * f(x) and store it to f
     */
    void square(ll *f, int len)
    {
        fft(f, len, 1);
        for (int i = 0; i < len; i++)
        {
            (f[i] *= f[i]) %= MOD;
        }
        fft(f, len, -1);
    }

    /**
     * calculate the [l, r) power of f(x) = f(x) * g(x) (self convolution), f(0) = 1
     * it should be ensured that (r - l) is a power of 2
     */
    void self_conv(ll __restrict__ *f, ll __restrict__ *g, int l, int r)
    {
        static ll buffer[MAXN];
        if (l == r - 1)
        {
            if (!l)
            {
                f[l] = 1;
            }
            return;
        }
        int mid = (l + r) >> 1;
        self_conv(f, g, l, mid);
        for (int i = l; i < 2 * r - l; i++)
        {
            if (i < mid)
            {
                buffer[i - l] = f[i];
            }
            else
            {
                buffer[i - l] = 0;
            }
        }
        conv(buffer, g, (r - l) << 1);
        for (int i = mid; i < r; i++)
        {
            (f[i] += buffer[i - l]) %= MOD;
        }
        self_conv(f, g, mid, r);
    }

    /**
     * calculate f(x) = 1 / g(x) with length of len
     * it should be ensured that len is a power of 2
     */
    void poly_inv(ll __restrict__ *f, ll __restrict__ *g, int len)
    {
        static ll buffer[MAXN];
        f[0] = pow(g[0], MOD - 2, MOD);
        for (int i = 1; i < (len << 1); i++)
        {
            f[i] = 0;
        }
        for (int t = 2; t <= len; t <<= 1)
        {
            int t2 = 2 * t;
            for (int i = 0; i < t; i++)
            {
                buffer[i] = g[i];
            }
            for (int i = t; i < t2; i++)
            {
                buffer[i] = 0;
            }
            fft(f, t2, 1);
            fft(buffer, t2, 1);
            for (int i = 0; i < t2; i++)
            {
                f[i] = f[i] * (MOD + 2 - f[i] * buffer[i] % MOD) % MOD;
            }
            fft(f, t2, -1);
            for (int i = t; i < t2; i++)
            {
                f[i] = 0;
            }
        }
    }

    /**
     * calculate f = g' with length of len
     * it should be ensured that len is a power of 2
     * this algorithm will work if f = g
     */
    void poly_der(ll *f, ll *g, int len)
    {
        for (int i = 1; i < len; i++)
        {
            f[i - 1] = g[i] * i % MOD;
        }
        f[len - 1] = 0;
    }

    /**
     * calculate f = ∫g with length of len
     * it should be ensured that len is a power of 2
     * this algorithm will work if f = g
     *
     * it's name shouldn't be int because int is a c++ keyword
     */
    void poly_int(ll *f, ll *g, int len)
    {
        for (int i = len - 2; ~i; i--)
        {
            f[i + 1] = g[i] * pow(i + 1, MOD - 2, MOD) % MOD;
        }
        f[0] = 0;
    }

    /**
     * calculate f(x) = ln(g(x)) with length of len
     * it should be ensured that len is a power of 2
     *
     * ln(g(x)) = ∫(g'(x) / g(x))
     */
    void poly_ln(ll __restrict__ *f, ll __restrict__ *g, int len)
    {
        static ll buffer[MAXN];
        for (int i = len; i < (len << 1); i++)
        {
            g[i] = 0;
        }
        poly_inv(buffer, g, len); // this function clears buffer[len, 2len)
        poly_der(g, g, len);
        conv(buffer, g, len << 1, f);
        poly_int(f, f, len);
        poly_int(g, g, len);
        g[0] = pow(buffer[0], MOD - 2, MOD);
    }

    /**
     * calculate f(x) = exp(g(x)) with length of len
     * if should be ensured that len is a power of 2
     *
     * exp(g(x)) = e^(g(x))
     *
     * f_len(x) = f_(len/2)(x)(1 - ln(f_(len/2)(x)) + g(x))
     */
    void poly_exp(ll __restrict__ *f, ll __restrict__ *g, int len)
    {
        static ll buffer[MAXN];
        for (int i = 0; i < (len << 1); i++)
        {
            f[i] = 0;
        }
        f[0] = 1;
        for (int t = 2; t <= len; t <<= 1)
        {
            int t2 = t << 1;
            poly_ln(buffer, f, t);
            buffer[0] = (1 + g[0] - buffer[0] + MOD) % MOD;
            for (int i = 1; i < t; i++)
            {
                buffer[i] = (g[i] - buffer[i] + MOD) % MOD;
            }
            for (int i = t; i < t2; i++)
            {
                buffer[i] = 0;
            }
            fft(f, t2, 1);
            fft(buffer, t2, 1);
            for (int i = 0; i < t2; i++)
            {
                (f[i] *= buffer[i]) %= MOD;
            }
            fft(f, t2, -1);
            for (int i = t; i < t2; i++)
            {
                f[i] = 0;
            }
        }
    }

    /**
     * calculate f(x) = (g(x))^p with length of len
     * it should be ensured that g[0] = 1
     *
     * (g(x))^p = exp(p * ln(g(x)))
     */
    void poly_pow_1(ll __restrict__ *f, ll __restrict__ *g, ll p, int len)
    {
        static ll buffer[MAXN];
        poly_ln(buffer, g, len);
        for (int i = 0; i < len; i++)
        {
            (buffer[i] *= p) %= MOD;
        }
        poly_exp(f, buffer, len);
    }

    /**
     * "i won't write poly_pow because the mod of the polynomial and the mod of the coefficients are different"
     */

    /**
     * calculate f(x) = sqrt(g(x)) with the length of len
     * it should be ensured that g[0] = 1
     *
     * f(x) = (f^2(x) + g(x)) / 2f(x)
     */
    void poly_sqrt_1(ll __restrict__ *f, ll __restrict__ *g, int len)
    {
        static ll buffer[MAXN];
        for (int i = 0; i < (len << 1); i++)
        {
            f[i] = 0;
        }
        f[0] = 1;
        for (int t = 2; t <= len; t <<= 1)
        {
            int t2 = t << 1;
            poly_inv(buffer, f, t);
            for (int i = 0; i < t; i++)
            {
                (buffer[i] *= I2) %= MOD;
            }
            for (int i = t; i < t2; i++)
            {
                buffer[i] = 0;
            }
            square(f, t);
            for (int i = 0; i < t; i++)
            {
                (f[i] += g[i]) %= MOD;
            }
            fft(f, t2, 1);
            fft(buffer, t2, 1);
            for (int i = 0; i < t2; i++)
            {
                (f[i] *= buffer[i]) %= MOD;
            }
            fft(f, t2, -1);
            for (int i = t; i < t2; i++)
            {
                f[i] = 0;
            }
        }
    }

    /**
     * calculate division: f(x) = g(x) * q(x) + r(x), where deg r < deg g
     * if n < m, this function only modifies r, and q will stay unchanged
     * otherwise, only the first m elements of r and the first (n - m + 1) elements of q will be meaningful
     */
    void poly_div(ll __restrict__ *f, int n, ll __restrict__ *g, int m, ll __restrict__ *q, ll __restrict__ *r)
    {
        static ll buffer[MAXN];
        if (n < m)
        {
            for (int i = 0; i <= n; i++)
            {
                r[i] = f[i];
            }
            return;
        }
        reverse(f, f + n + 1);
        reverse(g, g + m + 1);
        int len = 1;
        while (len < (n + m + 2))
        {
            len <<= 1;
        }
        for (int i = n + 1; i < (len << 1); i++)
        {
            f[i] = 0;
        }
        for (int i = m + 1; i < len; i++)
        {
            g[i] = 0;
        }
        poly_inv(buffer, g, len);
        for (int i = len; i < (len << 1); i++)
        {
            buffer[i] = 0;
        }
        conv(f, buffer, len << 1, q);
        reverse(f, f + n + 1);
        reverse(g, g + m + 1);
        reverse(q, q + (n - m) + 1);
        for (int i = n - m + 1; i < len; i++)
        {
            q[i] = 0;
        }
        conv(q, g, len, r);
        for (int i = 0; i < m; i++)
        {
            r[i] = (f[i] - r[i] + MOD) % MOD;
        }
        for (int i = m; i < len; i++)
        {
            r[i] = 0;
        }
    }
}

需要引用头文件 <algorithm>

posted @ 2026-02-04 18:43  cosf  阅读(0)  评论(0)    收藏  举报