[模板] 多项式工业

FFT

#include <bits/stdc++.h>

using namespace std;

int rd() {
	int ret=0, f=1;char c;
	while(c=getchar(),!isdigit(c))f=c=='-'?-1:1;
	while(isdigit(c))ret=ret*10+c-'0',c=getchar();
	return ret*f;
}


typedef long long ll;
typedef unsigned long long ull;

const double PI = acos(-1);
const int MAXN = 4000005;//four times

int limit, tr[MAXN << 2];

struct CP {
	CP(double xx=0,double yy=0) {
		x=xx;
		y=yy;
	}
	double x,y;
	CP operator + (CP const &B) const {
		return CP(x+B.x,y+B.y);
	}
	CP operator - (CP const &B) const {
		return CP(x-B.x,y-B.y);
	}
	CP operator * (CP const &B) const {
		return CP(x*B.x-y*B.y,x*B.y+y*B.x);
	}
};
void FFT(CP *f,bool op) {
	int n = limit;
	for(int i=0; i<n; i++)
		if(i<tr[i]) swap(f[i],f[tr[i]]);
	for(int p=2; p<=n; p<<=1) {
		int len=p>>1;
		CP tG=CP{cos(2*PI/p),sin(2*PI/p)};
		if(!op)tG.y*=-1;
		for(int k=0; k<n; k+=p) {
			CP buf=CP{1,0};
			for(int l=k; l<k+len; l++) {
				CP tt=buf*f[len+l];
				f[len+l]=f[l]-tt;
				f[l]=f[l]+tt;
				buf=buf*tG;
			}
		}
	}
}

void init(int n) {
	limit = 1;
	while (limit <= n)
		limit <<= 1;
	for (int i = 1; i < limit; i++)
		tr[i] = tr[i >> 1] >> 1 | ((i & 1) ? limit >> 1 : 0);
}


int MTT(int *a, int *b, int n, int m, int *res) {
	static CP tmpa[MAXN], tmpb[MAXN],tmp[MAXN];
	for (int i = 0; i < n; i++)
		tmpa[i].x = a[i];
	for (int i = 0; i < m; i++)//less
		tmpb[i].x = b[i];
	init(n + m);
	FFT(tmpa, 1);

	FFT(tmpb, 1);
	for (int i = 0; i < limit; i++)
		tmp[i] = tmpa[i] * tmpb[i] ;
	FFT(tmp, 0);
	for(int i=0;i<limit;i++)
		res[i] = (int)(tmp[i].x/limit+0.5);
	return n + m - 1;
}

int B[MAXN * 4], tot;

struct P {
	int *a, len, len2;
	void init(int _len, int input = 1) {
		len = _len;
		a = B + tot;
		if (input)
			for (int i = 0; i < _len; i++)
				cin >> a[i];
		tot += len;
	}
};

P mul(const P &lhs, const P &rhs) {
	P ret;
	ret.init(lhs.len + rhs.len - 1, 0);
	MTT(lhs.a, rhs.a, lhs.len, rhs.len, ret.a);
	return ret;
}

int n, m;

void work() {
	cin >> n >> m;
	P a, b;
	a.init(n + 1);
	b.init(m + 1);
	P ans = mul(a, b);
	for (int i = 0; i < ans.len; i++)
		cout << ans.a[i] << " ";
}

int main() {
	work();
}

NTT

#include <bits/stdc++.h>

using namespace std;

int rd() {
  int ret = 0, f = 1;char c;
  while (c = getchar(), !isdigit(c))f = c == '-' ? -1 : 1;
  while (isdigit(c))ret = ret * 10 + c - '0', c = getchar();
  return ret * f;
}

typedef long long ll;
typedef unsigned long long ull;

const int MAXN = 200005;
const int MOD = 998244353, _G = 3;

int limit, r[MAXN << 2];

ll qpow(ll a, ll b = MOD - 2) {
  ll ans = 1;
  while (b) {
    if (b & 1) {
      ans = ans * a % MOD;
    }
    a = a * a % MOD;
    b >>= 1;
  }
  return ans;
}
int bec;
const int invG = qpow(_G);

