[ZJOI2022] 深搜 题解

题目描述

九条可怜是一个喜欢算法的女孩子,在众多算法中她尤其喜欢深度优先搜索(DFS)。

有一天,可怜得到了一棵有根树,树根为 \(\mathit{root}\),树上每个节点 \(x\) 有一个权值 \(a_x\)

在一棵树上从 \(x\) 出发,寻找 \(y\) 节点,如果使用深度优先搜索,则可描述为以下演算过程:

  1. 将递归栈设置为空。
  2. 首先将节点 \(x\) 放入递归栈中。
  3. 从递归栈中取出栈顶节点,如果该节点为 \(y\),则结束演算过程;否则,如果存在未访问的直接子节点,则以均等概率随机选择一个子节点加入递归栈中。
  4. 重复步骤 3,直到不存在未访问的直接子节点。
  5. 将上一级节点加入递归栈中,重复步骤 3。
  6. 重复步骤 5,直至当前一级节点为 \(x\),演算过程结束。

我们定义 \(f(x, y)\) 合法当且仅当 \(y\)\(x\) 的子树中。它的值为从 \(x\) 出发,对 \(x\) 的子树进行深度优先搜索寻找 \(y\) 期间访问过的所有节点(包括 \(x\)\(y\))权值最小值的期望。

九条可怜想知道对于所有合法的点对 \((x, y)\)\(\sum f(x, y)\) 的值。你只需要输出答案对 \(998244353\) 取模的结果。具体地,如果答案的最简分数表示为 \(\frac{a}{b}\),输出 \(a \times b^{-1} \bmod 998244353\)

提示

对于所有测试点,满足 \(1 \le T \le 100\)\(\sum n \le 8 \times {10}^5\)\(1 \le n \le 4 \times {10}^5\)\(1 \le \mathit{root}, u, v \le n\)\(1 \le a_i \le {10}^9\)

每个测试点的具体限制见下表:

测试点编号 \(\sum n \le\) \(n \le\) 特殊限制
\(1\) \(50\) \(10\)
\(2 \sim 4\) \(40000\) \(5000\)
\(5 \sim 10\) \(4 \times {10}^5\) \({10}^5\)
\(11\) \(8 \times {10}^5\) \(4 \times {10}^5\) 树的生成方式随机
\(12\) \(8 \times {10}^5\) \(4 \times {10}^5\) 树是一条链
\(13\) \(8 \times {10}^5\) \(4 \times {10}^5\) 根的度数为 \(n - 1\)
\(14 \sim 20\) \(8 \times {10}^5\) \(4 \times {10}^5\)

对于测试点 \(11\),树的生成方式为:以 \(1\) 为根,对于节点 \(i \in [2, n]\),从 \([1, i - 1]\) 中等概率随机选择一个点作为父亲。之后将编号随机重排。

题解

默认根的父亲为 \(0\)

算法1

对每个权值 \(val\) 计算 \(P_{val}\) 表示权值最小值 \(>=val\) 对答案的贡献,最后对 \(P\) 差分,答案即为 \(\sum a_xP_{a_x}\)

首先离散化权值,从小到大枚举权值 \(min\_val\) 并计算 \(P_{min\_val}\)

以下是计算 \(P_{min\_val}\)的过程。

对于点 \(i\),若其权值 \(a_i>=min\_val\),则其为黑点, \(color_i=1\),若其权值 \(a_i<min\_val\),则其为白点,\(color_i=0\)

\(f_x\) 表示从点 \(x\) 开始深搜的答案,则 \(P_{min\_val}=\sum f_x\),因此我们再设 \(g_x=\sum\limits_{y\in subtree(x)}f_y\)\(P_{min\_val}=g_{root}\)

\(g\) 的转移是容易的,若我们能计算出 \(f\),则 \(g_x=f_x+\sum\limits_{y \in son(x)}g_y\)

\(f\) 如何转移?

考虑 \(dfs\) 的过程,不难发现若路径合法,则路径上的点全是黑点,转移分类讨论即可。

具体地说,我们记 \(tag_x=[以x为根的子树全为黑点],cnt_x=\sum\limits_{y\in son(x)}(1-tag_y)\)\(cnt_x\) 的定义是 \(x\) 的儿子中,有几个儿子,以它们为根的子树内的点非全黑。

