# LOJ#6703. 小 Q 的序列

### 题解

$$f_{i, j}$$ 为前 $$i$$ 个数选了 $$j$$ 个的权值和，那么有：

$f_{i, j} = f_{i - 1, j} + (a_i + j) f_{i - 1, j - 1}$

$$F_i(x) = \sum_j f_{i, j} x^j$$，于是可以得出 $$F_i = (1 + x + a_ix) F_{i - 1} + x^2 F'_{i-1}$$

$f_{i, j} = f_{i - 1, j - 1} + (a_i + i - j) f_{i - 1, j}$

$$b_i = a_i + i$$，那么 dp 用生成函数的形式表示就是

$F_i = xF_{i - 1} + b_i F_{i - 1} - xF'_{i-1}$

$$H_i = F_iG$$，那么我们希望能够有：

$H_i = b_i H_{i-1} - xH'_{i-1}$

$F_iG = b_iF_{i-1}G - x(F'_{i-1}G+F_{i-1}G')$

$F_iG = xF_{i - 1}G + b_i F_{i - 1}G - xF'_{i-1}G$

$$H_i$$ 的生成函数递推式展开就可以知道 $$h_{i, j} = (b_i - j) h_{i - 1, j}$$，就有 $$h_{0, j} = \frac {(-1)^j} {j!}$$$$h_{n, j} = h_{0, j} \prod_{1 \leq i \leq n} (b_i - j)$$，多点求值即可。

### 代码

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
#define file(x) freopen(#x".in", "r", stdin), freopen(#x".out", "w", stdout)

{
int data = 0, w = 1; char ch = getchar();
while (ch != '-' && (ch < '0' || ch > '9')) ch = getchar();
if (ch == '-') w = -1, ch = getchar();
while (ch >= '0' && ch <= '9') data = data * 10 + (ch ^ 48), ch = getchar();
return data * w;
}

const int N(3e5 + 10), Mod(998244353);
inline int upd(const int &x) { return x + (x >> 31 & Mod); }
int fastpow(int x, int y)
{
int ans = 1;
for (; y; y >>= 1, x = 1ll * x * x % Mod)
if (y & 1) ans = 1ll * ans * x % Mod;
return ans;
}

namespace Poly
{
int w[N], L, invL, Len = 1;
void Init(int n)
{
for (L = 1; L < n; L <<= 1); invL = fastpow(L, Mod - 2);
for (int &i = Len, t; i < L; i <<= 1)
{
w[i] = 1, t = fastpow(3, Mod / (i << 1));
for (int j = 1; j < i; j++) w[i + j] = 1ll * w[i + j - 1] * t % Mod;
}
}

void DFT(int *p)
{
for (int i = L >> 1, s = L; i; i >>= 1, s >>= 1)
for (int j = 0; j < L; j += s) for (int k = 0, o = i; k < i; ++k, ++o)
{
int x = p[j + k], y = p[i + j + k];
p[j + k] = upd(x + y - Mod), p[i + j + k] = 1ll * w[o] * upd(x - y) % Mod;
}
}

void IDFT(int *p)
{
for (int i = 1, s = 2; i < L; i <<= 1, s <<= 1)
for (int j = 0; j < L; j += s) for (int k = 0, o = i; k < i; ++k, ++o)
{
int x = p[j + k], y = 1ll * w[o] * p[i + j + k] % Mod;
p[j + k] = upd(x + y - Mod), p[i + j + k] = upd(x - y);
}
std::reverse(p + 1, p + L);
for (int i = 0; i < L; i++) p[i] = 1ll * p[i] * invL % Mod;
}

void Inv(int *a, int *b, int n)
{
if (n == 1) return (void) (*b = fastpow(*a, Mod - 2));
static int c[N], d[N]; Inv(a, b, (n + 1) >> 1), Init(n * 1.5 + 0.5);
std::memset(c, 0, L << 2), std::memcpy(c, a, n << 2), DFT(c);
std::memset(d, 0, L << 2), std::memcpy(d, b, n << 2), DFT(d);
for (int i = 0; i < L; i++) c[i] = 1ll * c[i] * d[i] % Mod * d[i] % Mod; IDFT(c);
for (int i = (n + 1) >> 1; i < n; i++) b[i] = upd(-c[i]);
}

void Mul(const int *a, const int *b, int *c, int n, int m)
{
static int f[N], g[N]; Init(n + m - 1);
std::memset(f, 0, L << 2), std::memcpy(f, a, n << 2), DFT(f);
std::memset(g, 0, L << 2), std::memcpy(g, b, m << 2), DFT(g);
for (int i = 0; i < L; i++) f[i] = 1ll * f[i] * g[i] % Mod; IDFT(f);
std::memcpy(c, f, (n + m - 1) << 2);
}

void MulT(const int *a, const int *b, int *c, int n, int m, int k)
{
static int f[N], g[N]; std::memcpy(f, a, n << 2), std::memcpy(g, b, m << 2);
std::reverse(g, g + m), Mul(f, g, f, n, m), std::memcpy(c, f + m - 1, k << 2);
}
}

