[SDOI2015]序列统计(NTT,取对数(并非多项式对数!),卷积快速幂)
题面
题解
题意要求的是求长度为 N N N 的数列个数,满足
a 1 ⋅ a 2 ⋅ a 3 ⋅ . . . ⋅ a N ≡ x ( m o d M ) a_1\cdot a_2\cdot a_3\cdot ... \cdot a_N \equiv x \;\;\;(\!\!\!\mod M) a1⋅a2⋅a3⋅...⋅aN≡x(modM)
这不好做,我们得变一下。
我们注意到
M
M
M 是质数,也就是说它一定有原根,
取原根为
g
g
g, 那么每一个
[
1
,
M
−
1
]
[1,M-1]
[1,M−1] 内的数都可以表示成
g
k
m
o
d
M
g^k\!\!\mod M
gkmodM,
于是原式为
g
k
1
⋅
g
k
2
⋅
g
k
3
⋅
.
.
.
⋅
g
k
N
≡
g
K
(
m
o
d
M
)
g^{k_1}\cdot g^{k_2}\cdot g^{k_3}\cdot ... \cdot g^{k_N} \equiv g^{K} \;\;\;(\!\!\!\mod M)
gk1⋅gk2⋅gk3⋅...⋅gkN≡gK(modM)
⇔
k
1
+
k
2
+
k
3
+
.
.
.
+
k
N
≡
K
(
m
o
d
φ
(
M
)
)
\Leftrightarrow k_1+k_2+k_3+...+k_N\equiv K\;\;\;(\!\!\!\mod φ(M))
⇔k1+k2+k3+...+kN≡K(modφ(M))
为什么非要是原根呢? 因为这样就可以保证在 [ 1 , M − 1 ] [1,M-1] [1,M−1] 和 [ 0 , φ ( M ) − 1 ] [0,φ(M)-1] [0,φ(M)−1] 之间形成一一映射,即唯一对应关系,上式的等价才成立。
成功把数列积变成对数和!
当我们输入了 S S S 后,我们就可以知道哪些对数是可以在数列中取的了。
设
f
(
x
)
f(x)
f(x) 为对数 x 在数列中是否出现(0/1),
那么长度为 2 的数列积为
g
x
g^x
gx 的方案数就是
S u m 2 ( x ) = ∑ i = 0 x f ( i ) f ( x − i ) + ∑ i = x φ ( M ) f ( i ) f ( x + φ ( M ) − i ) Sum_2(x)=\sum_{i=0}^{x}f(i)f(x-i) \;\;\;\;+\;\;\;\;\sum_{i=x}^{φ(M)}f(i)f(x+φ(M)-i) Sum2(x)=∑i=0xf(i)f(x−i)+∑i=xφ(M)f(i)f(x+φ(M)−i)
相当于把
f
∗
f
f*f
f∗f (卷积) 后面从
φ
(
M
)
φ(M)
φ(M) 开始的系数都加到前面 (
S
2
(
x
)
+
=
S
2
(
φ
(
M
)
+
x
)
S_2(x)\,+\!\!=\,S_2(φ(M)+x)
S2(x)+=S2(φ(M)+x)),不妨将其叫做一次 特殊的卷积“
∗
s
*_{s}
∗s” (瞎定义的,方便理解),即
S u m 2 = f ∗ s f Sum_2=f*_sf Sum2=f∗sf
同理可得,
S u m 3 = S u m 2 ∗ s f = f ∗ s f ∗ s f S u m n = S u m n − 1 ∗ s f = ( ∗ s ) f n Sum_3=Sum_2*_sf=f*_sf*_sf\\ Sum_n=Sum_{n-1}*_sf=(*_s)f^n Sum3=Sum2∗sf=f∗sf∗sfSumn=Sumn−1∗sf=(∗s)fn
于是我们可以用快速幂来卷积了。
满怀希望地提交了,灰心丧气地得了个WA
我们再仔细地看范围,发现 x x x 和 S i S_i Si 都可以为零!
那完蛋了,我们只能用对数表示 [ 1 , M − 1 ] [1,M-1] [1,M−1] 中的数
但是我们可以把它特判掉,如果 x x x 不为零,那么 S S S 中的零就不管它,如果 x x x 等于零,那就意味着数列只需要满足其中有零就足够了,若此时 S S S 中有零,我们用随便乱选的方案数减去没有零的方案数,容易得到
a n s = ∣ S ∣ N − ( ∣ S ∣ − 1 ) N ans=|S|^N-(|S|-1)^N ans=∣S∣N−(∣S∣−1)N
CODE
#include<cmath>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
#define MAXN 8005
#define LL long long
#define DB double
#define ENDL putchar('\n')
LL read() {
LL f=1,x=0;char s = getchar();
while(s < '0' || s > '9') {if(s=='-')f = -f;s = getchar();}
while(s >= '0' && s <= '9') {x=x*10+(s-'0');s = getchar();}
return f * x;
}
const int MOD = 1004535809;
int n,m,i,j,s,o,k;
int xm[MAXN<<2],om;
int rev[MAXN<<2];
int qkpow(int a,int b,int MD) {
int res = 1; while(b > 0) {
if(b & 1) res = res *1ll* a % MD;
a = a *1ll* a % MD; b >>= 1;
}return res;
}
int findroot(int p) {
for(int i = 2;i < p;i ++) {
bool flag = 1;
for(int j = 2;j*1ll*j <= p-1;j ++) {
if((p-1) % j == 0) {
if(qkpow(i,j,p) == 1) {
flag = 0; break;
}
else if(qkpow(i,(p-1)/j,p) == 1) {
flag = 0; break;
}
}
}
if(flag) return i;
}
return 3;
}
const int RM = 3;
void NTT(int *s,int n) {
for(int i = 1;i < n;i ++) {
rev[i] = ((rev[i>>1] >> 1) | ((i&1) ? (n>>1):0));
if(rev[i] < i) swap(s[rev[i]],s[i]);
}
om = qkpow(RM,(MOD-1)/n,MOD); xm[0] = 1;
for(int k = 1;k < n;k ++) xm[k] = xm[k-1]*1ll*om % MOD;
for(int k = 2,t = (n>>1);k <= n;k <<= 1,t >>= 1)
for(int j = 0;j < n;j += k)
for(int i = j,l = 0;i < j+(k>>1);i ++,l += t) {
int A = s[i],B = s[i+(k>>1)];
s[i] = (A + xm[l]*1ll*B%MOD) % MOD, s[i+(k>>1)] = (A +MOD- xm[l]*1ll*B%MOD) % MOD;
}
return ;
}
void INTT(int *s,int n) {
for(int i = 1;i < n;i ++) {
rev[i] = ((rev[i>>1] >> 1) | ((i&1) ? (n>>1):0));
if(rev[i] < i) swap(s[rev[i]],s[i]);
}
om = qkpow(qkpow(RM,(MOD-1)/n,MOD),MOD-2,MOD); xm[0] = 1;
for(int k = 1;k < n;k ++) xm[k] = xm[k-1]*1ll*om % MOD;
for(int k = 2,t = (n>>1);k <= n;k <<= 1,t >>= 1)
for(int j = 0;j < n;j += k)
for(int i = j,l = 0;i < j+(k>>1);i ++,l += t) {
int A = s[i],B = s[i+(k>>1)];
s[i] = (A + xm[l]*1ll*B%MOD) % MOD, s[i+(k>>1)] = (A +MOD- xm[l]*1ll*B%MOD) % MOD;
}
int invn = qkpow(n,MOD-2,MOD);
for(int i = 0;i <= n;i ++) s[i] = s[i] *1ll* invn % MOD;
return ;
}
int lo[MAXN],ROOT;
int A[MAXN<<2],C[MAXN<<2];
int main() {
int N = read();n = read();
int xx = read();m = read();
ROOT = findroot(n);
for(int i = 0,j = 1;i < n-1;i ++,j = j *1ll* ROOT % n) {
lo[j] = i;
}
C[0] = 1;
bool flag = 0;
for(int i = 1;i <= m;i ++) {
s = read();
if(s) A[lo[s]] ++;
else flag = 1;
}n --;
if(xx == 0) {
int ans = (qkpow(m,N,MOD) +MOD- qkpow(m-flag,N,MOD)) % MOD;
printf("%d\n",ans);
return 0;
}
else xx = lo[xx];
int le = 1; while(le <= n*2) le <<= 1;
NTT(A,le);
while(N > 0) {
if(N & 1) {
NTT(C,le);
for(int i = 0;i <= le;i ++) C[i] = C[i] *1ll* A[i] % MOD;
INTT(C,le);
for(int i = n;i <= le;i ++) (C[i-n] += C[i]) %= MOD,C[i] = 0;
}
for(int i = 0;i <= le;i ++) A[i] = A[i]*1ll*A[i] % MOD;
INTT(A,le);
for(int i = n;i <= le;i ++) (A[i-n] += A[i]) %= MOD,A[i] = 0;
NTT(A,le);
N >>= 1;
}
printf("%d\n",C[xx]);
return 0;
}