题解 [HAOI2018] 染色
题目链接
题目描述
对一个长度为 \(n\) 的序列染色,颜色一共有 \(m\) 种,对于每一种染色方案,设有 \(u\) 种颜色恰好出现了 \(s\) 次,则将答案加 \(w_u\) ,求答案
其中 \(n \leq 10^7, m \leq 10^5, s \leq 150\),并将答案对 \(1004535809\) 取模
分析
下面的式子虽然有点多,但每一步推导都非常简单,应该可以在不使用草稿纸的情况下看懂
首先,\(1004535809\) 的出现直接暗示了我们这题需要 NTT
然后肯定就是生成函数推柿子题了
考虑枚举恰好出现了 \(s\) 次的颜色有多少种
就假设枚举到 \(i\) 种
那么对于这 \(i\) 种颜色,他们的指数型生成函数都是
那对于另外的 \(m - i\) 种颜色,由于他们不能恰好出现 \(s\) 次,那他们的生成函数就应该是
那么,假设选定了 \(i\) 种颜色使他们恰好出现 \(s\) 次,这样的序列个数就是
这个 \(n!\) 是因为这是指数型生成函数
再考虑从 \(m\) 种颜色中选出 \(i\) 种颜色作为恰好出现 \(s\) 次的颜色的方案数,所以答案就是
对于最后那一项,我们用二项式定理展开,有
代入得答案为
这里观察到如果 \(j\) 从 \(i\) 开始枚举的话式子会简单许多
把 \(\frac{1}{(s!)^i}\) 和 \([x^{n - si}]\) 写进去
而我们又知道
所以
代进去
观察发现两个二项式系数显然可以消掉一个阶乘,于是我们把二项式系数展开
其实到这里已经差不多可以看出来了,我们再变形一下
可以发现 \(\frac{(-1)^{j-i}}{(j-i)!}\) 的取值只与 \(j-i\) 有关,\(\frac{(m-j)^{n-sj}}{(s!)^j(m-j)!(n-sj)!}\) 只与 \(j\) 有关
我们分别将其设为 \(g_{j-i}\) 与 \(f_j\)
那么
可以发现这里有一个极其类似卷积的式子,对于这样的式子,我们的操作是将 \(f\) 翻转,即让 \(f_j\) 与 \(f_{m-j}\) 交换,那么对于新的 \(f_j\),他对应的原来的 \(f\) 值就是 \(f_{m-j}\),所以他应该乘的 \(g\) 值就是 \(g_{m-j-i}\),到这里卷积的形式就出来了
我们再写一遍答案
其中
如果 \(n-si < 0\) 则令 \(f_i = 0\) 即可
使用 NTT 求解卷积,复杂度为 \(O(n + m\log m)\)
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1e7 + 10;
const int M = 400010;
const LL p = 1004535809;
LL n, m, s, ge = 3, ig, mx = 1, ans, w[M], f[M], g[M], fac[N];
LL ksm(LL v, LL tms)
{
LL res = 1;
for(; tms; tms >>= 1, v = v * v % p) if(tms & 1) res = res * v % p;
return res;
}
void NTT(LL *a, LL mx, LL flag)
{
for(int i = 0, j = 0; i < mx; i++){
if(i < j) swap(a[i], a[j]);
for(int l = mx >> 1; (j ^= l) < l; l >>= 1);
}
for(int i = 2; i <= mx; i <<= 1){
LL now = (i >> 1), step = ksm((flag > 0 ? ge : ig), (p - 1) / i);
for(int j = 0; j < mx; j += i){
LL temp = 1;
for(int k = 0; k < now; k++){
LL G = a[j + k], H = temp * a[j + k + now] % p;
a[j + k] = (G + H) % p, a[j + k + now] = (G - H + p) % p;
temp = step * temp % p;
}
}
}
LL inv = ksm(mx, p - 2);
if(flag < 0) for(int i = 0; i < mx; i++) a[i] = a[i] * inv % p;
}
int main()
{
ios :: sync_with_stdio(false), cin.tie(0);
ig = ksm(ge, p - 2);
cin >> n >> m >> s;
for(int i = 0; i <= m; i++) cin >> w[i];
fac[0] = 1;
for(int i = 1; i <= max(n, m); i++) fac[i] = fac[i - 1] * i % p;
for(int i = 0; i <= m; i++)
g[i] = (((i & 1) ? -1 : 1) * ksm(fac[i], p - 2) + p) % p;
for(int i = 0; i <= m; i++)
if(s * i > n) f[i] = 0; //特判f[i]=0
else f[i] = ksm((ksm(fac[s], i) * fac[m - i] % p) * fac[n - s * i] % p, p - 2) * ksm(m - i, n - s * i) % p;
for(int i = 0; (i << 1) <= m; i++) swap(f[i], f[m - i]); //翻转f
while(mx <= m) mx <<= 1; mx <<= 1;
NTT(f, mx, 1);
NTT(g, mx, 1);
for(int i = 0; i < mx; i++) f[i] = f[i] * g[i] % p;
NTT(f, mx, -1);
for(int i = 0; i <= m; i++)
if(s * i <= n) ans = (ans + ((w[i] * ksm(fac[i], p - 2) % p) * f[m - i] % p) + p) % p;
ans = ans * (fac[n] * fac[m] % p) % p;
printf("%lld\n", ans);
return 0;
}

浙公网安备 33010602011771号