suxxsfe

一言(ヒトコト)

P3321 [SDOI2015]序列统计

https://www.luogu.com.cn/problem/P3321

暴力 dp 的话,就是 \(f_{i,x}\) 表示填了前 \(i\) 个数,乘积为 \(x\) 的有多少种,那么 \(f_{i,x}\rightarrow f_{i+1,x\cdot S_k}\)
发现如果把后面下标里的那个乘改成加就是普通的循环卷积
而原根有性质 \(a^x\) 可以取遍 \(0,1,\cdots,p-1\) 的所有数(其中 \(a\)\(p\) 的原根,\(x\in [0,p-2]\)
所以只要找到 \(m\) 的原根,把它当作底数来在膜意义下取 \(\log\) 即可变成加号

最小原根是不会大于 \(p^{0.25}\) 级别的,所以直接从小到大枚举并判断
\(k\)\(p-1\) 的一个质因数,那么若对于任意的 \(k\),都有 \(a^{\frac{p-1}{k}}\not \equiv 1\bmod p\),那么 \(a\) 就是 \(p\) 的一个原根

那么每个 \(S_i\) 就可以为产生的那个多项式的第 \(\log S_i\) 项加一
因为指数需要模 \(\varphi(m)\) 也就是 \(m-1\),所以产生的这个多项式是 \(m-1\) 次的
那么对这个多项式做 \(n\) 次方即可
因为是循环卷积,所以需要像普通快速幂那样做一个 \(O(m\log m\log n)\) 的东西,而不能写那种先取 \(\log\)\(\operatorname{exp}\) 的一 \(\log\) 做法

#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<vector>
#include<assert.h>
#define getChar getchar()
inline int read(){
	register int x=0,y=1;register char c=getChar;
	while(c<'0'||c>'9') y&=(c!='-'),c=getChar;
	while(c>='0'&&c<='9') x=x*10+(c^48),c=getChar;
	return y?x:-x;
}
#define N 8006
#define G 3
#define MOD 1004535809
inline long long power(long long a,long long b,int mod=MOD){
	long long ans=1;
	while(b){
		if(b&1) ans=ans*a%mod;
		b>>=1;a=a*a%mod;
	}
	return ans;
}
inline int getFac(int n,int *fac){
	int o=0;
	for(int i=2;i*i<=n;i++)if(!(n%i)){
		fac[++o]=i;
		while(!(n%i)) n/=i;
	}
	if(n>1) fac[++o]=n;
	return o;
}
inline int getRoot(int p){
	static int fac[N];
	int n=getFac(p-1,fac);
	for(int i=2;;i++){
		for(int j=1;j<=n;j++)if(power(i,(p-1)/fac[j],p)==1) goto NEX;
		return i;
NEX:;
	}
	return -1;
}
int rev[N*4];
inline int init(int n){
	int max=1;while(max<n) max<<=1;
	for(int i=0;i<max;i++) rev[i]=rev[i>>1]>>1,rev[i]|=(i&1)?(max>>1):0;
	return max;
}
inline void ntt(int n,long long *a,int type){
	for(int i=0;i<n;i++)if(rev[i]<i) std::swap(a[i],a[rev[i]]);
	for(int h=1;h<n;h<<=1){
		long long gn=power(G,(MOD-1)/(h<<1)),g,o;
		if(!type) gn=power(gn,MOD-2);
		for(int i=0,j;i<n;i+=h<<1){
			for(g=1,j=i;j<i+h;j++,g=g*gn%MOD){
				o=g*a[j+h]%MOD;
				a[j+h]=(a[j]-o+MOD)%MOD;a[j]=(a[j]+o)%MOD;
			}
		}
	}
	if(!type){
		long long inv=power(n,MOD-2);
		for(int i=0;i<n;i++) a[i]=a[i]*inv%MOD;
	}
}
inline void shrink(int len,int n,long long *f){
	for(int i=0;i<n;i++) f[i]=(f[i]+f[i+n])%MOD;
	std::memset(f+n,0,(len-n)*sizeof f[0]);
}
inline void mul(int len,int n,long long *f,long long *g=NULL){
	ntt(len,f,1);
	if(g) ntt(len,g,1); 
	if(g) for(int i=0;i<len;i++) f[i]=f[i]*g[i]%MOD;
	else for(int i=0;i<len;i++) f[i]=f[i]*f[i]%MOD;
	ntt(len,f,0);shrink(len,n,f);
	if(g) ntt(len,g,0);
}
inline void power(int n,long long *f,int b){
	int len=init(n*2);
	static long long ans[N*4];
	std::memset(ans,0,len*sizeof ans[0]);ans[0]=1;
	while(b){
		if(b&1) mul(len,n,ans,f);
		mul(len,n,f);b>>=1;
	}
	std::memcpy(f,ans,sizeof ans);
}
int _log[N];
int main(){
	int m=read(),n=read(),x=read(),s=read();
	static int a[N];
	for(int i=1;i<=s;i++) a[i]=read();
	int root=getRoot(n);
	for(int i=0,x=1;i<n-1;i++,x=(long long)x*root%n) _log[x]=i;
	static long long f[N*4];
	for(int i=1;i<=s;i++)if(a[i]) f[_log[a[i]]]++;
	power(n-1,f,m);
	printf("%lld\n",f[_log[x]]);
	return 0;
}
posted @ 2022-08-07 14:53  suxxsfe  阅读(33)  评论(0编辑  收藏  举报