每次随机一个子树走入,若我们在全黑子树内停止,相当于这个全黑子树被走入的时间要排在所有非全黑子树前,由于概率均等,所以概率为 \(\frac{cnt_x!}{(1+cnt_x)|}=\frac{1}{1+cnt_x}\)

同理,若在非全黑子树内停止,相当于这个非全黑子树是最早被走入的非全黑子树,概率为 \(\frac{(cnt_x-1)!}{cnt_x!}=\frac{1}{cnt_x}\)

那么 \(f\) 的转移就有了,对于\(f_x\),先判断其是不是黑点,若是白点,值为 \(0\) ,否则先加上 \(f(x,x)=1\) ,然后转移。

写成式子就是 \(f_x=color_x*(1+\sum\limits_{y\in son(x)}\frac{1}{cnt_x+tag_y}f_y)\)

然后每次 \(min\_val\) 改变的时候,我们修改 \(color\) 后重新计算 \(tag,cnt,f,g\) 即可做到 \(O(n^2)\),可以获得 \(20\) 分。

代码

#include<bits/stdc++.h>
#define For(i,l,r) for(int i=(l);i<=(r);++i)
typedef long long ll;
const int mod=998244353;
const int N=400010;
using namespace std;
int n,root,tot;
int ver[N<<1],nxt[N<<1],head[N],b[N],cnt[N];
vector<int> id[N];
bool color[N],tag[N];
ll ans;
ll a[N],P[N],f[N],g[N],inv[N];
template<typename T1,typename T2>
void Add(T1 &a,T2 b){a+=b;if(a>=mod)a-=mod;return;}
template<typename T1,typename T2>
void Sub(T1 &a,T2 b){a-=b;if(a<0)a+=mod;return;}
void add(int x,int y)
{
	ver[++tot]=y;
	nxt[tot]=head[x];
	head[x]=tot;
}
void calc_tag(int x,int fa)
{
	tag[x]=color[x];
	for(int i=head[x];i;i=nxt[i])
	{
		int y=ver[i];
		if(y!=fa)
		{
			calc_tag(y,x);
			tag[x]&=tag[y];
		}
	}
}
void calc_cnt(int x,int fa)
{
	cnt[x]=0;
	for(int i=head[x];i;i=nxt[i])
	{
		int y=ver[i];
		if(y!=fa)
		{
			calc_cnt(y,x);
			cnt[x]+=(1-tag[y]);
		}
	}
}
void calc(int x,int fa)
{
	f[x]=1;
	g[x]=0;
	for(int i=head[x];i;i=nxt[i])
	{
		int y=ver[i];
		if(y!=fa)
		{
			calc(y,x);
			ll tmp=inv[cnt[x]+tag[y]];
			(tmp*=f[y])%=mod;
			Add(f[x],tmp);
			Add(g[x],g[y]);
		}
	}
	(f[x]*=color[x])%=mod;
	Add(g[x],f[x]);
}
void solve()
{
	tot=0;
	scanf("%d%d",&n,&root);
	{
		For(i,1,n)
			head[i]=0;
		For(i,1,n)
			id[i].clear();
	};
	For(i,1,n)
	{
		scanf("%lld",&a[i]);
		b[i]=a[i];
	}
	{
		sort(a+1,a+n+1);
		For(i,1,n)
			b[i]=(lower_bound(a+1,a+n+1,b[i])-a);
		For(i,1,n)
			id[b[i]].push_back(i);
	};
	For(i,1,(n-1))
	{
		int x,y;
		scanf("%d%d",&x,&y);
		add(x,y);
		add(y,x);
	}
	{
		For(i,1,n)
			color[i]=1;
	};
	{
		For(min_val,1,n)
		{
			if(id[min_val].empty())
				continue;
			calc_tag(root,-1);
			calc_cnt(root,-1);
			calc(root,-1);
			P[min_val]=g[root];
			for(auto j:id[min_val])
				color[j]=0;
		}
	};
	{
		For(i,1,(n-1))
			Sub(P[i],P[i+1]);
	};
	{
		ans=0;
		For(i,1,n)
		{
			ll E_i=P[i];
			(E_i*=a[i])%=mod;
			Add(ans,E_i);
		}
		printf("%lld\n",ans);
	};
	return;
}
int main()
{
	{
		inv[1]=1;
		For(i,2,(N-1))
		{
			inv[i]=mod;
			Sub(inv[i],(mod/i));
			(inv[i]*=inv[mod%i])%=mod;
		}
	};
	int T;
	scanf("%d",&T);
	while(T--)
		solve();
	return 0;
}

