分零食[JSOI2012]

题目描述

题面太长
\(n\)个人,\(m\)颗糖,要求给前若干个人分糖(把所有糖分完),如果一个人得到了\(x\)颗糖,那他的欢乐度就是\(Ox^2+Sx+U\),一个分糖方案的总欢乐度是所有分到糖的人的欢乐度的乘积,求所有可行分糖方案的总欢乐度的总和。

题解

首先有一个显然的dp方程:

\(dp[i][j]\)表示给前\(i\)个人分了\(j\)颗糖,设\(f(x)=Ox^2+Sx+U\)
\(dp[i][j]=\sum\limits_{k=1}^{j-i+1} dp[i-1][j-k]*f(k)\)

答案即为\(\sum\limits_{i=1}^{n} dp[i][m]\)

考虑如何优化这个式子

如果把\(dp[i-1][j-k]*f(k)\)看作一个卷积形式的话,我们会发现\(dp[i]\)这一个数组就是\(dp[i-1]\)\(f\)的卷积

我们把\(f\)以及\(dp[i]\)看作一个多项式,那么有\(f=f(1)x+f(2)x^2+f(3)x^3+\dots+f(m)x^m\)

由于\(dp[i]=dp[i-1]*f\),所以显然\(dp[i]=f^i\) (这里的乘方是指\(i\)次卷积,不是指\(i\)次方。。。)

可以使用FFT优化卷积

但是\(n\le 10^8\),这样还是跑不过,还需要进行优化:

卷积满足交换律,所以可以进行快速幂优化。但是我们要求的答案是\(\sum\limits_{i=1}^{n} dp[i][m]\),如果用快速幂的话没法计算答案啊

所以我们再定义一个多项式\(sum[i]=\sum\limits_{j=1}^i dp[i]\),也就是前缀和

首先要意识到,根据上面的定义\(dp[i]=f^i\),那么显然\(dp[a+b]=dp[a]*dp[b]\)

然后看一下\(sum\)是怎么快速幂递推的

假设现在在计算\(sum[x]\)\(dp[x]\)

\(x\)为偶,

\(dp[x]=dp[\frac{x}{2}]*dp[\frac{x}{2}]\)

\(sum[x]=sum[\frac{x}{2}]+\sum\limits_{i=x/2+1}^x dp[i]\)
\(=sum[\frac{x}{2}]+dp[\frac{x}{2}]\sum\limits_{i=1}^{x/2} dp[i]\)
\(=sum[\frac{x}{2}]+dp[\frac{x}{2}]*sum[\frac{x}{2}]\)

\(x\)为奇

那就先把\(dp[x-1]\)\(sum[x-1]\)用上面的方法算出来,然后\(dp[x]=dp[x-1]*f\)\(sum[x]=sum[x-1]+dp[x]\)

最后的答案就是\(sum[n][m]\)

时间复杂度\(O(m\log m\log n)\)

代码

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const ll mod = 998244353, G = 3, invg = 332748118;
int N, m, lim, l, rev[100005], tot;
ll dp[100005], sum[100005], F[100005], tmp[100005], invn, O, S, U, p;
ll aa[100005], bb[100005];

inline ll fpow(ll x, ll t) {
	ll ret = 1;
	for (; t; t >>= 1, x = x * x % mod) if (t & 1) ret = ret * x % mod;
	return ret;
}

void NTT(ll *c, int tp) {
	for (int i = 0; i < lim; i++) {
		if (i < rev[i]) swap(c[i], c[rev[i]]);
	}
	for (int mid = 1; mid < lim; mid <<= 1) {
		int r = mid<<1; 
		ll wn = fpow(~tp?G:invg, (mod-1)/r);
		for (int j = 0; j < lim; j += r) {
			ll w = 1;
			for (int k = 0; k < mid; k++, w = w * wn % mod) {
				ll x = c[j+k], y = w * c[j+k+mid] % mod;
				c[j+k] = (x + y) % mod;
				c[j+k+mid] = (x - y + mod) % mod;
			}
		}
	}
	if (tp == -1) {
		for (int i = 0; i < lim; i++) {
			c[i] = c[i] * invn % mod;
		}
	}
}

inline ll calc(ll x) {
	return (x * x % p * O % p + x * S % p + U) % p;
}

void Mul(ll *a, ll *b, ll *c) {
	for (int i = 0; i < lim; i++) aa[i] = a[i], bb[i] = b[i];
	NTT(aa, 1); NTT(bb, 1);
	for (int i = 0; i < lim; i++) aa[i] = aa[i] * bb[i] % mod;
	NTT(aa, -1);
	for (int i = 0; i <= m; i++) c[i] = aa[i] % p; 	
}

void solve(int n) {
	if (n == 1) {
		for (int i = 0; i <= m; i++) dp[i] = sum[i] = F[i];
		return;
	}
	solve(n >> 1);
	tot++;
	Mul(dp, sum, tmp);
	Mul(dp, dp, dp);
	for (int i = 0; i <= m; i++) {
		sum[i] = (sum[i] + tmp[i]) % p;
	}
	if (n & 1) {
		Mul(dp, F, tmp);
		for (int i = 0; i <= m; i++) {
			dp[i] = tmp[i];
			sum[i] = (sum[i] + tmp[i]) % p;
		}
	}
}

int main() {
	scanf("%d %lld %d %lld %lld %lld", &m, &p, &N, &O, &S, &U);
	lim = 1;
	while (lim <= m + m) {
		lim <<= 1; 
		l++;
	}
	invn = fpow(lim, mod-2);
	for (int i = 0; i < lim; i++) {
		rev[i] = (rev[i>>1]>>1)|((i&1)<<(l-1));
	}
	for (int i = 1; i <= m; i++) F[i] = calc(i);
	solve(N);
	printf("%lld\n", sum[m]);
	return 0;
} 
posted @ 2020-06-16 22:59  AK_DREAM  阅读(136)  评论(0编辑  收藏  举报