多项式学习笔记(三): 多项式全家桶

1.多项式求逆

给你 \(A(x)\)\(A(x)B(x) \equiv 1 \pmod {x^n}\) 。 (模 \(x^n\) 是为了把高次项舍掉)

假设我们已经得到了满足 \(C(x)A(x) \equiv 1 \pmod {x^{n\over 2}}\) 的一个多项式 \(C\)

那么由题意可得 \(A(x)B(x)\equiv 1 \pmod {x^{n\over 2}}\)

两式联立可得:

\(B(x) \equiv C(x) \pmod {x^{n\over 2}}\)

\(B(x) - C(x) \equiv 0 \pmod {x^{n\over 2}}\)

两边同时平方可得:

\(B^2(x) + C^2(x) - 2B(x)C(x) \equiv 0 \pmod {x^{n}}\)

在同时乘上一个 \(A(x)\) 得:

\(A(x)B^2(x) + A(x)C^2(x)-2A(x)B(x)C(x)\equiv 0 \pmod {x^{n}}\)

然后由题意可得 \(A(x)B(x)\equiv 1 \pmod {x^n}\) ,代入化简可得:

\(B(x) + A(x)C^2(x)-2C(x) \equiv 0 \pmod {x^n}\)

\(B(x) = 2C(x) - A(x)C^2(x)\)

然后,我们每次都可以把项数减半递归求解, 如果项数为 \(1\) 的话结果显然是零次项的逆元。

复杂度 \(T(n) = T({n\over 2}) + nlogn = nlogn\)

Code

#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
#define int long long
const int N = 1e6+10;
const int p = 998244353;
int n,a[N],b[N],rev[N],c[N];
inline int read()
{
    int s = 0,w = 1; char ch = getchar();
    while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
    while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0'; ch = getchar();}
    return s * w;
}
int ksm(int a,int b)
{
    int res = 1;
    for(; b; b >>= 1)
    {
        if(b & 1) res = res * a % p;
        a = a * a % p;
    }
    return res;
}
void NTT(int *a,int len,int opt)
{
    for(int i = 0; i < len; i++)
    {
        if(i < rev[i]) swap(a[i],a[rev[i]]);
    }
    for(int h = 1; h < len; h <<= 1)
    {
        int wn = ksm(3,(p-1)/(h<<1));
        if(opt == -1) wn = ksm(wn,p-2);
        for(int j = 0; j < len; j += (h<<1))
        {
            int w = 1;
            for(int k = 0; k < h; k++)
            {
                int u = a[j + k];
                int v = w * a[j + h + k] % p;
                a[j + k] = (u + v) % p;
                a[j + h + k] = (u - v + p) % p;
                w = w * wn % p;
            }
        }
    }
    if(opt == -1)
    {
        int inv = ksm(len,p-2);
        for(int i = 0; i < len; i++) a[i] = a[i] * inv % p;
    }
}
void Inv(int n,int *a,int *b)//求 A(x)B(x) = 1 mod x^n
{
    if(n == 1)//项数为1的情况
    {
        b[0] = ksm(a[0],p-2);
        return;
    }
    Inv((n+1)>>1,a,b);//递归求 C(x)
    int lim = 1, tim = 0;
    while(lim < (n<<1)) lim <<= 1, tim++;
    for(int i = 0; i < lim; i++)//预处理NTT的反转数组
    {
        rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
    }
    //注意,不能用 a 来做多项式乘法,因为如果拿 a 做了多项式乘法,那么 a 的值在递归过程中,就会发生改变。
    for(int i = 0; i < n; i++) c[i] = a[i];//把 a 赋给 c,用 c 来做多项式乘法
    for(int i = n; i < lim; i++) c[i] = 0;//多余的高次项舍去
    //此时的 B 数组存的是 B(x)A(x) = 1 mod x^{n/2},C数组存的是 A(x)
    NTT(c,lim,1); NTT(b,lim,1);//求 B 和 C 的点值表示法
    for(int i = 0; i < lim; i++) b[i] = (2 * b[i] % p - b[i] * b[i] % p * c[i] % p + p) % p;//计算 B的点值
    NTT(b,lim,-1);//把B转化为系数表示法
    for(int i = n; i < lim; i++) b[i] = 0;//高次项舍去 
}
signed main()
{
    n = read();
    for(int i = 0; i < n; i++) a[i] = read();
    Inv(n,a,b);
    for(int i = 0; i < n; i++) printf("%lld ",b[i]);
    printf("\n");
    return 0;
}

