【uoj34】 多项式乘法

http://uoj.ac/problem/34 (题目链接)

题意

  求两个多项式的乘积

Solution

  挂个FFT板子。

细节

  FFT因为要满足$n$是$2$的幂,所以注意数组大小。

代码

// uoj34
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<complex>
#include<cstdio>
#include<cmath>
#define LL long long
#define inf 2147483640
#define Pi acos(-1.0)
#define free(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout);
using namespace std;
 
typedef complex<double> E;
const int maxn=300010;
E a[maxn],b[maxn];
int n,m;
 
namespace FFT {
    int rev[maxn],L;
    void DFT(E *a,int f) {
        for (int i=0;i<n;i++) if (i<rev[i]) swap(a[i],a[rev[i]]);
        for (int i=1;i<n;i<<=1) {
            E wn(cos(Pi/i),f*sin(Pi/i));
            for (int p=i<<1,j=0;j<n;j+=p) {
                E w(1,0);
                for (int k=0;k<i;k++,w*=wn) {
                    E x=a[j+k],y=w*a[j+k+i];
                    a[j+k]=x+y;a[j+k+i]=x-y;
                }
            }
        }
		if (f==-1) for (int i=0;i<n;i++) a[i].real()/=n;
    }
    void main() {
        m=n+m;
        for (n=1;n<=m;n<<=1) L++;   //一定是<=,因为这里的m是最高次幂
        for (int i=0;i<n;i++) rev[i]=(rev[i>>1]>>1) | ((i&1)<<(L-1));
        DFT(a,1);DFT(b,1);
        for (int i=0;i<n;i++) a[i]=a[i]*b[i];
        DFT(a,-1);
    }
}
int main() {
    scanf("%d%d",&n,&m);
    for (int i=0,x;i<=n;i++) scanf("%d",&x),a[i]=x;
    for (int i=0,x;i<=m;i++) scanf("%d",&x),b[i]=x;
    FFT::main();
    for (int i=0;i<=m;i++) printf("%d ",(int)(a[i].real()+0.5));
    return 0;
}

Solution

  ${NTT}$,适用于对一些形如 ${p=C*2^k+1}$的数取模,且${2^k>n}$(当然也可以将不取模但结果不会超过某个范围视作取模)的多项式乘法问题。

  一些常见的${NTT}$模数:

  ${998244353=119*2^{23}+1}$,原根为${3}$

  ${1004535809=479*2^{21}+1}$,原根为${3}$。

  ${15*2^{112}+1}$,原根为${1111}$。

  详情请见Xlightgod的博客

代码

// uoj34
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<cstdio>
#include<cmath>
#define LL long long
#define inf 2147483640
#define MOD 998244353
#define Pi acos(-1.0)
#define free(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout);
using namespace std;
  
const int maxn=300010;
int a[maxn],b[maxn],rev[maxn],n,m,L;
 
int power(int a,int b) {
    int res=1;
    while (b) {
        if (b&1) res=1LL*res*a%MOD;
        a=1LL*a*a%MOD;b>>=1;
    }
    return res;
}
void NTT(int *a,int f) {
    for (int i=0;i<n;i++) if (i<rev[i]) swap(a[i],a[rev[i]]);
    for (int i=1;i<n;i<<=1) {
        int gn=power(3,(MOD-1)/(i<<1));   //这里除的是i<<1
        for (int p=i<<1,j=0;j<n;j+=p) {
            int g=1;
            for (int k=0;k<i;k++,g=1LL*g*gn%MOD) {
                int x=a[k+j],y=1LL*g*a[k+j+i]%MOD;
                a[k+j]=(x+y)%MOD;a[k+j+i]=(x-y+MOD)%MOD;
            }
        }
    }
    if (f==-1) {
		int ev=power(n,MOD-2);reverse(a+1,a+n);   //reverse的是[1,n)
		for (int i=0;i<n;i++) a[i]=1LL*a[i]*ev%MOD;
	}
}
int main() {
    scanf("%d%d",&n,&m);
    for (int i=0;i<=n;i++) scanf("%d",&a[i]);
    for (int i=0;i<=m;i++) scanf("%d",&b[i]);
    m=n+m;
    for (n=1;n<=m;n<<=1) L++;
    for (int i=0;i<n;i++) rev[i]=(rev[i>>1]>>1) | ((i&1)<<(L-1));
    NTT(a,1);NTT(b,1);
    for (int i=0;i<n;i++) a[i]=1LL*a[i]*b[i]%MOD;
    NTT(a,-1);
    for (int i=0;i<=m;i++) printf("%d ",a[i]);
    return 0;
}

Solution3

  听说还有任意模数的${NTT}$,比如说对${1000000007}$取模,那么这显然是不能直接${NTT}$的,直接${FFT}$转成整型取模的时候又会爆LL。我用的是毛爷爷的做法,把一个数${x}$拆成${x=a*M+b}$,${M}$是模数的算术平方根。这样就能避免爆LL了。

  具体见上面那个链接:Xlightgod。

代码

// uoj34
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<complex>
#include<cstdio>
#include<cmath>
#include<queue>
#define LL long long
#define inf 1ll<<60
#define MOD 1000000007
#define M 32768
#define Pi acos(-1.0)
#define free(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout);
using namespace std;
 
typedef complex<double> E;
const int maxn=300010;
E a[maxn],b[maxn],c[maxn],d[maxn],A[maxn],B[maxn],C[maxn];
int n,m,L,rev[maxn];

void FFT(E *a,int f) {
	for (int i=0;i<n;i++) if (rev[i]>i) swap(a[i],a[rev[i]]);
	for (int i=1;i<n;i<<=1) {
		E wn(cos(Pi/i),f*sin(Pi/i));
		for (int p=i<<1,j=0;j<n;j+=p) {
			E w(1,0);
			for (int k=0;k<i;k++,w*=wn) {
				E x=a[j+k],y=w*a[j+k+i];
				a[j+k]=x+y;a[j+k+i]=x-y;
			}
		}
	}
	if (f==-1) for (int i=0;i<n;i++) a[i].real()=a[i].real()/n+0.5;   //这里的0.5一定要除了再加上去
}

int main() {
	scanf("%d%d",&n,&m);
	for (int x,i=0;i<=n;i++) {
		scanf("%d",&x);
		a[i]=x>>15;b[i]=x&(M-1);
	}
	for (int x,i=0;i<=m;i++) {
		scanf("%d",&x);
		c[i]=x>>15;d[i]=x&(M-1);
	}
	m=n+m;
	for (n=1,L=-1;n<=m;n<<=1) L++;
	for (int i=0;i<n;i++) rev[i]=(rev[i>>1]>>1) | ((i&1)<<L);
	FFT(a,1);FFT(b,1);FFT(c,1);FFT(d,1);
	for (int i=0;i<n;i++) {
		A[i]=a[i]*c[i];
		B[i]=a[i]*d[i]+b[i]*c[i];
		C[i]=b[i]*d[i];
	}
	FFT(A,-1);FFT(B,-1);FFT(C,-1);
	for (int i=0;i<=m;i++) {
		LL x=(LL)A[i].real()%MOD,y=(LL)B[i].real()%MOD,z=(LL)C[i].real()%MOD;
		printf("%lld ",((x<<30)+(y<<15)+z)%MOD);
	}
	return 0;
}

  

posted @ 2017-01-21 09:18  MashiroSky  阅读(413)  评论(0编辑  收藏  举报