题解:P9437 一棵树

题目传送门

明显的换根 dp,感觉是道不错的换根 dp 练习题。

题意

一棵 \(n\) 个节点的树,点带权,定义 \(w(x,y)=\overline {a_x\dots a_y}\)

\(\sum\limits_{i=1}^{n}\sum\limits_{j=1}^{n}w(i,j)\bmod 998244353\)

思路

我们不妨先求出 \(i=1\) 时的 \(\sum w(i,j)\),再求 \(\sum\limits_{i=1}^{n}\sum\limits_{j=1}^{n}w(i,j)\)

\(f_u\) 为其子树每个点走到 \(x\) 的总贡献,定义 \(\operatorname{get}(a_u)\)\(a_u\) 的位数,不难得出转移方程:

\[f_u=\sum\limits_{v\in son} f_v \times \operatorname{get}(a_u)+size_u\times a_x \]

这样我们就可以求出 \(i=1\) 时的 \(\sum w(i,j)\)。朴素的,我们可以直接求 \(n\) 次得到最终的答案,但是这样 \(\mathcal{O(n^2)}\) 的复杂度是不能接受的。

\(i=1\) 推到整棵树,考虑换根 dp,发现这几个条件都很好换根转移。

具体的,我们设 \(g_u\) 为整棵树走到 \(u\) 的贡献。考虑一个节点,假设我们已经得知了其父节点的 \(g\),并且也得知了其子树内点的贡献,所以剩余需要求出子树外的答案。

子树外节点到其父节点的贡献和为其父节点的 \(g\) 减去此节点及其子树对父节点的贡献,而子树外节点到其本身的贡献都比父节点多个本身,所以需要再乘一个 \(10^{\mid a_x \mid}\)。形式化的:

\[g_u=(g_{fa}-(f_u\times \operatorname{get}(a_{fa})+size_x\times a_{fa}))\times \operatorname{get}(a_{fa})+(n-size_u)\times a_u+f_u \]

答案即为 \(\sum g\),复杂度 \(\mathcal{O(n)}\),注意开 long long 和一步一模防止爆炸。

代码

#include<bits/stdc++.h>
#define int long long 
using namespace std;
const int N=1e6+5;
const int mod=998244353;
inline int read();
int n,a[N],cnt,head[N];
int f[N],size[N],g[N],ans;
struct E{
	int to,next,w;
}edge[N<<1];
void add(int u,int v)
{
	edge[++cnt].next=head[u];
	edge[cnt].to=v;
	head[u]=cnt;
}
int get(int x)
{
	if(x==0) return 10;
	int cnt=0;
	while(x>=1)
	{
		x/=10;
		cnt++;
	}
	return pow(10,cnt);
}
void dfs0(int x,int fa)
{
	size[x]=1;
	for(int i=head[x];i;i=edge[i].next)
	{
		int to=edge[i].to;
		if(to==fa) continue;
		dfs0(to,x);
		f[x]=(f[x]+f[to]*get(a[x]))%mod;
		size[x]+=size[to];
	}
	f[x]=(f[x]+size[x]*a[x])%mod;
}
void dfs(int x,int fa)
{
	for(int i=head[x];i;i=edge[i].next)
	{
		int to=edge[i].to;
		if(to==fa) continue;
		g[to]=(((g[x]-(f[to]*get(a[x])%mod+a[x]*size[to])%mod)*get(a[to]))%mod+((n-size[to])*a[to])%mod+f[to])%mod;
		dfs(to,x);
	}
}
signed main()
{
	n=read();
	for(int i=1;i<=n;i++)
	{
		a[i]=read(); 
	}
	for(int i=1;i<n;i++)
	{
		int x;
		x=read();
		add(x,i+1);
		add(i+1,x);
	}
	dfs0(1,0);
	g[1]=f[1];
	dfs(1,0);
	for(int i=1;i<=n;i++)
	{
		ans=(ans+g[i])%mod; 
	}
	printf("%lld",ans);
	return 0;
}

inline int read()
{
	int x=0,f=1;
	char ch;
	ch=getchar();
	while(ch>'9'||ch<'0'){if(ch=='-') f=-f;ch=getchar();}
	while(ch<='9'&&ch>='0')
	{
		x=(x<<1)+(x<<3)+(ch&15);
		ch=getchar();
	}
    return x*f;
}
posted @ 2024-07-23 10:21  一只小咕咕  阅读(42)  评论(0)    收藏  举报