2.多项式开根

\(B^2(x) \equiv A(x) \pmod {x^n}\)

假设,我们得到了满足 \(C^2(x) \equiv A(x) \pmod {x^{n\over 2}}\) 的一个多项式 \(C(x)\)

又因为 \(B^2(x) \equiv A(x) \pmod {x^{n\over 2}}\)

两式联立可得:

\(B^2(x) \equiv C^2(x) \pmod {x^{n\over 2}}\)

\(B^2(x)-C^2(x) \equiv 0 \pmod {x^{n\over 2}}\)

两边同时平方可得:

\(B^4(x) + C^4(x) - 2B^2(x)C^2(x) \equiv 0 \pmod {x^n}\)

两边同时加上 \(4B^2(x)C^2(x)\) 可得:

\(B^4(x) + C^4(x) + 2B^2(x)C^2(x) \equiv 4B^2(x)C^2(x) \pmod {x^n}\)

\((B^2(x) + C^2(x))^2 \equiv 4B^2(x)C^2(x) \pmod {x^n}\)

把右边的 \(4C^2(x)\) 除过去可得:

\({(B^2(x) + C^2(x))^2 \over 4C^2(x)} \equiv B^2(x) \pmod {x^n}\)

\(B(x) \equiv {B^2(x) + C^2(x)\over 2C(x)} \pmod {x^n}\)

又因为 \(B^2(x) \equiv A(x) \pmod {x^n}\) ,代入可得:

\(B(x) \equiv {A(x) + C^2(x)\over 2C(x)} \pmod {x^n}\)

还是像求逆一样每次项数减半,递归求解,当项数为 \(1\) 的时候答案为 \(\sqrt {常数项}\)

多项式求逆加NTT即可。

复杂度 \(O(nlogn)\)