算法2

我们注意到每次修改一个点的 \(color\) ,这个量的总变化次数是点数级别的,暴力去做即可。

可能影响的信息还有这个点的 \(tag\) 值,若这个点有父亲,可能会影响其父亲的 \(cnt\) 值,不可能每次重新计算。

但是我们注意到\(tag_x=1\) 的时间是一段前缀,具体地,\(tag_x\) 在以 \(x\) 为根的子树内的最小值对应的点变为白色后从 \(1\) 变成 \(0\) ,因此在最开始预处理子树最小值后暴力修改 \(tag_x\) 及其父亲的 \(cnt\) 值即可。

但是可能会影响的 \(f,g\) 值,都是这个点到根的一条链。

使用动态 \(dp\) 的套路,轻重链剖分,对每个点维护轻儿子相关信息和,对于每条重链的链顶额外维护子树信息和即可,具体地说,下文的 \(f\_light,g\_light\) 表示轻儿子相关信息和,\(f,g\) 定义不变,就是子树信息和,\(sum\_f\) 是辅助转移的数组。

\(sum\_f[x][0]=\sum\limits_{y是x轻儿子} f[y]*[tag[y]==0],sum\_f[x][1]=\sum\limits_{y是x轻儿子} f[y]*[tag[y]==1]\)

以下 \(y\) 表示 \(x\) 的重儿子(若其存在。)。

那么 \(f\) 的转移就可以写成 \(f_x=color[x]*(1+\frac{1}{cnt[x]}sum\_f[x][0]+\frac{1}{cnt[x]+1}sum\_f[x][1]+\frac{1}{cnt[x]+tag[y]}f[y])\)

\(f\_light[x]=color[x]*(1+\frac{1}{cnt[x]}sum\_f[x][0]+\frac{1}{cnt[x]+1}sum\_f[x][1])\)\(g\_light[x]=f\_light[x]+\sum\limits_{z是x轻儿子}g[z]\)

\(f,g\) 的转移可以写成 \(f[x]=f\_light[x]+\frac{1}{cnt[x]+tag[y]}f[y],g[x]=g\_light[x]+\frac{1}{cnt[x]+tag[y]}f[y]\)

转移写成矩阵乘法形式后用线段树维护即可。

每次需要将点由黑变白,以及修改一些 \(tag_x,cnt_{fa_x}\),以及若 \(x\) 不是 \(fa_x\) 的重儿子,记得修改 \(sum\_f[fa_x][0],sum\_f[fa_x][1]\)

时间复杂度 \(O(n\ log^2\ n)\),至少可以获得 \(55\) 分。

代码