inline void Mul(const std::vector<int> &f, const std::vector<int> &g, std::vector<int> &h)
{ Poly::Mul(&f[0], &g[0], &h[0], f.size(), g.size()); }
inline void MulT(const std::vector<int> &f, const std::vector<int> &g, std::vector<int> &h)
{ Poly::MulT(&f[0], &g[0], &h[0], f.size(), g.size(), h.size()); }

std::vector<int> v[N << 2], w[N << 2];
void getPoly(const int *a, int x, int l, int r)
{
if (l == r) return (void) (v[x] = {1, upd(-a[l])}, w[x] = {0});
int mid = (l + r) >> 1, ls = x << 1, rs = ls | 1;
getPoly(a, ls, l, mid), getPoly(a, rs, mid + 1, r);
v[x].resize(r - l + 2), w[x].resize(r - l + 1), Mul(v[ls], v[rs], v[x]);
}

void Div(int *ans, int x, int l, int r)
{
if (l == r) return (void) (ans[l] = w[x].front());
int mid = (l + r) >> 1, ls = x << 1, rs = ls | 1;
MulT(w[x], v[ls], w[rs]), MulT(w[x], v[rs], w[ls]);
Div(ans, ls, l, mid), Div(ans, rs, mid + 1, r);
}

void Solve(const int *f, const int *a, int *ans, int n)
{
static int g[N]; getPoly(a, 1, 1, n - 1), Poly::Inv(&v[1][0], g, n);
Poly::MulT(f, g, &w[1][0], n, n, w[1].size()), Div(ans, 1, 1, n - 1);
}

int n, a[N], b[N], ans[N], fac[N], inv[N];
std::vector<int> Prod(int l = 1, int r = n)
{
if (l == r) return {a[l], Mod - 1}; int mid = (l + r) >> 1; std::vector<int> res(r - l + 2);
return Mul(Prod(l, mid), Prod(mid + 1, r), res), res;
}

int main()
{
#ifndef ONLINE_JUDGE
file(cpp);
#endif
n = read(), ans[0] = fac[0] = 1;
for (int i = 1; i <= n; i++) fac[i] = 1ll * fac[i - 1] * i % Mod;
inv[n] = fastpow(fac[n], Mod - 2);
for (int i = n; i; i--) inv[i - 1] = 1ll * inv[i] * i % Mod;
for (int i = 1; i <= n; i++) a[i] = upd(read() + i - Mod), b[i] = i;
for (int i = 1; i <= n; i++) ans[0] = 1ll * ans[0] * a[i] % Mod;
Solve(&Prod()[0], b, ans, n + 1);
for (int i = 0; i <= n; i++)
if (i & 1) ans[i] = upd(-1ll * ans[i] * inv[i] % Mod);
else ans[i] = 1ll * ans[i] * inv[i] % Mod;
std::vector<int> f(ans, ans + n + 1), g(inv, inv + n + 1), h(f.size() + g.size() - 1);
Mul(f, g, h); int res = 0;
for (int i = 0; i < n; i++) res = upd(res + h[i] - Mod);
printf("%d\n", res);
return 0;
}

posted @ 2021-01-26 11:51  xgzc  阅读(133)  评论(0编辑  收藏  举报