ARC120F2 Wine Thief 线性做法
由于我比较菜,会把式子写的比较仔细。
伟大的 alpha1022 指出如下事实,即
我们无非是要计算
不妨设置哑元 \(t\),令 \(t^i\) 代替 \(A_i\),则我们只需要计算
\(t^{KD}\) 所在一项是没有贡献的,只要注意到每个 \(x\) 至少绑定一个 \(t\)。
所以无非就是
那么我们主要需要解决给定一个稀疏分式 \(u(t)\),求解
这自然关于 \(u\) D-Finite,因此也关于 \(t\) D-Finite,故可以做到 \(O(N)\)。
伟大的 alpha1022 在这里停止了祂的教诲。现在让我试着冒充一下伟大的先知。
设其为 \(F(t)=G(u)\) ,先求其满足的微分方程。首先,我们有
这个式子有一点小 BUG,但我们先不管他,最后再来处理。
写成微分方程,并补上 \(j=0\) 就是
然后我们代入 \(u=\frac{t-t^D}{1-t^D}\),结合 \(G'(u)=(G(u))'/u'\),有这样依托构思
好可怕。但是我们没必要这样做。
伟大的 alpha1022 悄悄指出,可以直接同时维护 \(uG\) 和 \(u(1-u)G'/u'\)。设 \(A(t) = uG(u), B(t) = \frac {u(1-u)}{u'}G'(u)\),有
提取系数就是:
类似地:
提取系数就是:
注意到这里是有一个 \(F_i\) 的,但是无伤大雅,处理的时候注意一下就好了。具体地,在转移的时候先把其他项做了,这一项在转移到 \(F\) 的时候可以移到左边去,再把系数除过来,类似期望 DP 的套路。
这题本来应该是做完了。但是还有一点小问题。在写出微分方程的时候我们没有关心 \(j=n+k\) 的情况。当然,写成微分方程本身就是可能丢失一些信息的,我们可以直接用第一行的式子求出这个单项。这样,问题就圆满解决了。
总结一下:
- 复杂的式子可以考虑小的一块当成整体一起转移;
- 微分方程只会努力比对系数(当然也可以写出导函数之后提升次数啥的);
- 写式子的时候要注意一下特殊情况,免得出锅。
- 两个多项式做点乘,我们可以像这样写哑演算。
- 菜就多练。
#include <cstdio>
const int N = 1e6, mod = 998244353;
using ll = long long;
inline void add(int &a, int b) {a += b; a >= mod && (a -= mod);}
inline void sub(int &a, int b) {a -= b; a < 0 && (a += mod);}
inline int mul(int a, int b) {return (ll)a * b % mod;}
inline int qpow(int a, int b) {
int ret = 1;
for(; b; b >>= 1, a = mul(a, a))
b & 1 && (ret = mul(ret, a));
return ret;
}
int n, k, d, ans;
int m, lim, a[N], res[N], f[N], g[N], h[N];
// f : F; g : uF; h : u(1-u)/u' F;
int fac[N << 1 + 5], ifac[N << 1 + 5], inv[N << 1 + 5];
inline int binom(int n, int m) {
return m < 0 || m > n ? 0 : (ll)fac[n] * ifac[m] % mod * ifac[n - m] % mod;
}
int main() {
scanf("%d %d %d", &n, &k, &d);
fac[0] = 1, lim = n << 1;
for(int i = 1; i <= lim; ++i) fac[i] = mul(fac[i - 1], i);
ifac[lim] = qpow(fac[lim], mod - 2);
for(int i = lim; i; --i) ifac[i - 1] = mul(ifac[i], i), inv[i] = mul(fac[i - 1], ifac[i]);
m = n - 1 - (k - 1) * d;
f[0] = mul(m + k, binom(m + k - 1, k - 1));
for(int i = 1; i <= k; ++i)
add(res[m + k], mul(binom(m + k - (i - 1) * (d - 1), i - 1), binom(n - m - k - 1 - (k - i) * (d - 1), k - i)));
for(int i = 0; i < n; ++i) {
if(i)
g[i] = f[i - 1],
h[i] = mul(mod - i + 1, f[i - 1]);
if(i >= d - 1)
sub(h[i], mul(i - d + 1, f[i - d + 1])), add(h[i], mul(d, h[i - d + 1]));
if(i >= d)
sub(g[i], f[i - d]), add(g[i], g[i - d]),
add(h[i], mul(i - d, f[i - d])), sub(h[i], mul(d - 1, h[i - d]));
if(i == m + k) {
f[i] = res[i];
if(i >= d) sub(f[i], res[i - d]);
} else {
add(f[i], mul(g[i], m)), add(f[i], h[i]);
res[i] = f[i] = mul(f[i], i < m + k ? inv[m + k - i] : mod - inv[i - m - k]);
if(i >= d) add(res[i], res[i - d]);
}
add(h[i], mul(i, f[i]));
}
for(int i = 0; i < n; ++i)
scanf("%d", &a[i]), add(ans, mul(a[i], res[i]));
printf("%d", ans);
return 0;
}