P6667 [清华集训2016] 如何优雅地求和

题意描述

给你一个 \(m\) 次多项式 \(f(x)\), 让你求:

\(\displaystyle\sum_{k=0}^{n} f(k){n\choose k} x^k(1-x)^{n-k}\)

其中 \(f(x)\) 由点值形式给出。

数据范围:\(1\leq n\leq 10^9,1\leq m\leq 2\times 10^4\)

solution

下降幂多项式+NTT。

先推一波柿子。

考虑把 \(f(x)\) 转化为下降幂多项式,则有:

\(\displaystyle\sum_{k=0}^{n} \sum_{i=0}^{m}a_ik^i{n\choose k}x^k(1-x)^{n-k}\)

\(=\displaystyle\sum_{k=0}^{n}\sum_{i=0}^{m}b_ik^{\underline i}{n\choose k}x^k(1-x)^{n-k}\)

交换一下求和顺序可得:

\(=\displaystyle\sum_{i=0}^{m}b_i\sum_{k=0}^{n}k^{\underline i}{n\choose k}x^k(1-x)^{n-k}\)

\(=\displaystyle\sum_{i=0}^{m}b_i\sum_{k=0}^{n}{k!\over (k-i)!} {n!\over k!(n-k)!} x^k(1-x)^{n-k}\)

\(=\displaystyle\sum_{i=0}^{m}b_i\sum_{k=0}^{n} {n!(n-i)!\over (k-i)!(n-k)!(n-i)!}x^k(1-x)^{n-k}\)

\(=\displaystyle\sum_{i=0}^{m}b_i\sum_{k=0}^{n} n^{\underline i}{n-i\choose k-i}x^k(1-x)^{n-k}\)

\(=\displaystyle\sum_{i=0}^{m}b_in^{\underline i}\sum_{k=0}^{n}{n-i\choose k-i}x^k(1-x)^{n-k}\)

\(=\displaystyle\sum_{i=0}^{m}b_in^{\underline i}x^{i}\sum_{k=0}^{n}{n-i\choose k-i}x^{k-i}(1-x)^{n-k}\)

后面的用二项式定理可得:

\(=\displaystyle\sum_{i=0}^{m}b_in^{\underline i}x^i(x+1-x)^i\)

\(=\displaystyle\sum_{i=0}^{m}b_in^{\underline i}x^i\)

求出来 \(b_i\) 之后,就可以 \(O(m)\) 的来做了。

考虑怎么求 \(b_i\) ,可以根据点值的 \(\text{EGF}\) 来求。

\(\text{EGF}(f(x)) = \displaystyle\sum_{i=0}^{m}{f_ix^i\over i!}\)

\(\text{EGF}(f(x)) = \displaystyle\sum_{i=0}^{m}{x^i\over i!}\sum_{j=0}^{m}b_ji^{\underline j}\)

\(\text{EGF}(f(x)) = \displaystyle\sum_{i=0}^{m}{x^i\over i!}\sum_{j=0}^{m}b_j{i!\over (i-j)!}\)

\(\text{EGF}(f(x)) = \displaystyle\sum_{i=0}^{m}x^i\sum_{j=0}^{m} b_j{1\over (i-j)!}\)

交换一下求和顺序则有:

\(\displaystyle \text{EGF}(f(x)) = \sum_{i=0}^{m}b_i\sum_{j=i}^{m}x^j{1\over (j-i)!}\)

\(\text{EGF}(f(x)) = \displaystyle\sum_{i=0}^{m}b_ix^{i}\sum_{j=i}^{m}x^{j-i}{1\over (j-i)!}\)

\(\text{EGF}(f(x))=\displaystyle\sum_{i=0}^{m}b_ix^i\sum_{j=0}^{m-i}{1\over j!}x^j\)

\(\text{EGF}(f(x))=\displaystyle\sum_{i=0}^{m}b_ix^i.e^x\)

\(G(x) = \displaystyle\sum_{i=0}^{\infty} b_ix^i\)\(b_i\)\(\text{OGF}\) ,则有:

\(\text{EGF}(f(x))= G(x).e^x\)

\(G(x) = \text{EGF}(f(x)) . e^{-x}\)

\(f(x)\)\(\text{EGF}\)\(e^{-x}\) 卷起来就可以得到 \(G(x)\)

在提取系数就可以得到 \(b_i\) 即:\(b_i = [x^i]G(x)\)

复杂度:\(O(mlogm)\)

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<vector>
#include<cmath>
#include<queue>
using namespace std;
#define int long long
const int p = 998244353;
const int N = 5e6+10;
int n,m,x,ans,a[N],b[N],f[N],jz[N],inv[N],rev[N];
inline int read()
{
    int s = 0,w = 1; char ch = getchar();
    while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
    while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0'; ch = getchar();}
    return s * w;
}
int ksm(int a,int b)
{
	int res = 1;
	for(; b; b >>= 1)
	{
		if(b & 1) res = res * a % p;
		a = a * a % p;
	}
	return res;
}
void YYCH()
{
	jz[0] = inv[0] = 1; 
	for(int i = 1; i <= m; i++) jz[i] = jz[i-1] * i % p;
	inv[m] = ksm(jz[m],p-2);
	for(int i = m-1; i >= 1; i--) inv[i] = inv[i+1] * (i+1) % p;
}
void NTT(int *a,int lim,int opt)
{
	for(int i = 0; i < lim; i++)
	{
		if(i < rev[i]) swap(a[i],a[rev[i]]);
	}
	for(int h = 1; h < lim; h <<= 1)
	{
		int wn = ksm(3,(p-1)/(h<<1));
		if(opt == -1) wn = ksm(wn,p-2);
		for(int j = 0; j < lim; j += (h<<1))
		{
			int w = 1;
			for(int k = 0; k < h; k++)
			{
				int u = a[j + k];
				int v = w * a[j + h + k] % p;
				a[j + k] = (u + v) % p;
				a[j + h + k] = (u - v + p) % p;
				w = w * wn % p;
			}
		}
	}
	if(opt == -1)
	{
		int inv = ksm(lim,p-2);
		for(int i = 0; i < lim; i++) a[i] = a[i] * inv % p;
	}
}
signed main()
{
	n = read(); m = read(); x = read(); YYCH();
	for(int i = 0; i <= m; i++) f[i] = read();
	for(int i = 0; i <= m; i++)
	{
		a[i] = f[i] * inv[i] % p;
		b[i] = (i&1) ? (p-inv[i]) : inv[i];
	}
	int lim = 1, tim = 0;
	while(lim < (m<<2)) lim <<= 1, tim++;
	for(int i = 0; i < lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
	NTT(a,lim,1); NTT(b,lim,1);
	for(int i = 0; i < lim; i++) a[i] = a[i] * b[i] % p;
	NTT(a,lim,-1);
	int tmp = 1;
	for(int i = 0; i <= m; i++)
	{
		ans = (ans + a[i] * tmp % p * ksm(x,i) % p) % p;
		tmp = tmp * (n-i) % p;
	}
	printf("%lld\n",ans);
	return 0;
}
posted @ 2021-04-04 07:06  genshy  阅读(164)  评论(0编辑  收藏  举报