P5643-[PKUWC2018]随机游走【min-max容斥,dp】

正题

题目链接:https://www.luogu.com.cn/problem/P5643


题目大意

给出\(n\)个点的一棵树,一个人从点\(x\)开始随机游走,然后\(Q\)次询问给出一个点集\(S\),求期望多少步这个人会经过这个点集中的所有点。

\(1\leq n\leq 18,1\leq Q\leq 5000\)


解题思路

整个点集都走完比较难统计,我们可以考虑用\(min-max\)容斥转为求走到其中一个点的期望步数。

我们设我们目前枚举的集合是\(S\),那么首先有\(f_x=0(x\in S)\)

然后有转移方程:

\[f_{x}=\frac{1}{deg_x}(f_{fa_x}+\sum_{x\rightarrow y}f_{y}) \]

惯例的我们设\(f_x=A_xf_{fa_x}+B_x\)

\[f_{x}=\frac{1}{deg_x}\left(f_{fa_x}+\sum_{x\rightarrow y}(A_yf_x+B_y)\right) \]

\[f_{x}=\frac{1}{deg_x}f_{fa_x}+\frac{sumAf_x}{deg}+\frac{1}{deg_x}sumB+1 \]

\[\frac{deg_x-sumA}{deg_x}f_x=\frac{1}{deg_x}f_{fa_x}+\frac{1}{deg_x}sumB+1 \]

\[f_x=\frac{1}{deg_x-sumA}f_{fa_x}+\frac{deg_x+sumB}{deg_x-sum_A} \]

这样我们就可以推出\(A\)\(B\),而\(B_x\)就是节点\(x\)\(f\)值,记\(g_S=B_x\)

那么根据min-max容斥如果我们询问集合\(S\)时答案就是

\[\sum_{T\sube S}(-1)^{|T|+1}g_T \]

用个高维前缀和就可以预处理所有集合的答案了。

时间复杂度:\(O(n2^n\log P)\)


code

#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
const ll N=18,P=998244353;
struct node{
	ll to,next;
}a[N<<1];
ll n,Q,rt,tot,ls[N],deg[N];
ll A[N],B[N],c[1<<N],f[1<<N];
ll power(ll x,ll b){
	ll ans=1;
	while(b){
		if(b&1)ans=ans*x%P;
		x=x*x%P;b>>=1;
	}
	return ans;
}
void addl(ll x,ll y){
	a[++tot].to=y;
	a[tot].next=ls[x];
	ls[x]=tot;deg[y]++;
	return;
}
void dfs(ll x,ll fa,ll S){
	ll sumA=0,sumB=0;
	if((S>>x)&1){A[x]=B[x]=0;return;}
	for(ll i=ls[x];i;i=a[i].next){
		ll y=a[i].to;
		if(y==fa)continue;
		dfs(y,x,S);
		sumA=(sumA+A[y])%P;
		sumB=(sumB+B[y])%P;
	}
	ll inv=power((deg[x]-sumA+P)%P,P-2);
	A[x]=inv;B[x]=(deg[x]+sumB)*inv%P;
	return;
}
signed main()
{
	scanf("%lld%lld%lld",&n,&Q,&rt);rt--;
	for(ll i=1,x,y;i<n;i++){
		scanf("%lld%lld",&x,&y);x--;y--;
		addl(x,y);addl(y,x);
	}
	ll MS=(1<<n);
	for(ll s=1;s<MS;s++){
		memset(A,0,sizeof(A));
		memset(B,0,sizeof(B));
		dfs(rt,n,s);f[s]=B[rt];
	}
	for(ll s=1;s<MS;s++){
		c[s]=c[s-(s&-s)]+1;
		f[s]=((c[s]&1)?f[s]:(P-f[s]));
	}
	for(ll i=0;i<n;i++)
		for(ll s=MS-1;s>=0;s--)
			if((s>>i)&1)(f[s]+=f[s-(1<<i)])%=P;
	while(Q--){
		ll k,s=0;scanf("%lld",&k);
		for(ll i=0,x;i<k;i++)
			scanf("%lld",&x),s|=(1<<x-1);
		printf("%lld\n",f[s]);
	}
	return 0;
}
posted @ 2021-12-11 10:10  QuantAsk  阅读(29)  评论(0编辑  收藏  举报