洛谷4512:【模板】多项式除法——题解

https://www.luogu.org/problemnew/show/P4512

题面见原题。

模板题就不说什么了,解释可以看:http://blog.miskcoo.com/2015/05/polynomial-division

就有一点注意,因为我们要求的d数组的最高项是n-m的,所以求反数组的时候其长度也要限制在n-m内不然会出问题。

#include<cstdio>
#include<cctype>
#include<cstring>
#include<vector>
#include<cmath>
#include<algorithm>
#include<iostream>
using namespace std;
typedef long long ll;
const ll P=998244353;
const int G=3;
const int N=1e6+5;
inline int read(){
    int X=0,w=0;char ch=0;
    while(!isdigit(ch)){w|=ch=='-';ch=getchar();}
    while(isdigit(ch))X=(X<<3)+(X<<1)+(ch^48),ch=getchar();
    return w?-X:X;
}
ll qpow(ll a,ll n,ll p){
    ll res=1;
    while(n){
    if(n&1)res=res*a%p;
    a=a*a%p;n>>=1;
    }
    return res;
}
void MTT(ll a[],int n,int on){
    for(int i=1,j=n>>1;i<n-1;i++){
        if(i<j)swap(a[i],a[j]);
        int k=n>>1;
        while(j>=k){j-=k;k>>=1;}
        if(j<k)j+=k;
    }
    for(int i=2;i<=n;i<<=1){
    ll res=qpow(G,(P-1)/i,P);
        for(int j=0;j<n;j+=i){
        ll w=1;
            for(int k=j;k<j+i/2;k++){
                ll u=a[k],t=w*a[k+i/2]%P;
                a[k]=(u+t)%P;
                a[k+i/2]=(u-t+P)%P;
                w=w*res%P;
            }
        }
    }
    if(on==-1){
    ll inv=qpow(n,P-2,P);
    a[0]=a[0]*inv%P;
    for(int i=1;i<=n/2;i++){
        a[i]=a[i]*inv%P;
        if(i!=n-i)a[n-i]=a[n-i]*inv%P;
        swap(a[i],a[n-i]);
    }
    }
}
void inv(int deg,ll a[],ll b[]){
    static ll t[N];
    if(deg==1){
    b[0]=qpow(a[0],P-2,P);
    return;
    }
    inv((deg+1)>>1,a,b);
    int n=1;
    while(n<(deg<<1))n<<=1;
    for(int i=0;i<deg;i++)t[i]=a[i];
    for(int i=deg;i<n;i++)t[i]=0;
    MTT(t,n,1);MTT(b,n,1);
    for(int i=0;i<n;i++)
    b[i]=b[i]*(2-b[i]*t[i]%P+P)%P;
    MTT(b,n,-1);
    for(int i=deg;i<n;i++)b[i]=0;
}
//a(x)=b(x)d(x)+r(x)
void division(int n,int m,ll a[],ll b[],ll d[],ll r[]){
    static ll t1[N],t2[N];
    int nn=1,l=n-m+1;
    while(nn<(l<<1))nn<<=1;
    for(int i=0;i<m;i++)t1[i]=b[m-i-1];
    for(int i=l;i<nn;i++)t1[i]=0;
    inv(l,t1,t2);
    for(int i=l;i<nn;i++)t2[i]=0;
    MTT(t2,nn,1);
    
    for(int i=0;i<n;i++)t1[i]=a[n-i-1];
    for(int i=l;i<nn;i++)t1[i]=0;
    MTT(t1,nn,1);

    for(int i=0;i<nn;i++)t1[i]=t1[i]*t2[i]%P;
    MTT(t1,nn,-1);
    for(int i=0;i<l-i-1;i++)swap(t1[i],t1[l-i-1]);
    for(int i=0;i<l;i++)d[i]=t1[i];

    nn=1;
    while(nn<n)nn<<=1;
    for(int i=l;i<nn;i++)t1[i]=0;
    MTT(t1,nn,1);
    for(int i=0;i<m;i++)t2[i]=b[i];
    for(int i=m;i<nn;i++)t2[i]=0;
    MTT(t2,nn,1);
    for(int i=0;i<nn;i++)t1[i]=t1[i]*t2[i]%P;
    MTT(t1,nn,-1);
    for(int i=0;i<m-1;i++)r[i]=((a[i]-t1[i])%P+P)%P;
}
int n,m;
ll f[N],g[N],q[N],r[N];
int main(){
    n=read()+1,m=read()+1;
    for(int i=0;i<n;i++)f[i]=read();
    for(int i=0;i<m;i++)g[i]=read();
    division(n,m,f,g,q,r);
    for(int i=0;i<n-m+1;i++)printf("%lld ",q[i]);
    puts("");
    for(int i=0;i<m-1;i++)printf("%lld ",(r[i]+P)%P);
    puts("");
    return 0;
}

+++++++++++++++++++++++++++++++++++++++++++

+本文作者:luyouqi233。               +

+欢迎访问我的博客:http://www.cnblogs.com/luyouqi233/ +

+++++++++++++++++++++++++++++++++++++++++++

 

posted @ 2018-05-08 13:26  luyouqi233  阅读(366)  评论(0编辑  收藏  举报