【LOJ2542】【PKUWC 2018】随机游走 min-max容斥 树上高斯消元

题目描述

  有一棵 \(n\) 个点的树。你从点 \(x\) 出发,每次等概率随机选择一条与所在点相邻的边走过去。

  有 \(q\) 次询问,每次询问给定一个集合 \(S\),求如果从 \(x\) 出发一直随机游走,直到点集 \(S\) 中所有点都至少经过一次的话,期望游走几步。

  特别地,点 \(x\)(即起点)视为一开始就被经过了一次。

  答案对 \(998244353\) 取模。

题解

  这道题要求点集 \(S\) 中所有点都至少经过一次的期望步数,直接做不好做,要先用一个 min-max 容斥转换成走到点集 \(S\) 中第一个点的期望步数:

\[\max(S)=\sum_{T\subseteq S,T\neq \varnothing}{(-1)}^{|T|+1}\min(T) \]

  然后就可以列方程高斯消元了。

  \(f_i\) 表示从 \(i\)走到最近的点所需要的最小步数。

\[\begin{align} f_i&=1+\frac{1}{d_i}f_{fa}+\frac{1}{d_i}\sum_v f_v \end{align} \]

  直接高斯消元是 \(O(n^3)\) 的,但是我们可以用一些技巧把这个过程加速到 \(O(n\log p)\)\(\log p\) 来自求逆元)。

  设 \(f_i=a_if_{fa}+b_i\)。特别的,如果 \(i\in S\),那么\(a_i=0,b_i=0\)

\[\begin{align} f_i&=1+\frac{1}{d_i}f_{fa}+\frac{1}{d_i}\sum_{v}(a_vf_i+b_v)\\ &=1+\frac{1}{d_i}f_{fa}+\frac{1}{d_i}(\sum_{v}a_vf_i+\sum_{v}b_v)\\ d_if_i&=d_i+f_{fa}+\sum_{v}a_vf_i+\sum_{v}b_v\\ (d_i-\sum_{v}a_v)f_i&=d_i+f_{fa}+\sum_{v}b_v\\ f_i&=\frac{1}{d_i-\sum_{v}a_v}f_{fa}+\frac{\sum_{v}b_v+d_i}{d_i-\sum_{v}a_v}\\ \end{align} \]

  这样就可以从下往上递推得到\(a_i,b_i\)

  那么答案就是 \(b_x\)

  然后就可以轻松算出询问每一个集合的答案了。

  时间复杂度:\(O(n2^n\log p+qn)\)

代码

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
#include<cmath>
#include<functional>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
void sort(int &a,int &b)
{
	if(a>b)
		swap(a,b);
}
void open(const char *s)
{
#ifndef ONLINE_JUDGE
	char str[100];
	sprintf(str,"%s.in",s);
	freopen(str,"r",stdin);
	sprintf(str,"%s.out",s);
	freopen(str,"w",stdout);
#endif
}
int rd()
{
	int s=0,c,b=0;
	while(((c=getchar())<'0'||c>'9')&&c!='-');
	if(c=='-')
	{
		c=getchar();
		b=1;
	}
	do
	{
		s=s*10+c-'0';
	}
	while((c=getchar())>='0'&&c<='9');
	return b?-s:s;
}
void put(int x)
{
	if(!x)
	{
		putchar('0');
		return;
	}
	static int c[20];
	int t=0;
	while(x)
	{
		c[++t]=x%10;
		x/=10;
	}
	while(t)
		putchar(c[t--]+'0');
}
int upmin(int &a,int b)
{
	if(b<a)
	{
		a=b;
		return 1;
	}
	return 0;
}
int upmax(int &a,int b)
{
	if(b>a)
	{
		a=b;
		return 1;
	}
	return 0;
}
const ll p=998244353;
ll fp(ll a,ll b)
{
	ll s=1;
	for(;b;b>>=1,a=a*a%p)
		if(b&1)
			s=s*a%p;
	return s;
}
ll f[100];
ll g[100];
vector<int> a[100];
int d[100];
int b[100];
void dfs(int x,int fa)
{
	if(b[x])
	{
		f[x]=g[x]=0;
		return;
	}
	f[x]=0;
	g[x]=d[x];
	ll k=d[x];
	for(auto v:a[x])
		if(v!=fa)
		{
			dfs(v,x);
			k=(k-f[v])%p;
			g[x]=(g[x]+g[v])%p;
		}
	k=fp(k,p-2);
	f[x]=k;
	g[x]=g[x]*k%p;
}
int n,q,rt;
ll s[1<<20];
int main()
{
	open("loj2542");
	scanf("%d%d%d",&n,&q,&rt);
	int x,y;
	for(int i=1;i<n;i++)
	{
		scanf("%d%d",&x,&y);
		a[x].push_back(y);
		a[y].push_back(x);
		d[x]++;
		d[y]++;
	}
	for(int i=1;i<1<<n;i++)
	{
		int num=0;
		for(int j=1;j<=n;j++)
		{
			b[j]=((i>>(j-1))&1);
			num+=b[j];
		}
		dfs(rt,0);
		s[i]=g[rt];
		if(!(num&1))
			s[i]=-s[i];
	}
	for(int i=1;i<=n;i++)
		for(int j=0;j<1<<n;j++)
			if((j>>(i-1))&1)
				s[j]=(s[j]+s[j^(1<<(i-1))])%p;
	int k;
	for(int i=1;i<=q;i++)
	{
		scanf("%d",&k);
		x=0;
		for(int i=1;i<=k;i++)
		{
			scanf("%d",&y);
			x|=1<<(y-1);
		}
		printf("%lld\n",(s[x]+p)%p);
	}
	return 0;
}
posted @ 2018-06-01 15:46  ywwyww  阅读(981)  评论(0编辑  收藏  举报