[SDOI2015]序列统计(NTT,取对数(并非多项式对数!),卷积快速幂)

题面

在这里插入图片描述

题解

题意要求的是求长度为 N N N 的数列个数,满足

a 1 ⋅ a 2 ⋅ a 3 ⋅ . . . ⋅ a N ≡ x        (  ⁣ ⁣ ⁣ m o d    M ) a_1\cdot a_2\cdot a_3\cdot ... \cdot a_N \equiv x \;\;\;(\!\!\!\mod M) a1a2a3...aNx(modM)

这不好做,我们得变一下。

我们注意到 M M M 是质数,也就是说它一定有原根,
取原根为 g g g, 那么每一个 [ 1 , M − 1 ] [1,M-1] [1,M1] 内的数都可以表示成 g k  ⁣ ⁣ m o d    M g^k\!\!\mod M gkmodM
于是原式为

g k 1 ⋅ g k 2 ⋅ g k 3 ⋅ . . . ⋅ g k N ≡ g K        (  ⁣ ⁣ ⁣ m o d    M ) g^{k_1}\cdot g^{k_2}\cdot g^{k_3}\cdot ... \cdot g^{k_N} \equiv g^{K} \;\;\;(\!\!\!\mod M) gk1gk2gk3...gkNgK(modM)
⇔ k 1 + k 2 + k 3 + . . . + k N ≡ K        (  ⁣ ⁣ ⁣ m o d    φ ( M ) ) \Leftrightarrow k_1+k_2+k_3+...+k_N\equiv K\;\;\;(\!\!\!\mod φ(M)) k1+k2+k3+...+kNK(modφ(M))

为什么非要是原根呢? 因为这样就可以保证在 [ 1 , M − 1 ] [1,M-1] [1,M1] [ 0 , φ ( M ) − 1 ] [0,φ(M)-1] [0,φ(M)1] 之间形成一一映射,即唯一对应关系,上式的等价才成立。

成功把数列积变成对数和!


当我们输入了 S S S 后,我们就可以知道哪些对数是可以在数列中取的了。

f ( x ) f(x) f(x) 为对数 x 在数列中是否出现(0/1),
那么长度为 2 的数列积为 g x g^x gx 的方案数就是

S u m 2 ( x ) = ∑ i = 0 x f ( i ) f ( x − i )          +          ∑ i = x φ ( M ) f ( i ) f ( x + φ ( M ) − i ) Sum_2(x)=\sum_{i=0}^{x}f(i)f(x-i) \;\;\;\;+\;\;\;\;\sum_{i=x}^{φ(M)}f(i)f(x+φ(M)-i) Sum2(x)=i=0xf(i)f(xi)+i=xφ(M)f(i)f(x+φ(M)i)

相当于把 f ∗ f f*f ff (卷积) 后面从 φ ( M ) φ(M) φ(M) 开始的系数都加到前面 ( S 2 ( x )   +  ⁣ ⁣ =   S 2 ( φ ( M ) + x ) S_2(x)\,+\!\!=\,S_2(φ(M)+x) S2(x)+=S2(φ(M)+x)),不妨将其叫做一次 特殊的卷积“ ∗ s *_{s} s (瞎定义的,方便理解),即

S u m 2 = f ∗ s f Sum_2=f*_sf Sum2=fsf

同理可得,

S u m 3 = S u m 2 ∗ s f = f ∗ s f ∗ s f S u m n = S u m n − 1 ∗ s f = ( ∗ s ) f n Sum_3=Sum_2*_sf=f*_sf*_sf\\ Sum_n=Sum_{n-1}*_sf=(*_s)f^n Sum3=Sum2sf=fsfsfSumn=Sumn1sf=(s)fn

于是我们可以用快速幂来卷积了。


满怀希望地提交了,灰心丧气地得了个WA

我们再仔细地看范围,发现 x x x S i S_i Si 都可以为零!

那完蛋了,我们只能用对数表示 [ 1 , M − 1 ] [1,M-1] [1,M1] 中的数

但是我们可以把它特判掉,如果 x x x 不为零,那么 S S S 中的零就不管它,如果 x x x 等于零,那就意味着数列只需要满足其中有零就足够了,若此时 S S S 中有零,我们用随便乱选的方案数减去没有零的方案数,容易得到

a n s = ∣ S ∣ N − ( ∣ S ∣ − 1 ) N ans=|S|^N-(|S|-1)^N ans=SN(S1)N

CODE