Code(常数爆炸):

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
using namespace std;
#define int long long
const int N = 1e6+10;
const int p = 998244353;
int n,a[N],b[N],c[N],d[N],rev[N];
inline int read()
{
    int s = 0,w = 1; char ch = getchar();
    while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
    while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0'; ch = getchar();}
    return s * w;
}
int ksm(int a,int b)
{
    int res = 1;
    for(; b; b >>= 1)
    {
        if(b & 1) res = res * a % p;
        a = a * a % p;
    }
    return res;
}
void NTT(int *a,int len,int opt)//NTT 板子
{
    for(int i = 0; i < len; i++)
    {
        if(i < rev[i]) swap(a[i],a[rev[i]]);
    }
    for(int h = 1; h < len; h <<= 1)
    {
        int wn = ksm(3,(p-1)/(h<<1));
        if(opt == -1) wn = ksm(wn,p-2);
        for(int j = 0; j < len; j += (h<<1))
        {
            int w = 1;
            for(int k = 0; k < h; k++)
            {
                int u = a[j + k];
                int v = w * a[j + h + k] % p;
                a[j + k] = (u + v) % p;
                a[j + h + k] = (u - v + p) % p;
                w = w * wn % p;
            }
        }
    }
    if(opt == -1)
    {
        int inv = ksm(len,p-2);
        for(int i = 0; i < len; i++) a[i] = a[i] * inv % p;
    }
}
void Inv(int n,int *a,int *b)//多项式求逆板子
{
    if(n == 1)
    {
        b[0] = ksm(a[0],p-2);
        return;
    }
    Inv((n+1)>>1,a,b);
    int lim = 1, tim = 0;
    while(lim < (n<<1)) lim <<= 1, tim++;
    for(int i = 0; i < lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
    for(int i = 0; i < n; i++) c[i] = a[i];
    for(int i = n; i < lim; i++) c[i] = 0;
    NTT(c,lim,1); NTT(b,lim,1);
    for(int i = 0; i < lim; i++) b[i] = (2 * b[i] % p - b[i] * b[i] % p * c[i] % p + p) % p;
    NTT(b,lim,-1);
    for(int i = n; i < lim; i++) b[i] = 0;//记得清空
}
void sqrt(int n,int *a,int *b)
{
    if(n == 1)//项数为 1的情况
    {
        b[0] = (int) sqrt(a[0]);
        return;
    } 
    sqrt((n+1)>>1,a,b);    
	Inv(n,b,d);//这里求 mod x^n 下的逆元,而不是 mod x^lim 下的逆元 
    int lim = 1, tim = 0;
    while(lim < (n<<1)) lim <<= 1, tim++;
    for(int i = 0; i < lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
    for(int i = 0; i < n; i++) c[i] = a[i];//用c数组代替a来做多项式乘法
    for(int i = n; i < lim; i++) c[i] = 0;
    //这里 b 数组存的是 C^2(x) = A(x) mod x^{n/2}
    // c数组 存的是 A(x), d数组存的是 C(x) 的乘法逆
    NTT(b,lim,1); NTT(c,lim,1); NTT(d,lim,1);
    int inv2 = ksm(2,p-2);
    for(int i = 0; i < lim; i++) b[i] = (b[i] * b[i] % p + c[i] % p) * d[i] % p * inv2 % p;//根据柿子算出 B(x) 的点值
    NTT(b,lim,-1);//转换为系数表示法
    for(int i = n; i < lim; i++) b[i] = 0;   
    for(int i = 0; i < lim; i++) d[i] = 0;//多次调用要清空
} 
signed main()
{
    n = read();
    for(int i = 0; i < n; i++) a[i] = read();
    sqrt(n,a,b);
    for(int i = 0; i < n; i++) printf("%lld ",b[i]);
    return 0;
}

3.多项式求导

\(A(x) = \displaystyle\sum_{i=0}^{n} a_ix^i\) , 则 \(A^\prime(x) = \displaystyle\sum_{i=0}^{n} ia_{i}x^{i-1}\)

void qiudao(int len,int *a,int *b)
{
    for(int i = 1; i < len; i++) b[i-1] = i * a[i] % p;
    b[len-1] = 0;
}

5.多项式积分

\(A(x) = \displaystyle\sum_{i=0}^{n}a_ix^i\) ,则 \(\int A(x) = \displaystyle\sum_{i=1}^{n} {a_i\over i+1} x^{i+1}\)

void jifen(int len,int *a,int *b)
{
    for(int i = 1; i < len; i++) b[i] = a[i-1] * ksm(i,p-2) % p;
    b[0] = 0;
}

6.多项式 ln

\(B(x) \equiv lnA(x) \pmod {x^n}\)

\(F(x) = lnA(x)\) ,则 对等式两边同时求导可得:

\(B^\prime(x) \equiv F^\prime(x) \pmod {x^n}\)

根据复合函数求导公式 \(f^\prime(g(x)) = f^\prime(g(x)) g^\prime(x)\) 可得:

\(B^\prime(x) \equiv {A^\prime (x)\over A(x)} \pmod {x^n}\)

先求出 \(A(x)\) 的导函数和乘法逆,在相乘得到 \(B^\prime(x)\) ,最后在积分回去即可。

多项式求逆,多项式求导,多项式积分,多项式乘法。

复杂度 \(O(nlogn)\)

code

#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
#define int long long
const int p = 998244353;
const int N = 1e6+10;
int n,a[N],b[N],c[N],rev[N],A[N],B[N];
inline int read()
{
    int s = 0,w = 1; char ch = getchar();
    while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
    while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0'; ch = getchar();}
    return s * w;
}
int ksm(int a,int b)
{
    int res = 1;
    for(; b; b >>= 1)
    {
        if(b & 1) res = res * a % p;
        a = a * a % p;
    }
    return res;
}
void NTT(int *a,int len,int opt)
{
    for(int i = 0; i < len; i++)
    {
        if(i < rev[i]) swap(a[i],a[rev[i]]);
    }
    for(int h = 1; h < len; h <<= 1)
    {
        int wn = ksm(3,(p-1)/(h<<1));
        if(opt == -1) wn = ksm(wn,p-2);
        for(int j = 0; j < len; j += (h<<1))
        {
            int w = 1;
            for(int k = 0; k < h; k++)
            {
                int u = a[j + k];
                int v = w * a[j + h + k] % p;
                a[j + k] = (u + v) % p;
                a[j + h + k] = (u - v + p) % p;
                w = w * wn % p;
            }
        }
    }
    if(opt == -1)
    {
        int inv = ksm(len,p-2);
        for(int i = 0; i < len; i++) a[i] = a[i] * inv % p;
    }
}
void Inv(int n,int *a,int *b)
{
    if(n == 1)
    {
        b[0] = ksm(a[0],p-2);
        return;
    }
    Inv((n+1)>>1,a,b);
    int lim = 1, tim = 0;
    while(lim < (n<<1)) lim <<= 1, tim++;
    for(int i = 0; i < lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
    for(int i = 0; i < n; i++) c[i] = a[i];
    for(int i = n; i < lim; i++) c[i] = 0;
    NTT(b,lim,1); NTT(c,lim,1);
    for(int i = 0; i < lim; i++) b[i] = (2 * b[i] % p - b[i] * b[i] % p * c[i] % p + p) % p;
    NTT(b,lim,-1);
    for(int i = n; i < lim; i++) b[i] = 0;
}
void qiudao(int len,int *a,int *b)
{
    for(int i = 1; i < len; i++) b[i-1] = i * a[i] % p;
    b[len-1] = 0;
}
void jifen(int len,int *a,int *b)
{
    for(int i = 1; i < len; i++) b[i] = a[i-1] * ksm(i,p-2) % p;
    b[0] = 0;
}
void Ln(int n,int *a,int *b)
{
    Inv(n,a,A); qiudao(n,a,B);//A 存的是 a的乘法逆,B存的是 a的导函数
    int lim = 1, tim = 0;
    while(lim < (n<<1)) lim <<= 1, tim++;
    for(int i = 0; i < lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
    NTT(A,lim,1); NTT(B,lim,1);
    for(int i = 0; i < lim; i++) B[i] = B[i] * A[i] % p;
    NTT(B,lim,-1); jifen(lim,B,b);//B存的是 b 的导函数
    for(int i = n; i < lim; i++) b[i] = 0;
}
signed main()
{
    n = read();
    for(int i = 0; i < n; i++) a[i] = read();
    Ln(n,a,b);
    for(int i = 0; i < n; i++) printf("%lld ",b[i]);
    return 0;
}

7.多项式除法

给你一个 \(n\) 次多项式 \(A(x)\) 和一个 \(m\) 次的多项式 \(B(x)\),求多项式 \(C(x)\)\(D(x)\) 满足:

  1. \(C(x)\) 的次数为 \(n-m\), \(D(x)\) 的次数小于 \(m\)
  2. \(A(x) = C(x) * B(x) + D(x)\)

\(f(x)\) 是一个 \(n\) 次多项式,则定义 \(inv(f(x)) = x^nf({1\over x})\)

\(inv(f(x)) = x^n f({1\over x}) = x^n(a_0+a_1x^{-1}+...a_nx^{-n}) = a_{n} + a_{n-1}x^1 + a_{n-2}x^2+....a_{1}x^{n-1} + a_0x^{n}\)

所以 \(inv(f(x))\) 其实就是把 \(f(x)\) 的系数反转过来得到的结果。

\(\because A(x) = C(x) * B(x) + D(x)\)

所以有 \(inv(A(x)) = inv(C(x) * B(x) + D(x))\)

展开可得:

\(x^nA({1\over x}) = x^{n} (C({1\over x}) * B({1\over x}) + D({1\over x}))\)

\(x^nA({1\over x}) = x^mB({1\over x}) x^{n-m} C({1\over x}) + x^{n-m+1} x^{m-1} D({1\over x})\)

在转化为 \(inv(f(X))\) 可得:

\(inv(A(x)) = inv(B(x))inv(C(x)) + x^{n-m+1}inv(D(x))\)

两边同时模上 \(x^{n-m+1}\) 可得:

\(invA(x) \equiv inv(B(x))inv(C(x)) \pmod {x^{n-m+1}}\)

\(inv(C(x)) \equiv {inv(A(x))\over invB(x)} \pmod {x^{n-m+1}}\)

多项式乘法和多项式求逆可以求出来 \(inv(C(x))\), 在把系数反转得到 \(C(x)\).

最后把 \(C(x)\) 代入原式可得到 \(D(x)\).

复杂度 \(O(nlogn)\)

一定要注意清空数组(我这个沙比就因为这个卡在了50分好多次 )

Code:

#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
#define int long long
const int p = 998244353;
const int N = 1e6+10;
int n,m,rev[N],a[N],b[N],c[N],d[N],A[N],B[N],invB[N];
inline int read()
{
    int s = 0,w = 1; char ch = getchar();
    while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
    while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0'; ch = getchar();}
    return s * w;
}
int ksm(int a,int b)
{
    int res = 1;
    for(; b; b >>= 1)
    {
        if(b & 1) res = res * a % p;
        a = a * a % p;
    }
    return res;
}
void NTT(int *a,int len,int opt)
{
    for(int i = 0; i < len; i++)
    {
        if(i < rev[i]) swap(a[i],a[rev[i]]);
    }
    for(int h = 1; h < len; h <<= 1)
    {
        int wn = ksm(3,(p-1)/(h<<1));
        if(opt == -1) wn = ksm(wn,p-2);
        for(int j = 0; j < len; j += (h<<1))
        {
            int w = 1;
            for(int k = 0; k < h; k++)
            {
                int u = a[j + k];
                int v = w * a[j + h + k] % p;
                a[j + k] = (u + v) % p;
                a[j + h + k] = (u - v + p) % p;
                w = w * wn % p;
            }
        }
    }
    if(opt == -1)
    {
        int inv = ksm(len,p-2);
        for(int i = 0; i < len; i++) a[i] = (a[i] * inv % p + p) % p;
    }
}
void Inv(int n,int *a,int *b)
{
    if(n == 1)
    {
        b[0] = ksm(a[0],p-2);
        return;
    }
    Inv((n+1)>>1,a,b);
    int lim = 1, tim = 0;
    while(lim < (n<<1)) lim <<= 1, tim++;
    for(int i = 0; i < lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
    for(int i = 0; i < n; i++) c[i] = a[i];
    for(int i = n; i < lim; i++) c[i] = 0;
    NTT(c,lim,1); NTT(b,lim,1);
    for(int i = 0; i < lim; i++) b[i] = (2 * b[i] % p - b[i] * b[i] % p * c[i] % p + p) % p;
    NTT(b,lim,-1); 
    for(int i = n; i < lim; i++) b[i] = 0;
}
void mul(int n,int m,int *a,int *b)
{
	int lim = 1, tim = 0;
	while(lim < (n<<1)) lim <<= 1, tim++;
	for(int i = 0; i <lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
	NTT(a,lim,1); NTT(b,lim,1);
	for(int i = 0; i < lim; i++) a[i] = a[i] * b[i] % p;
	NTT(a,lim,-1);
	for(int i = n; i < lim; i++) a[i] = 0; 
}
void Chu(int n,int m,int *a,int *b)
{
    for(int i = 0; i < n; i++) A[i] = a[n-i-1];//A 数组存的是 inv(A(x))
    for(int i = 0; i < m; i++) B[i] = b[m-i-1];//B 数组存的是 inv(B(x))
    Inv(n-m+1,B,invB); 
    for(int i = n-m+1; i < (n<<2); i++) A[i] = invB[i] = 0;
    mul(n-m+1,n-m+1,A,invB); 
    for(int i = 0; i < n-m+1; i++) c[i] = (A[n-m-i] % p + p) % p;
	for(int i = 0; i < n-m+1; i++) printf("%lld ",c[i]); 
    printf("\n");
    for(int i = n-m+1; i < (n<<2); i++) c[i] = 0;
    mul(n,n,c,b);
    for(int i = 0; i < m-1; i++) d[i] = ((a[i] - c[i]) % p + p) % p;
    for(int i = 0; i < m-1; i++) printf("%lld ",d[i]);
} 
signed main()
{
    n = read() + 1; m = read() + 1;
    for(int i = 0; i < n; i++) a[i] = read();
    for(int i = 0; i < m; i++) b[i] = read();
    Chu(n,m,a,b);
    return 0;
}

8.多项式 exp

\(B(x) \equiv e^{A(x)} \pmod {x^n}\)

\(C(x) \equiv e^{A(x)} \pmod {x^{n\over 2}}\) ,则 \(B(x) = C(x) (1-lnC(x) + A(x))\)

多项式求逆即可。

注意:每次求 \(exp\) 的时候,一定要把求 \(ln\) 所有用到的数组都清空掉。

code:

#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
#define int long long
const int p = 998244353;
const int N = 1e6+10;
int n,a[N],b[N],c[N],invB[N],invA[N],A[N],B[N],rev[N];
inline int read()
{
    int s = 0,w = 1; char ch = getchar();
    while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
    while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0'; ch = getchar();}
    return s * w;
} 
int ksm(int a,int b)
{
    int res = 1;
    for(; b; b >>= 1)
    {
        if(b & 1) res = res * a % p;
        a = a * a % p;
    }
    return res;
}
void NTT(int *a,int len,int opt)
{
    for(int i = 0; i < len; i++)
    {
        if(i < rev[i]) swap(a[i],a[rev[i]]);
    }
    for(int h = 1; h < len; h <<= 1)
    {
        int wn = ksm(3,(p-1)/(h<<1));
        if(opt == -1) wn = ksm(wn,p-2);
        for(int j = 0; j < len; j += (h<<1))
        {
            int w = 1;
            for(int k = 0; k < h; k++)
            {
                int u = a[j + k];
                int v = w * a[j + h + k] % p;
                a[j + k] = (u + v) % p;
                a[j + h + k] = (u - v + p) % p;
                w = w * wn % p;
            }
        }
    }
    if(opt == -1)
    {
        int inv = ksm(len,p-2);
        for(int i = 0; i < len; i++) a[i] = a[i] * inv % p;
    }
}
void Inv(int n,int *a,int *b)
{
    if(n == 1)
    {
        b[0] = ksm(a[0],p-2);
        return;
    }
    Inv((n+1)>>1,a,b);
    int lim = 1, tim = 0;
    while(lim < (n<<1)) lim <<= 1, tim++;
    for(int i = 0; i < lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
    for(int i = 0; i < n; i++) c[i] = a[i];
    for(int i = n; i < lim; i++) c[i] = 0;
    NTT(c,lim,1); NTT(b,lim,1);
    for(int i = 0; i < lim; i++) b[i] = (2 * b[i] % p - b[i] * b[i] % p * c[i] % p + p) % p;
    NTT(b,lim,-1);
    for(int i = n; i < lim; i++) b[i] = 0;
}
void qiudao(int n,int *a,int *b)
{
    for(int i = 1; i < n; i++) b[i-1] = i * a[i] % p;
    b[n-1] = 0;
}
void jifen(int n,int *a,int *b)
{
    for(int i = 1; i < n; i++) b[i] = a[i-1] * ksm(i,p-2) % p;
    b[0] = 0;
}
void Ln(int n,int *a,int *b)
{
    Inv(n,a,invA); qiudao(n,a,A);
    int lim = 1, tim = 0;
    while(lim < (n<<1)) lim <<= 1, tim++;
    for(int i = 0; i < lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
    NTT(invA,lim,1); NTT(A,lim,1);
    for(int i = 0; i < lim; i++) B[i] = invA[i] * A[i] % p;
    NTT(B,lim,-1); jifen(lim,B,b);
    for(int i = n; i < lim; i++) b[i] = 0;
}
void Exp(int n,int *a,int *b)
{
    if(n == 1)
    {
        b[0] = 1;
        return;
    }
    Exp((n+1)>>1,a,b);
    int lim = 1, tim = 0;
    while(lim < (n<<1)) lim <<= 1, tim++;
    for(int i = 0; i < lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
    for(int i = 0; i < lim; i++) B[i] = A[i] = invA[i]= invB[i] = 0;
    Ln(n,b,invB); 
    for(int i = 0; i < n; i++) c[i] = a[i];
    for(int i = n; i < lim; i++) c[i] = 0;
    NTT(c,lim,1); NTT(invB,lim,1); NTT(b,lim,1);
    for(int i = 0; i < lim; i++) b[i] = b[i] * (1LL - invB[i] + c[i] + p) % p;
    NTT(b,lim,-1);
    for(int i = n; i < lim; i++) b[i] = 0;
}
signed main()
{
    n = read();
    for(int i = 0; i < n; i++) a[i] = read();
    Exp(n,a,b);
    for(int i = 0; i < n; i++) printf("%lld ",b[i]);
    printf("\n");
    return 0;
}

9.多项式快速幂

\(B(x) \equiv A^k(x) \pmod {x^n}\)

做法1: 倍增多项式乘法

和普通的快速幂一样,只不过在相乘的时候是把两个多项式乘起来。

常数比较大。

做法2: 多项式求 ln 多项式exp

等式两边同时取对数可得:

\(ln B(x) \equiv klnA(x) \pmod {x^n}\)

在同时取指数可得:

\(B(x) \equiv e^{klnA(x)} \pmod {x^n}\)

多项式求逆求出 \(lnA(x)\) ,把每一项乘个 \(k\), 最后在 exp回去即可。

复杂度 \(O(nlogn)\)

注意:每次求 exp 的时候,一定要把 求 \(ln\) 用到的数组清空。

Code

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
#define int long long
const int p = 998244353;
const int N = 1e6+10;
int n,k,a[N],b[N],c[N],rev[N],A[N],B[N],invA[N],invB[N],F[N];
char s[N];
inline int read()
{
    int s = 0,w = 1; char ch = getchar();
    while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
    while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0'; ch = getchar();}
    return s * w;
}
int ksm(int a,int b)
{
    int res = 1;
    for(; b; b >>= 1)
    {
        if(b & 1) res = res * a % p;
        a = a * a % p;
    }
    return res;
}
void NTT(int *a,int len,int opt)
{
    for(int i = 0; i < len; i++)
    {
        if(i < rev[i]) swap(a[i],a[rev[i]]);
    }
    for(int h = 1; h < len; h <<= 1)
    {
        int wn = ksm(3,(p-1)/(h<<1));
        if(opt == -1) wn = ksm(wn,p-2);
        for(int j = 0; j < len; j += (h<<1))
        {
            int w = 1;
            for(int k = 0; k < h; k++)
            {
                int u = a[j + k];
                int v = w * a[j + h + k] % p;
                a[j + k] = (u + v) % p;
                a[j + h + k] = (u - v + p) % p;
                w = w * wn % p;
            }
        }
    }
    if(opt == -1)
    {
        int inv = ksm(len,p-2);
        for(int i = 0; i < len; i++) a[i] = a[i] * inv % p;
    }
}
void Inv(int n,int *a,int *b)
{
    if(n == 1)
    {
        b[0] = ksm(a[0],p-2);
        return;
    }
    Inv((n+1)>>1,a,b);
    int lim = 1, tim = 0;
    while(lim < (n<<1)) lim <<= 1, tim++;
    for(int i = 0; i < lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
    for(int i = 0; i < n; i++) c[i] = a[i];
    for(int i = n; i < lim; i++) c[i] = 0;
    NTT(c,lim,1); NTT(b,lim,1);
    for(int i = 0; i < lim; i++) b[i] = (2 * b[i] % p - b[i] * b[i] % p * c[i] % p + p) % p;
    NTT(b,lim,-1);
    for(int i = n; i < lim; i++) b[i] = 0;
}
void qiudao(int n,int *a,int *b)
{
    for(int i = 1; i < n; i++) b[i-1] = i * a[i] % p;
    b[n-1] = 0;
}
void jifen(int n,int *a,int *b)
{
    for(int i = 1; i < n; i++) b[i] = a[i-1] * ksm(i,p-2) % p;
    b[0] = 0;
}
void Ln(int n,int *a,int *b)
{
    Inv(n,a,invA); qiudao(n,a,A);
    int lim = 1, tim = 0;
    while(lim < (n<<1)) lim <<= 1, tim++;
    for(int i = 0; i < lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
    NTT(invA,lim,1); NTT(A,lim,1);
    for(int i = 0; i < lim; i++) B[i] = A[i] * invA[i] % p;
    NTT(B,lim,-1); jifen(n,B,b);
    for(int i = n; i < lim; i++) b[i] = 0;
} 
void Exp(int n,int *a,int *b)
{
    if(n == 1)
    {
        b[0] = 1;
        return;
    }
    Exp((n+1)>>1,a,b);
    int lim = 1, tim = 0;
    while(lim < (n<<1)) lim <<= 1, tim++;
    for(int i = 0; i < lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
    for(int i = 0; i < lim; i++) B[i] = A[i] = invA[i] = invB[i] = 0;//这里一定要清空ln所有用到的所有的数组,我在exp板子的时候只清空了两个数组既然还过了,就nm离谱
    Ln(n,b,invB);
    for(int i = 0; i < n; i++) c[i] = a[i];
    for(int i = n; i < lim; i++) c[i] = 0;
    NTT(b,lim,1); NTT(invB,lim,1); NTT(c,lim,1);
    for(int i = 0; i < lim; i++) b[i] = b[i] * (1LL - invB[i] + c[i] + p) % p;
    NTT(b,lim,-1);
    for(int i = n; i < lim; i++) b[i] = 0; 
}
void kuaisumi(int n,int k,int *a)
{
    Ln(n,a,F);//F 存的是 ln(A(x))
    for(int i = 0; i < n; i++) F[i] = F[i] * k % p;
    Exp(n,F,b);//exp回去
}
signed main()
{
    n = read(); scanf("%s",s+1);
    for(int i = 1; i <= (int) strlen(s+1); i++) k = (k * 10 + s[i] - '0') % p;
    for(int i = 0; i < n; i++) a[i] = read();
    kuaisumi(n,k,a);
    for(int i = 0; i < n; i++) printf("%lld ",b[i]);
    printf("\n");
    return 0;
}
posted @ 2021-01-11 06:35  genshy  阅读(249)  评论(0编辑  收藏  举报