习题:序列统计(NTT)
题目
思路
考虑这个问题的简化版本,如果是所有数的和\(mod\space m\),那么就可以利用多项式卷起来,
具体而言,定义\(dp[i]\)表示\(mod \space m\)下和为\(i\)有多少种方案,\(dp[i]=\sum_{i=1}^{n}dp[((i-s_i)\%m+m)\%m]\),这东西很显然是可以用\(ntt\)优化的
那么问题就转化成了怎么将乘法转换成为加法,
\(log\)不失为一个好方法,那么条件就为
\(\prod s_i\equiv x\pmod m\Leftrightarrow \prod g^{a_i}\equiv g^y\pmod m\)
考虑到需要用\(g\)去表示任意\(s_i\),这里\(g\)是\(m\)的原根,根据欧拉定理有
\(\sum a_i\equiv y \pmod {m-1}\)
然后就可以转换成为上面那一个问题了
代码
#include<iostream>
#include<cstring>
using namespace std;
namespace NTT
{
#define mod 1004535809
#define g 3
#define gi 334845270
int limit;
int l,r[20005];
long long qkpow(int a,int b)
{
if(b==0)
return 1;
if(b==1)
return a;
long long t=qkpow(a,b/2);
t=t*t%mod;
if(b&1)
t=t*a%mod;
return t;
}
void prepa(int n)
{
limit=1;
l=0;
while(limit<=n)
{
limit<<=1;
l++;
}
for(int i=0;i<limit;i++)
r[i]=((r[i>>1]>>1)|((i&1)<<(l-1)));
}
void ntt(long long *a,int ty)
{
for(int i=0;i<limit;i++)
if(i<r[i])
swap(a[i],a[r[i]]);
for(int mid=1;mid<limit;mid<<=1)
{
long long wn=qkpow(ty?g:gi,(mod-1)/(mid<<1));
for(int r=mid<<1,j=0;j<limit;j+=r)
{
long long w=1;
for(int k=0;k<mid;k++,w=w*wn%mod)
{
long long x=a[j+k],y=w*a[j+k+mid]%mod;
a[j+k]=(x+y)%mod;
a[j+k+mid]=((x-y)%mod+mod)%mod;
}
}
}
}
#undef mod
#undef g
#undef gi
}
using namespace NTT;
const int mod=1004535809;
int n,m,x,len;
int a[20005];
bool vis[20005];
int rt,inv;
bool check(int cnt)
{
int t=cnt;
for(int i=1;i<=m-2;i++,t=t*cnt%m)
if(t==1)
return 0;
return 1;
}
int getlog(int x)
{
int t=0;
int lim=1;
while(1)
{
if(lim==x)
return t;
t++;
lim=lim*rt%m;
}
}
void mul(long long *a,long long *b)
{
prepa(m*2);
ntt(a,1);ntt(b,1);
for(int i=0;i<=limit;i++)
a[i]=a[i]*b[i]%mod;
ntt(a,0);ntt(b,0);
inv=qkpow(limit,mod-2);
for(int i=0;i<=2*m;i++)
{
a[i]=a[i]*inv%mod;
b[i]=b[i]*inv%mod;
}
for(int i=m;i<=2*m;i++)
a[i-m]=(a[i-m]+a[i])%mod;
for(int i=m;i<=limit;i++)
a[i]=b[i]=0;
}
void pf(long long *a)
{
prepa(m*2);
ntt(a,1);
//cout<<"a:";
for(int i=0;i<=limit;i++)
{
//cout<<a[i]<<' ';
a[i]=a[i]*a[i]%mod;
}
//cout<<'\n';
ntt(a,0);
inv=qkpow(limit,mod-2);
for(int i=0;i<=2*m;i++)
a[i]=a[i]*inv%mod;
for(int i=m;i<=2*m;i++)
a[i-m]=(a[i-m]+a[i])%mod;
for(int i=m;i<=limit;i++)
a[i]=0;
}
void qk(int n)
{
long long f[20005]={};f[0]=1;
long long g[20005]={};
for(int i=1;i<=len;i++)
g[a[i]]++;
while(n)
{
/*cout<<"g:";
for(int i=0;i<=m;i++)
cout<<g[i]<<' ';
cout<<'\n';*/
if(n&1)
mul(f,g);
pf(g);
n>>=1;
}
cout<<f[x];
}
int main()
{
cin>>n>>m>>x>>len;
for(int i=1;i<=len;i++)
{
cin>>a[i];
if(a[i]==0)
{
len--;
i--;
}
}
for(int i=2;i<m;i++)
if(check(i))
rt=i;
x=getlog(x);
for(int i=1;i<=len;i++)
a[i]=getlog(a[i])%(m-1);
m--;
qk(n);
return 0;
}
/*
709 53 9 27
0 3 4 5 7 9 10 15 16 17 19 22 23 24 25 28 29 30 31 33 37 38 40 41 42 50 52
*/

浙公网安备 33010602011771号