#include<cmath>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
#define MAXN 8005
#define LL long long
#define DB double
#define ENDL putchar('\n')
LL read() {
	LL f=1,x=0;char s = getchar();
	while(s < '0' || s > '9') {if(s=='-')f = -f;s = getchar();}
	while(s >= '0' && s <= '9') {x=x*10+(s-'0');s = getchar();}
	return f * x;
}
const int MOD = 1004535809;
int n,m,i,j,s,o,k;
int xm[MAXN<<2],om;
int rev[MAXN<<2];
int qkpow(int a,int b,int MD) {
	int res = 1; while(b > 0) {
		if(b & 1) res = res *1ll* a % MD;
		a = a *1ll* a % MD; b >>= 1;
	}return res;
}
int findroot(int p) {
	for(int i = 2;i < p;i ++) {
		bool flag = 1;
		for(int j = 2;j*1ll*j <= p-1;j ++) {
			if((p-1) % j == 0) {
				if(qkpow(i,j,p) == 1) {
					flag = 0; break;
				}
				else if(qkpow(i,(p-1)/j,p) == 1) {
					flag = 0; break;
				}
			}
		}
		if(flag) return i;
	}
	return 3;
}
const int RM = 3;
void NTT(int *s,int n) {
	for(int i = 1;i < n;i ++) {
		rev[i] = ((rev[i>>1] >> 1) | ((i&1) ? (n>>1):0));
		if(rev[i] < i) swap(s[rev[i]],s[i]);
	}
	om = qkpow(RM,(MOD-1)/n,MOD); xm[0] = 1;
	for(int k = 1;k < n;k ++) xm[k] = xm[k-1]*1ll*om % MOD;
	for(int k = 2,t = (n>>1);k <= n;k <<= 1,t >>= 1)
		for(int j = 0;j < n;j += k)
			for(int i = j,l = 0;i < j+(k>>1);i ++,l += t) {
				int A = s[i],B = s[i+(k>>1)];
				s[i] = (A + xm[l]*1ll*B%MOD) % MOD, s[i+(k>>1)] = (A +MOD- xm[l]*1ll*B%MOD) % MOD;
			}
	return ;
}
void INTT(int *s,int n) {
	for(int i = 1;i < n;i ++) {
		rev[i] = ((rev[i>>1] >> 1) | ((i&1) ? (n>>1):0));
		if(rev[i] < i) swap(s[rev[i]],s[i]);
	}
	om = qkpow(qkpow(RM,(MOD-1)/n,MOD),MOD-2,MOD); xm[0] = 1;
	for(int k = 1;k < n;k ++) xm[k] = xm[k-1]*1ll*om % MOD;
	for(int k = 2,t = (n>>1);k <= n;k <<= 1,t >>= 1)
		for(int j = 0;j < n;j += k)
			for(int i = j,l = 0;i < j+(k>>1);i ++,l += t) {
				int A = s[i],B = s[i+(k>>1)];
				s[i] = (A + xm[l]*1ll*B%MOD) % MOD, s[i+(k>>1)] = (A +MOD- xm[l]*1ll*B%MOD) % MOD;
			}
	int invn = qkpow(n,MOD-2,MOD);
	for(int i = 0;i <= n;i ++) s[i] = s[i] *1ll* invn % MOD;
	return ;
}
int lo[MAXN],ROOT;
int A[MAXN<<2],C[MAXN<<2];
int main() {
	int N = read();n = read();
	int xx = read();m = read();
	ROOT = findroot(n);
	for(int i = 0,j = 1;i < n-1;i ++,j = j *1ll* ROOT % n) {
		lo[j] = i;
	}
	C[0] = 1;
	bool flag = 0;
	for(int i = 1;i <= m;i ++) {
		s = read(); 
		if(s) A[lo[s]] ++;
		else flag = 1;
	}n --;
	if(xx == 0) {
		int ans = (qkpow(m,N,MOD) +MOD- qkpow(m-flag,N,MOD)) % MOD;
		printf("%d\n",ans);
		return 0;
	}
	else xx = lo[xx]; 
	int le = 1; while(le <= n*2) le <<= 1;
	NTT(A,le);
	while(N > 0) {
		if(N & 1) {
			NTT(C,le); 
			for(int i = 0;i <= le;i ++) C[i] = C[i] *1ll* A[i] % MOD;
			INTT(C,le);
			for(int i = n;i <= le;i ++) (C[i-n] += C[i]) %= MOD,C[i] = 0;
		}
		for(int i = 0;i <= le;i ++) A[i] = A[i]*1ll*A[i] % MOD;
		INTT(A,le);
		for(int i = n;i <= le;i ++) (A[i-n] += A[i]) %= MOD,A[i] = 0;
		NTT(A,le);
		N >>= 1;
	}
	printf("%d\n",C[xx]);
	return 0;
}
posted @ 2021-02-02 19:08  DD_XYX  阅读(14)  评论(0)    收藏  举报