【SDOI2015】序列统计

题面

https://www.luogu.org/problem/P3321

题解

首先贡献是$f[a_ib_i]+=f1[a_i]\times f2[b_i]$,用原根变成$f[a_i+b_i]+=f1[a_i]\times f2[b_i]$,即形成一个新的映射。

开个桶,即求这个多项式的$n$次幂。

$NTT+$分治快速幂。

自己编的$NTT$口诀:

上倍增,中加二倍,下加加。

上界$1,0,0$,下界$li,li,mi$

原根$3$,$2^{mid}$次单位根$w0=phi(p)/2mid$($FFT$:$w0=\{cos(pi/mid),sin(pi/mid)\times opt\}$)

加($mid$)乘($w$),加($mid$)减($y$)。

$-1$时,$reverse(1,lim)$除以$lim$

NTT(copy from Gloid orz Gloid)

没啥区别就是模意义下的。那用原根代替复数就好了。

原根随便都能求出来。

模数需要为2^n*k+1的形式且为质数。因为需要求得2^i次根,也即原根的(p-1)/2^i次,这个东西显然需要是整数。

常用模数有

998244353=2^23*119+1

1004535809=2^21*479+1

469762049=2^26*7+1

都能跑几百万项。

IDFT时用原根的逆元。最后乘项数的逆元。

#include<cmath>
#include<stack>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ri register int
#define N 100000
#define mod 1004535809
#define LL long long

using namespace std;

inline int read() {
  int ret=0; char ch=getchar();
  while (ch<'0' || ch>'9') ch=getchar();
  while (ch>='0' && ch<='9') ret*=10,ret+=(ch-'0'),ch=getchar();
  return ret;
}

int n,m,x,t;
int f[N],s[N];
int a[N],mp[N],b[N],ftr[30];
int lim=1,l=0,r[N],ret[N];

int pow(int a,int b,int p) {
  int ret=1;
  for (;b;b>>=1,a=a*1LL*a%p) if (b&1) ret=ret*1LL*a%p;
  return ret;
}

int getg(int m) {
  int tmp=m-1;
  int c=0;
  for (ri i=2;i*i<=tmp;i++) if (tmp%i==0) {
    ftr[++c]=i;
    while (tmp%i==0) tmp/=i;
  }
  if (tmp>1) ftr[++c]=tmp;
  for (ri g=2;g<=m-1;g++) {
    bool fl=0;
    for (ri j=1;j<=c;j++) if (pow(g,(m-1)/ftr[j],m)==1) {fl=1;break;}
    if (!fl) return g;
  }
  return -1;
}

void NTT(int *p,int opt) {
  for (ri i=0;i<lim;i++) if (i<r[i]) swap(p[i],p[r[i]]);
  for (ri i=1;i<lim;i<<=1) {
    int w0=pow(3,(mod-1)/(i<<1),mod);
    for (ri j=0;j<lim;j+=2*i) {
      int w=1;
      for (ri k=0;k<i;k++,w=w*1LL*w0%mod) {
        int x=p[j+k],y=p[i+j+k]*1LL*w%mod;
        p[j+k]=(x+y)%mod; p[i+j+k]=(x-y+mod)%mod;
      }
    }
  }
  if (opt==-1) {
    reverse(&p[1],&p[lim]);
    int inv=pow(lim,mod-2,mod);
    for (ri i=0;i<lim;i++) p[i]=p[i]*1LL*inv%mod;
  }
}

void mul(int *a1,int *a2,int *c) {
  memset(a,0,sizeof(a)); memset(b,0,sizeof(b));
  for (ri i=0;i<m-1;i++) a[i]=a1[i],b[i]=a2[i];
  NTT(a,1); NTT(b,1);
  for (ri i=0;i<lim;i++) a[i]=a[i]*1LL*b[i]%mod;
  NTT(a,-1);
  memset(ret,0,sizeof(ret));
  for (ri i=0;i<m-1;i++) ret[i]=(a[i]+a[i+m-1])%mod;
  for (ri i=0;i<m-1;i++) c[i]=ret[i];
}

int main() {
  n=read(); m=read(); x=read(); t=read();
  int g=getg(m);
  for (ri i=0;i<m-1;i++) mp[pow(g,i,m)]=i;
  while (lim<=2*(m-2)) lim<<=1,l++;
  for (ri i=0;i<lim;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
  for (ri i=1;i<=t;i++) {
    int xx=read()%m;
    if (xx) f[mp[xx]]++;
  }
  s[mp[1]]=1;
  for (;n;n>>=1,mul(f,f,f)) if (n&1) mul(s,f,s);
  printf("%d\n",s[mp[x]]);
}

 

posted @ 2019-09-06 23:41  HellPix  阅读(177)  评论(0编辑  收藏  举报