void NTT(int *g, bool op) {
  int n = limit;
  static unsigned long long f[MAXN << 1], w[MAXN << 1];
  for (int i = 0; i < n; i++)
    w[i] = 1;
  for (int i = 0; i < n; i++)
    f[i] = g[r[i]];

  for (int l = 1; l < n; l <<= 1) {
    ull tG = qpow(op ? _G : invG, (MOD - 1) / (l + l));
    for (int i = 1; i < l; i++)
      w[i] = w[i - 1] * tG % MOD;
    for (int k = 0; k < n; k += l + l) {
      for (int p = 0; p < l; p++) {
        int tt = w[p] * f[k | l | p] % MOD;
        f[k | l | p] = f[k | p] + MOD - tt;
        f[k | p] += tt;
      }
    }

    if (l == (1 << 17))
      for (int i = 0; i < n; i++)
        f[i] %= MOD;
  }
  if (!op) {
    ull invn = qpow(n);
    for (int i = 0; i < n; i++)
      g[i] = f[i] % MOD * invn % MOD;
  } else {
    for (int i = 0; i < n; i++)
      g[i] = f[i] % MOD;
  }
}

void init(int n) {
  limit = 1;
  while (limit <= n)
    limit <<= 1;
  for (int i = 1; i < limit; i++)
    r[i] = r[i >> 1] >> 1 | ((i & 1) ? limit >> 1 : 0);
}


int MTT(int *a, int *b, int n, int m, int *res) {
  static int tmpa[MAXN], tmpb[MAXN];
  for (int i = 0; i < n; i++)
    tmpa[i] = a[i];
  for (int i = 0; i < m; i++)//less
    tmpb[i] = b[i];
  init(n + m);
  NTT(tmpa, 1);
  NTT(tmpb, 1);
  for (int i = 0; i < limit; i++)
    res[i] = 1ll * tmpa[i] * tmpb[i] % MOD;
  NTT(res, 0);

  return n + m - 1;
}

int B[MAXN * 4], tot;

struct P {
  int *a, len, len2;
  void init(int _len, int input = 1) {
    len = _len;
    a = B + tot;
    if (input)
      for (int i = 0; i < _len; i++)
        cin >> a[i];
    tot += len;
  }
};

P mul(const P &lhs, const P &rhs) {
  P ret;
  ret.init(lhs.len + rhs.len - 1, 0);
  MTT(lhs.a, rhs.a, lhs.len, rhs.len, ret.a);
  return ret;
}

int n, m;

void work() {
  cin >> n >> m;
  P a, b;
  a.init(n + 1);
  b.init(m + 1);
  P ans = mul(a, b);
  for (int i = 0; i < ans.len; i++)
    cout << ans.a[i] << " ";
}

int main() {
  work();
}

MTT/任意模数NTT

#include<bits/stdc++.h>

using namespace std;

int rd(){
  int ret=0,f=1;char c;
  while(c=getchar(),!isdigit(c))f=c=='-'?-1:1;
  while(isdigit(c))ret=ret*10+c-'0',c=getchar();
  return ret*f;
}

typedef long long ll;

const int inf = 1<<30;
const int MAXN = 100005;
//const int MOD = 998244353;
int MOD;
const int BASE = 1 << 15;
const long double Pi = acos(-1.0);

struct CP {
  long double x, y;
  CP (long double xx = 0, long double yy = 0) {
    x = xx, y = yy;
  }
} P1[MAXN << 2], P2[MAXN << 2], Q[MAXN << 2];

CP operator + (CP a, CP b) {
  return CP(a.x + b.x, a.y + b.y);
}

CP operator - (CP a, CP b) {
  return CP(a.x - b.x, a.y - b.y);
}

CP operator * (CP a, CP b) {
  return CP(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x);
}
int limit, r[MAXN << 2];

ll qpow(ll a, ll b) {
  ll ans = 1;
  while (b) {
    if (b & 1) {
      ans = ans * a % MOD;
    }
    a = a * a % MOD;
    b >>= 1;
  }
  return ans;
}

void FFT(CP *A, int type) {
  for (int i = 0; i < limit; i++)
    if (i < r[i])
      swap(A[i], A[r[i]]);
  for (int mid = 1; mid < limit; mid <<= 1) {
    CP Wn( cos(Pi / mid), type * sin(Pi / mid) );
    for (int R = mid << 1, j = 0; j < limit; j += R) {
      CP w(1, 0);
      for (int k = 0; k < mid; k++, w = w * Wn) {
        CP x = A[j + k], y = w * A[j + mid + k];
        A[j + k] = x + y;
        A[j + mid + k] = x - y;
      }
    }
  }
}

void init(int n) {
  limit = 1;
  while (limit <= n)
    limit <<= 1;
  for (int i = 1; i < limit; i++)
    r[i] = r[i >> 1] >> 1 | ((i & 1) ? limit >> 1 : 0);
}

