【数学】【多项式】多项式求逆
写在前面
多项式求逆
前置知识:NTT
多项式求逆
给定一个多项式 \(F\left(x\right)\),求一个多项式 \(G\left(x\right)\),使得 \(F\left(x\right)G\left(x\right)\equiv 1\left(\bmod 998244353\right)\)。
考虑递归求解。
假定现在已经求出了 \(G_0\left(x\right)\),满足
\[F\left(x\right)G_0\left(x\right)\equiv 1\left(\bmod x^{\lceil\frac{n}{2}\rceil}\right)\tag 1
\]
根据要求的 \(G\left(x\right)\) 的定义,显然有
\[F\left(x\right)G\left(x\right) \equiv 1\left(\bmod x^{\lceil\frac{n}{2}\rceil}\right)\tag 2
\]
\((2) - (1)\),得
\[F\left(x\right)\left(G\left(x\right) - G_0\left(x\right)\right) \equiv 0 \left(\bmod x^{\lceil\frac{n}{2}\rceil}\right)
\]
因为 \(F\left(x\right) \not\equiv 0\left(\bmod x^{\lceil \frac{n}{2}\rceil}\right)\),所以有
\[G\left(x\right) - G_0\left(x\right) \equiv 0 \left(\bmod x^{\lceil\frac{n}{2}\rceil}\right)
\]
两边平方,得
\[G^2\left(x\right) - 2G\left(x\right)G_0\left(x\right) + G_0^2\left(x\right) \equiv 0\left(\bmod x^n\right)
\]
两边同乘 \(F\left(x\right)\),得
\[G\left(x\right) - 2G_0\left(x\right) + F\left(x\right)G_0^2\left(x\right) \equiv 0\left(\bmod x^n\right)
\]
移项整理
\[G\left(x\right) \equiv 2G_0\left(x\right) - F\left(x\right)G_0^2\left(x\right) \left(\bmod x^n\right)
\]
递归处理之后自下而上递推即可。
代码:
int rev[Maxn];
void Setrev(int len) {
for(int i = 1; i < len; ++i) {
rev[i] = rev[i >> 1] >> 1;
if(i & 1) rev[i] |= (len >> 1);
}
}
void ntt(LL p[], int len, int type) {
for(int i = 0; i < len; ++i) if(i < rev[i]) swap(p[i], p[rev[i]]);
for(int h = 2; h <= len; h <<= 1) {
LL gn = qpow(g[type], (Mod - 1) / h);
for(int j = 0; j < len; j += h) {
LL gk = 1;
for(int k = j; k < j + h / 2; ++k) {
LL e = p[k] % Mod, o = gk * p[k + h / 2] % Mod;
p[k] = (e + o) % Mod; p[k + h / 2] = ((e - o) % Mod + Mod) % Mod;
gk = gk * gn % Mod;
}
}
}
if(type == 1) {
LL invl = qpow(len, Mod - 2);
for(int i = 0; i < len; ++i) p[i] = p[i] * invl % Mod;
}
}
LL tmp[Maxn];
void polyinv(LL A[], LL B[], int siz) {
if(siz == 1) {B[0] = qpow(A[0], Mod - 2); return;}
polyinv(A, B, (siz + 1) >> 1);
int len = 1, L = (siz << 1); while(L) L >>= 1, len <<= 1;
for(int i = 0; i < siz; ++i) tmp[i] = A[i];
for(int i = siz; i < len; ++i) tmp[i] = 0;
Setrev(len); ntt(tmp, len, 0); ntt(B, len, 0);
for(int i = 0; i < len; ++i) B[i] = ((2ll * B[i] % Mod - B[i] * B[i] % Mod * tmp[i] % Mod) % Mod + Mod) % Mod;
ntt(B, len, 1);
for(int i = siz; i < len; ++i) B[i] = 0;
}
实现上的一些小细节
-
注意多项式长度,在算法没有问题的时候,长度稍微长了些并不会影响多项式求逆的结果。
-
最后那一步记得把 B 数组无用的元素清空。
-
虽然看上去用了多次 NTT,但是根据主定理(如有需要请自行搜索),复杂度仍旧是 \(\mathcal O\left(n \log n\right)\) 的。
完整代码
#include <bits/stdc++.h>
#define LL long long
using namespace std;
template <typename Temp> inline void read(Temp & res) {
Temp fh = 1; res = 0; char ch = getchar();
for(; !isdigit(ch); ch = getchar()) if(ch == '-') fh = -1;
for(; isdigit(ch); ch = getchar()) res = (res << 3) + (res << 1) + (ch ^ '0');
res = res * fh;
}
const int Maxn = 262200;
const LL Mod = 998244353, g[2] = {3, 332748118};
LL qpow(LL A, LL P) {
LL res = 1;
while(P) {
if(P & 1) res = res * A % Mod;
A = A * A % Mod;
P >>= 1;
}
return res;
}
namespace Polynomial {
int rev[Maxn];
void Setrev(int len) {
for(int i = 1; i < len; ++i) {
rev[i] = rev[i >> 1] >> 1;
if(i & 1) rev[i] |= (len >> 1);
}
}
void ntt(LL p[], int len, int type) {
for(int i = 0; i < len; ++i) if(i < rev[i]) swap(p[i], p[rev[i]]);
for(int h = 2; h <= len; h <<= 1) {
LL gn = qpow(g[type], (Mod - 1) / h);
for(int j = 0; j < len; j += h) {
LL gk = 1;
for(int k = j; k < j + h / 2; ++k) {
LL e = p[k] % Mod, o = gk * p[k + h / 2] % Mod;
p[k] = (e + o) % Mod; p[k + h / 2] = ((e - o) % Mod + Mod) % Mod;
gk = gk * gn % Mod;
}
}
}
if(type == 1) {
LL invl = qpow(len, Mod - 2);
for(int i = 0; i < len; ++i) p[i] = p[i] * invl % Mod;
}
}
void polymul(LL A[], LL B[], int siz) {
int len = 1; while(siz) siz >>= 1, len <<= 1;
Setrev(len); ntt(A, len, 0); ntt(B, len, 0);
for(int i = 0; i < len; ++i) A[i] = A[i] * B[i] % Mod;
ntt(A, len, 1);
}
LL tmp[Maxn];
void polyinv(LL A[], LL B[], int siz) {
if(siz == 1) {B[0] = qpow(A[0], Mod - 2); return;}
polyinv(A, B, (siz + 1) >> 1);
int len = 1, L = (siz << 1); while(L) L >>= 1, len <<= 1;
for(int i = 0; i < siz; ++i) tmp[i] = A[i];
for(int i = siz; i < len; ++i) tmp[i] = 0;
Setrev(len); ntt(tmp, len, 0); ntt(B, len, 0);
for(int i = 0; i < len; ++i) B[i] = ((2ll * B[i] % Mod - B[i] * B[i] % Mod * tmp[i] % Mod) % Mod + Mod) % Mod;
ntt(B, len, 1);
for(int i = siz; i < len; ++i) B[i] = 0;
}
}
int n, m;
LL a[Maxn], b[Maxn];
int main() {
read(n);
for(int i = 0; i < n; ++i) read(a[i]);
Polynomial::polyinv(a, b, n);
for(int i = 0; i < n; ++i) printf("%lld ", b[i]);
return 0;
}

浙公网安备 33010602011771号