#include<bits/stdc++.h>
#define For(i,l,r) for(int i=(l);i<=(r);++i)
typedef long long ll;
const int mod=998244353;
const int N=400010;
using namespace std;
int n,root,tot;
int a[N],ver[N<<1],nxt[N<<1],head[N],b[N],cnt[N];
vector<int> id[N],tim_e[N];
bool color[N],tag[N];
ll ans;
ll P[N],f[N],g[N],f_light[N],g_light[N],inv[N],sum_f[N][2];
template<typename T1,typename T2>
void Add(T1 &a,T2 b){a+=b;if(a>=mod)a-=mod;return;}
template<typename T1,typename T2>
void Sub(T1 &a,T2 b){a-=b;if(a<0)a+=mod;return;}
void add(int x,int y)
{
	ver[++tot]=y;
	nxt[tot]=head[x];
	head[x]=tot;
}
int subtree_min[N];
void calc_subtree_min(int x,int fa)
{
	subtree_min[x]=b[x];
	for(int i=head[x];i;i=nxt[i])
	{
		int y=ver[i];
		if(y!=fa)
		{
			calc_subtree_min(y,x);
			subtree_min[x]=min(subtree_min[x],subtree_min[y]);
		}
	}
}
int dep[N],fa[N],siz_e[N],L_size[N],son[N];
void dfs1(int x,int Fa)
{
	fa[x]=Fa;
	dep[x]=(dep[Fa]+1);
	siz_e[x]=1;
	son[x]=0;
	for(int i=head[x];i;i=nxt[i])
	{
		int y=ver[i];
		if(y!=Fa)
		{
			dfs1(y,x);
			siz_e[x]+=siz_e[y];
			if((son[x]==0) || (siz_e[y]>siz_e[son[x]]))
				son[x]=y;
		}
	}
	L_size[x]=siz_e[x];
	if(son[x])
		L_size[x]-=siz_e[son[x]];
}
int cnt_;
int dfn[N],rnk[N],top[N],en_d[N];
void dfs2(int x,int Top)
{
	top[x]=Top;
	en_d[Top]=x;
	++cnt_;
	dfn[x]=cnt_;
	rnk[cnt_]=x;
	if(son[x])
		dfs2(son[x],Top);
	for(int i=head[x];i;i=nxt[i])
	{
		int y=ver[i];
		if((y!=fa[x]) && (y!=son[x]))
			dfs2(y,y);
	}
}
void get_f_light(int x)
{
	f_light[x]=1;
	ll tmp0=sum_f[x][0],tmp1=sum_f[x][1];
	(tmp0*=inv[cnt[x]])%=mod;
	(tmp1*=inv[cnt[x]+1])%=mod;
	Add(f_light[x],tmp0);
	Add(f_light[x],tmp1);
	(f_light[x]*=color[x])%=mod;
	return;
}
struct Matrix{ll a[3][3];}I;
Matrix operator * (Matrix mat1,Matrix mat2)
{
	Matrix res;
	For(i,0,2)
	{
		For(j,0,2)
			res.a[i][j]=0;
	}
	For(i,0,2)
	{
		For(j,0,2)
		{
			For(k,0,2)
			{
				ll tmp=mat1.a[i][k];
				(tmp*=mat2.a[k][j])%=mod;
				Add(res.a[i][j],tmp);
			}
		}
	}
	return res;
}
struct node{Matrix prod;}tree[N<<2];
Matrix get_mat(int x)
{
	Matrix res;
	For(i,0,2)
	{
		For(j,0,2)
			res.a[i][j]=0;
	}
	if(color[x])
	{
		res.a[0][0]=inv[cnt[x]+tag[son[x]]];
		res.a[1][0]=inv[cnt[x]+tag[son[x]]];
	}
	get_f_light(x);
	res.a[0][2]=f_light[x];
	res.a[1][2]=g_light[x];
	res.a[1][1]=1;
	res.a[2][2]=1;
	return res;
}
#define lson(x) (x<<1)
#define rson(x) (x<<1|1) 
void pushup(int x){tree[x].prod=(tree[lson(x)].prod*tree[rson(x)].prod);return;}
void build(int x,int l,int r)
{
	if(l==r)
	{
		tree[x].prod=get_mat(rnk[l]);
		return;
	}
	int mid=((l+r)>>1);
	build(lson(x),l,mid);
	build(rson(x),(mid+1),r);
	pushup(x);
}
Matrix query(int x,int l,int r,int L,int R)
{
	if(L<=l && r<=R)
		return tree[x].prod;
	Matrix res=I;
	int mid=((l+r)>>1);
	if(L<=mid)
		res=(res*query(lson(x),l,mid,L,R));
	if((mid+1)<=R)
		res=(res*query(rson(x),(mid+1),r,L,R));
	return res;
}
void modify(int x,int l,int r,int pos)
{
	if(l==r)
	{
		tree[x].prod=get_mat(rnk[pos]);
		return;
	}
	int mid=((l+r)>>1);
	if(pos<=mid)
		modify(lson(x),l,mid,pos);
	if((mid+1)<=pos)
		modify(rson(x),(mid+1),r,pos);
	pushup(x);
}
void calc_subtree(int x)
{
	Matrix res=query(1,1,n,dfn[x],dfn[en_d[x]]);
	f[x]=res.a[0][2];
	g[x]=res.a[1][2];
	return;
}
void modify(int x)
{
	while(x)
	{
		int top_x=top[x],fa_top_x=fa[top_x];
		if(fa_top_x)
		{
			Sub(g_light[fa_top_x],g[top_x]);
			Sub(g_light[fa_top_x],f_light[fa_top_x]);
			Sub(sum_f[fa_top_x][tag[top_x]],f[top_x]);
		}
		modify(1,1,n,dfn[x]);
		calc_subtree(top_x);
		if(fa_top_x)
		{
			Add(sum_f[fa_top_x][tag[top_x]],f[top_x]);
			get_f_light(fa_top_x);
			Add(g_light[fa_top_x],f_light[fa_top_x]);
			Add(g_light[fa_top_x],g[top_x]);
		}
		x=fa[top[x]];
	}
	return;
}
void change_color(int x)
{
	Sub(g_light[x],f_light[x]);
	color[x]=0;
	f_light[x]=0;
	modify(x);
	return;
}
void change(int x)
{
	tag[x]=0;
	int fa_x=fa[x];
	if(fa_x==0)
		return;
	if(son[fa_x]!=x)
	{
		calc_subtree(x);
		Sub(sum_f[fa_x][1],f[x]);
		Add(sum_f[fa_x][0],f[x]);
	}
	Sub(g_light[fa_x],f_light[fa_x]);
	++cnt[fa_x];
	get_f_light(fa_x);
	Add(g_light[fa_x],f_light[fa_x]);
	modify(fa_x);
	return;
}
void calc_init(int x)
{
	f[x]=0;
	g[x]=0;
	for(int i=head[x];i;i=nxt[i])
	{
		int y=ver[i];
		if(y!=fa[x])
		{
			calc_init(y);
			ll tmp=inv[cnt[x]+tag[y]];
			(tmp*=f[y])%=mod;
			Add(f[x],tmp);
			Add(g[x],g[y]);
			if(y!=son[x])
			{
				ll tmp0=f[y],tmp1=f[y];
				(tmp0*=(1-tag[y]))%=mod;
				(tmp1*=tag[y])%=mod;
				Add(sum_f[x][0],tmp0);
				Add(sum_f[x][1],tmp1);
			}
		}
	}
	Add(f[x],1);
	(f[x]*=color[x])%=mod;
	Add(g[x],f[x]);
}
void calc_light()
{
	For(x,1,n)
	{
		get_f_light(x);
		g_light[x]=g[x];
		if(son[x])
		{
			ll tmp=inv[cnt[x]+tag[son[x]]];
			(tmp*=f[son[x]])%=mod;
			Sub(g_light[x],tmp);
			Sub(g_light[x],g[son[x]]);
		}
	}
	return;
}
void calc()
{
	calc_init(root);
	calc_light();
	return;
}
void solve()
{
	scanf("%d%d",&n,&root);
	{
		tot=0;
		cnt_=0;
		For(i,1,n)
			head[i]=0;
		For(i,1,n)
			id[i].clear();
		For(i,1,n)
			tim_e[i].clear();
		For(i,1,n)
		{
			f[i]=0;
			g[i]=0;
			sum_f[i][0]=0;
			sum_f[i][1]=0;
			f_light[i]=0;
			g_light[i]=0;
		}
	};
	For(i,1,n)
	{
		scanf("%d",&a[i]);
		b[i]=a[i];
	}
	{
		sort(a+1,a+n+1);
		For(i,1,n)
			b[i]=(lower_bound(a+1,a+n+1,b[i])-a);
		For(i,1,n)
			id[b[i]].push_back(i);
	};
	For(i,1,(n-1))
	{
		int x,y;
		scanf("%d%d",&x,&y);
		add(x,y);
		add(y,x);
	}
	{
		dfs1(root,0);
		dfs2(root,root);
		calc_subtree_min(root,0);
		For(i,1,n)
			color[i]=1;
		For(i,1,n)
			tag[i]=1;
		For(i,1,n)
			cnt[i]=0;
		For(i,1,n)
			tim_e[subtree_min[i]].push_back(i);
	};
	{
		calc();
		build(1,1,n);
		For(min_val,1,n)
		{
			if(id[min_val].empty())
				continue;
			P[min_val]=g[root];
			for(auto j:id[min_val])
				change_color(j);
			for(auto j:tim_e[min_val])
				change(j);
		}
	};
	{
		For(i,1,(n-1))
			Sub(P[i],P[i+1]);
	};
	{
		ans=0;
		For(i,1,n)
		{
			ll E_i=P[i];
			(E_i*=a[i])%=mod;
			Add(ans,E_i);
		}
		printf("%lld\n",ans);
	};
	return;
}
int main()
{
	For(i,0,2)
	{
		For(j,0,2)
			I.a[i][j]=0;
	}
	For(i,0,2)
		I.a[i][i]=1;
	{
		inv[1]=1;
		For(i,2,(N-1))
		{
			inv[i]=mod;
			Sub(inv[i],(mod/i));
			(inv[i]*=inv[mod%i])%=mod;
		}
	};
	int T;
	scanf("%d",&T);
	while(T--)
		solve();
	return 0;
}