int MTT(int *a, int *b, int n, int m, int *res, int MOD) {
  init(n + m);
  for (int i = 0; i < n; i++) {
    P1[i] = {a[i] / BASE, a[i] % BASE};
    P2[i] = {a[i] / BASE, -a[i] % BASE};
  }
  for (int i = n; i < limit; i++) {
    P1[i] = {0, 0}, P2[i] = {0, 0};
  }
  for (int i = 0; i < m; i++) {
    Q[i] = {b[i] / BASE, b[i] % BASE};
  }
  for (int i = m; i < limit; i++) {
    Q[i] = {0, 0};
  }
  FFT(P1, 1), FFT(P2, 1), FFT(Q, 1);
  for (int i = 0; i < limit; i++) {
    Q[i].x /= limit, Q[i].y /= limit;
    P1[i] = P1[i] * Q[i], P2[i] = P2[i] * Q[i];
  }
  FFT(P1, -1), FFT(P2, -1);
  for (int i = 0; i < n + m - 1; i++) {
    long long a1b1, a1b2, a2b1, a2b2;
    a1b1 = (long long)floor((P1[i].x + P2[i].x) / 2 + 0.5) % MOD;
    a1b2 = (long long)floor((P1[i].y + P2[i].y) / 2 + 0.5) % MOD;
    a2b1 = (long long)floor((P1[i].y - P2[i].y) / 2 + 0.5) % MOD;
    a2b2 = (long long)floor((P2[i].x - P1[i].x) / 2 + 0.5) % MOD;
    res[i] = ((a1b1 * BASE + (a1b2 + a2b1)) * BASE + a2b2) % MOD;
    res[i] = (res[i] + MOD) % MOD;
  }
  return n + m - 1;
}

int B[MAXN*8],tot;

struct P{
  int *a,len;
  void init(int _len){
    len=_len;
    a=B+tot;
    for(int i=0;i<len;i++)
      a[i]=rd();
    tot+=len;
  }
  void mul(const P& rhs){
    len=MTT(a,rhs.a,len,rhs.len,a,MOD);
  }
};

int n,m;

int main(){
  n=rd();m=rd();MOD=rd();
  P x,y;
  x.init(n+1);
  y.init(m+1);
  x.mul(y);
  for(int i=0;i<x.len;i++){
    printf("%d ",x.a[i]);
  }
}

FWT

#include<bits/stdc++.h>

using namespace std;

int rd(){
  int ret=0,f=1;char c;
  while(c=getchar(),!isdigit(c))f=c=='-'?-1:1;
  while(isdigit(c))ret=ret*10+c-'0',c=getchar();
  return ret*f;
}

typedef long long ll;
const int MOD = 998244353,INV2=499122177;
const ll
Cor[2][2]={{1,0},{1,1}},
Cand[2][2]={{1,1},{0,1}},
Cxor[2][2]={{1,1},{1,MOD-1}},
ICor[2][2]={{1,0},{MOD-1,1}},
ICand[2][2]={{1,MOD-1},{0,1}},
ICxor[2][2]={{INV2,INV2},{INV2,MOD-INV2}};

void FWT(ll *f,const ll c[2][2],int n){
  for(int len=1;len<n;len<<=1)
    for(int p=0;p<n;p+=len+len)
      for(int i=p;i<p+len;i++){
        ll sav=f[i];
        f[i]=(c[0][0]*f[i]+c[0][1]*f[i+len])%MOD;
        f[i+len]=(c[1][0]*sav+c[1][1]*f[i+len])%MOD;
      }
}

void bitmul(ll *f,ll *g,const ll c[2][2],const ll ic[2][2],int n){
  FWT(f,c,n);FWT(g,c,n);
  for(int i=0;i<n;i++) f[i]=f[i]*g[i]%MOD;
  FWT(f,ic,n);
}

void print(ll *x,int n){
  for(int i=0;i<n;i++)
    printf("%lld ",x[i]);
  putchar('\n');
}

const int MAXN = 2000006;
#define cpy(f,g,n) memcpy(f,g,sizeof(ll)*(n))
ll f[MAXN],g[MAXN],savf[MAXN],savg[MAXN];
int main(){
  int n=(1<<rd());
  for(int i=0;i<n;i++)f[i]=rd();
  for(int i=0;i<n;i++)g[i]=rd();
  cpy(savf,f,n);cpy(savg,g,n);
  bitmul(f,g,Cor,ICor,n);
  print(f,n);
  cpy(f,savf,n);cpy(g,savg,n);
  bitmul(f,g,Cand,ICand,n);
  print(f,n);
  cpy(f,savf,n);cpy(g,savg,n);
  bitmul(f,g,Cxor,ICxor,n);
  print(f,n);
  return 0;
}
posted @ 2021-08-04 16:45  GhostCai  阅读(51)  评论(0编辑  收藏  举报