# 【BZOJ4555】【TJOI2016】【HEOI2016】求和

## 解法

\begin{aligned} &\sum_{i = 0}^{n}\sum_{j = 0}^{i}S(i, j)\\ &=\sum_{i = 0}^{n}\sum_{j = 0}^{i}S(i, j)\\ &=\sum_{i = 0}^{n}\sum_{j = 0}^{n}2^jj!{\frac{1}{j!}}\sum_{k = 0}^{j}(j-k)^i(-1)^j{j \choose k}\\ &= \sum_{i = 0}^{n}\sum_{j = 0}^{n}2^j\sum_{k = 0}^{j}(j-k)^i(-1)^j\frac{j!}{k!(j-k)!}\\ &= \sum_{i = 0}^{n}\sum_{j = 0}^{n}2^j\sum_{k = 0}^{j}(j-k)^i(-1)^j\frac{j!}{k!(j-k)!}\\ &= \sum_{i = 0}^{n}\sum_{j = 0}^{n}2^jj!\sum_{k = 0}^{j}(j-k)^i(-1)^j\frac{1}{k!(j-k)!}\\ &= \sum_{i = 0}^{n}\sum_{j = 0}^{n}2^jj!\sum_{k = 0}^{j}(-1)^j\frac{(j-k)^i}{k!(j-k)!}\\ &= \sum_{i = 0}^{n}\sum_{j = 0}^{n}2^jj!\sum_{k = 0}^{j}(-1)^j\frac{1}{k!}\frac{(j-k)^i}{(j-k)!} \\ &= \sum_{j = 0}^{n}2^jj!\sum_{k = 0}^{j}(-1)^j\frac{1}{k!}\sum_{i = 0}^{n}\frac{(j-k)^i}{(j-k)!} \end{aligned}

$b$可以用等比数列求和公式求出

## 代码

#include <iostream>
#include <cstdlib>
#include <cstdio>
#include <cstdlib>
#include <algorithm>

using namespace std;

typedef long long LL;

const LL mod = 998244353LL;

const int N = 400010;

inline LL power(LL a, LL n, LL mod)
{	LL Ans = 1;
a %= mod;
while (n)
{	if (n & 1) Ans = (Ans * a) % mod;
a = (a * a) % mod;
n >>= 1;
}
return Ans;
}

inline LL Plus(LL a, LL b) { return a + b > mod ? a + b - mod : a + b; }

inline LL Minus(LL a, LL b) { return a - b < 0 ? a - b + mod : a - b; }

struct Mul
{	int Len, Bit;

LL wn[N];

int rev[N];

void getReverse()
{	for (int i = 0; i < Len; i++)
rev[i] = (rev[i>>1] >> 1) | ((i&1) * (Len >> 1));
}

void NTT(LL * a, int opt)
{	getReverse();
for (int i = 0; i < Len; i++)
if (i < rev[i]) swap(a[i], a[rev[i]]);
int cnt = 0;
for (int i = 2; i <= Len; i <<= 1)
{	cnt++;
for (int j = 0; j < Len; j += i)
{	LL w = 1LL;
for (int k = 0; k < (i>>1); k++)
{	LL x = a[j + k];
LL y = (w * a[j + k + (i>>1)]) % mod;
a[j + k] = Plus(x, y);
a[j + k + (i>>1)] = Minus(x, y);
w = (w * wn[cnt]) % mod;
}
}
}
if (opt == -1)
{	reverse(a + 1, a + Len);
LL num = power(Len, mod-2, mod);
for (int i = 0; i < Len; i++)
a[i] = (a[i] * num) % mod;
}
}

void getLen(int l)
{	Len = 1, Bit = 0;
for (; Len <= l; Len <<= 1) Bit++;
}

void init()
{	for (int i = 0; i < 23; i++)
wn[i] = power(3, (mod-1) / (1LL << i), mod);
}
} Calc;

LL fac[N], ifac[N];

LL A[N], B[N], C[N];

int main()
{	int n;
scanf("%d", &n);
fac[0] = 1;
for (int i = 1; i <= n; i++)
fac[i] = fac[i-1] * i % mod;
ifac[n] = power(fac[n], mod-2, mod);
for (int i = n-1; i >= 0; i--)
ifac[i] = ifac[i+1] * (i+1) % mod;
for (int i = 0; i <= n; i++)
A[i] = (i & 1 ? Minus(mod, 1) : 1) * ifac[i] % mod;
B[0] = 1;
B[1] = n + 1;
for (int i = 2; i <= n; i++)
B[i] = (power(i, n+1, mod) + mod - 1) % mod * power(i-1, mod-2, mod) % mod * ifac[i] % mod;
Calc.init();
Calc.getLen(n * 2 + 1);
Calc.NTT(A, 1);
Calc.NTT(B, 1);
for (int i = 0; i < Calc.Len; i++)
C[i] = A[i] * B[i] % mod;
Calc.NTT(C, -1);
LL Ans = 0;
for (int i = 0; i <= n; i++)
Ans = Plus(Ans, (power(2LL, i, mod) * fac[i] % mod * C[i] % mod));
printf("%lld\n", Ans);
return 0;
}