算法3

将算法 \(2\) 的线段树改为全局平衡二叉树,即可做到 \(O(n\ log\ n)\),需要略微精细实现,比如矩阵乘法有个 \(27\) 的常数,但是我们注意到矩阵中只有 \(4\) 个位置的值会变,其他都是固定的,那么我们只维护这 \(4\) 个位置即可,将矩阵乘法手动计算后,这 \(4\) 个位置的值都可以快速求出,这是一个很大的优化。

可以获得 \(100\) 分。

代码

#include<bits/stdc++.h>
#define For(i,l,r) for(int i=(l);i<=(r);++i)
typedef long long ll;
const int mod=998244353;
const int N=400010;
using namespace std;
int n,root,tot;
int a[N],ver[N<<1],nxt[N<<1],head[N],b[N],cnt[N],rt[N];
vector<int> id[N],tim_e[N];
bool color[N],tag[N];
ll ans;
ll P[N],f[N],g[N],f_light[N],g_light[N],inv[N],sum_f[N][2];
template<typename T1,typename T2>
void Add(T1 &a,T2 b){a+=b;if(a>=mod)a-=mod;return;}
template<typename T1,typename T2>
void Sub(T1 &a,T2 b){a-=b;if(a<0)a+=mod;return;}
void add(int x,int y)
{
	ver[++tot]=y;
	nxt[tot]=head[x];
	head[x]=tot;
}
int subtree_min[N];
void calc_subtree_min(int x,int fa)
{
	subtree_min[x]=b[x];
	for(int i=head[x];i;i=nxt[i])
	{
		int y=ver[i];
		if(y!=fa)
		{
			calc_subtree_min(y,x);
			subtree_min[x]=min(subtree_min[x],subtree_min[y]);
		}
	}
}
int dep[N],fa[N],siz_e[N],L_size[N],son[N];
void dfs1(int x,int Fa)
{
	fa[x]=Fa;
	dep[x]=(dep[Fa]+1);
	siz_e[x]=1;
	son[x]=0;
	for(int i=head[x];i;i=nxt[i])
	{
		int y=ver[i];
		if(y!=Fa)
		{
			dfs1(y,x);
			siz_e[x]+=siz_e[y];
			if((son[x]==0) || (siz_e[y]>siz_e[son[x]]))
				son[x]=y;
		}
	}
	L_size[x]=siz_e[x];
	if(son[x])
		L_size[x]-=siz_e[son[x]];
}
int cnt_;
int dfn[N],rnk[N],top[N],en_d[N];
void dfs2(int x,int Top)
{
	top[x]=Top;
	en_d[Top]=x;
	++cnt_;
	dfn[x]=cnt_;
	rnk[cnt_]=x;
	if(son[x])
		dfs2(son[x],Top);
	for(int i=head[x];i;i=nxt[i])
	{
		int y=ver[i];
		if((y!=fa[x]) && (y!=son[x]))
			dfs2(y,y);
	}
}
void get_f_light(int x)
{
	f_light[x]=1;
	ll tmp0=sum_f[x][0],tmp1=sum_f[x][1];
	(tmp0*=inv[cnt[x]])%=mod;
	(tmp1*=inv[cnt[x]+1])%=mod;
	Add(f_light[x],tmp0);
	Add(f_light[x],tmp1);
	(f_light[x]*=color[x])%=mod;
	return;
}
struct Matrix{ll a_0_0,a_0_2,a_1_0,a_1_2;}I;
Matrix operator * (Matrix mat1,Matrix mat2)
{
	Matrix res;
	res.a_0_0=((1ll*mat1.a_0_0*mat2.a_0_0)%mod);
	res.a_0_2=(((1ll*mat1.a_0_0*mat2.a_0_2)+mat1.a_0_2)%mod);
	res.a_1_0=(((1ll*mat1.a_1_0*mat2.a_0_0)+mat2.a_1_0)%mod);
	res.a_1_2=(((1ll*mat1.a_1_0*mat2.a_0_2)+mat1.a_1_2+mat2.a_1_2)%mod);
	return res;
}
Matrix get_mat(int x)
{
	Matrix res;
	if(color[x])
	{
		res.a_0_0=inv[cnt[x]+tag[son[x]]];
		res.a_1_0=inv[cnt[x]+tag[son[x]]];
	}
	else
	{
		res.a_0_0=0;
		res.a_1_0=0;
	}
	get_f_light(x);
	res.a_0_2=f_light[x];
	res.a_1_2=g_light[x];
	return res;
}
int seq[N],weight[N];
#define lson(x) (tree[x].lson)
#define rson(x) (tree[x].rson)
struct node{int lson,rson,anc;Matrix prod,mat;}tree[N];
void pushup(int x){tree[x].prod=(tree[lson(x)].prod*tree[x].mat*tree[rson(x)].prod);return;}
int build_heavy_chain(int L,int R)
{
	if(L>R)
		return 0;
	ll sum=0,sum_now=0;
	For(i,L,R)
		sum+=(1ll*L_size[i]);
	For(i,L,R)
	{
		sum_now+=(1ll*L_size[i]);
		if((sum_now*1ll*2)>sum)
		{
			int root=seq[i];
			tree[root].mat=get_mat(root);
			tree[root].lson=build_heavy_chain(L,(i-1));
			tree[lson(root)].anc=root;
			tree[root].rson=build_heavy_chain((i+1),R);
			tree[rson(root)].anc=root;
			pushup(root);
			return root;
		}
	}
	return 0;
}
void build(int Top)
{
	for(int x=Top;x;x=son[x])
	{
		for(int i=head[x];i;i=nxt[i])
		{
			int y=ver[i];
			if((y!=fa[x]) && (y!=son[x]))
				build(y);
		}
	}
	int num=0;
	for(int x=Top;x;x=son[x])
	{
		++num;
		seq[num]=x;
		weight[num]=L_size[x];
	}
	rt[Top]=build_heavy_chain(1,num);
	tree[rt[Top]].anc=0;
}
void update(int x)
{
	tree[x].mat=get_mat(x);
	for(;x;x=tree[x].anc)
		pushup(x);
	return;
}
void modify(int x)
{
	while(x)
	{	
		int top_x=top[x],fa_top_x=fa[top_x];
		if(fa_top_x)
		{
			Sub(g_light[fa_top_x],tree[rt[top_x]].prod.a_1_2);
			Sub(g_light[fa_top_x],f_light[fa_top_x]);
			Sub(sum_f[fa_top_x][tag[top_x]],tree[rt[top_x]].prod.a_0_2);
		}
		update(x);
		if(fa_top_x)
		{
			Add(sum_f[fa_top_x][tag[top_x]],tree[rt[top_x]].prod.a_0_2);
			get_f_light(fa_top_x);
			Add(g_light[fa_top_x],f_light[fa_top_x]);
			Add(g_light[fa_top_x],tree[rt[top_x]].prod.a_1_2);
		}
		x=fa[top[x]];
	}
	return;
}
void change_color(int x)
{
	Sub(g_light[x],f_light[x]);
	color[x]=0;
	f_light[x]=0;
	modify(x);
	return;
}
void change(int x)
{
	tag[x]=0;
	int fa_x=fa[x];
	if(fa_x==0)
		return;
	if(son[fa_x]!=x)
	{
		Sub(sum_f[fa_x][1],tree[rt[x]].prod.a_0_2);
		Add(sum_f[fa_x][0],tree[rt[x]].prod.a_0_2);
	}
	Sub(g_light[fa_x],f_light[fa_x]);
	++cnt[fa_x];
	get_f_light(fa_x);
	Add(g_light[fa_x],f_light[fa_x]);
	modify(fa_x);
	return;
}
void calc_init(int x)
{
	f[x]=0;
	g[x]=0;
	for(int i=head[x];i;i=nxt[i])
	{
		int y=ver[i];
		if(y!=fa[x])
		{
			calc_init(y);
			ll tmp=inv[cnt[x]+tag[y]];
			(tmp*=f[y])%=mod;
			Add(f[x],tmp);
			Add(g[x],g[y]);
			if(y!=son[x])
			{
				ll tmp0=f[y],tmp1=f[y];
				(tmp0*=(1-tag[y]))%=mod;
				(tmp1*=tag[y])%=mod;
				Add(sum_f[x][0],tmp0);
				Add(sum_f[x][1],tmp1);
			}
		}
	}
	Add(f[x],1);
	(f[x]*=color[x])%=mod;
	Add(g[x],f[x]);
}
void calc_light()
{
	For(x,1,n)
	{
		get_f_light(x);
		g_light[x]=g[x];
		if(son[x])
		{
			ll tmp=inv[cnt[x]+tag[son[x]]];
			(tmp*=f[son[x]])%=mod;
			Sub(g_light[x],tmp);
			Sub(g_light[x],g[son[x]]);
		}
	}
	return;
}
void calc(int x)
{
	calc_init(x);
	calc_light();
	return;
}
void solve()
{
	int root;
	cin>>n>>root;
	{
		tot=0;
		cnt_=0;
		For(i,1,n)
			head[i]=0;
		For(i,1,n)
			id[i].clear();
		For(i,1,n)
			tim_e[i].clear();
	};
	For(i,1,n)
	{
		cin>>a[i];
		b[i]=a[i];
	}
	{
		sort(a+1,a+n+1);
		For(i,1,n)
			b[i]=(lower_bound(a+1,a+n+1,b[i])-a);
		For(i,1,n)
			id[b[i]].push_back(i);
	};
	For(i,1,(n-1))
	{
		int x,y;
		cin>>x>>y;
		add(x,y);
		add(y,x);
	}
	{
		dfs1(root,0);
		dfs2(root,root);
		calc_subtree_min(root,0);
		For(i,1,n)
			color[i]=1;
		For(i,1,n)
			tag[i]=1;
		For(i,1,n)
			cnt[i]=0;
		For(i,1,n)
			tim_e[subtree_min[i]].push_back(i);
	};
	{
		calc(root);
		build(root);
		For(min_val,1,n)
		{
			P[min_val]=tree[rt[root]].prod.a_1_2;
			for(auto j:id[min_val])
				change_color(j);
			for(auto j:tim_e[min_val])
				change(j);
		}
	};
	{
		For(i,1,(n-1))
			Sub(P[i],P[i+1]);
	};
	{
		ans=0;
		For(i,1,n)
		{
			ll E_i=P[i];
			(E_i*=a[i])%=mod;
			Add(ans,E_i);
		}
		cout<<ans<<"\n";
	};
	return;
}
int main()
{
	I.a_0_0=1;I.a_0_2=0;I.a_1_0=0;I.a_1_2=0;
	tree[0].mat=I;
	tree[0].prod=I;
	{
		inv[1]=1;
		For(i,2,(N-1))
		{
			inv[i]=mod;
			Sub(inv[i],(mod/i));
			(inv[i]*=inv[mod%i])%=mod;
		}
	};
	ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);int T;cin>>T;while(T--)solve();return 0;
}
posted @ 2023-06-15 19:15  llzer  阅读(40)  评论(0编辑  收藏  举报