[bzoj 4543] Hotel加强版

题意

给定一棵有n个点的树,有三元组$(a,b,c)$满足a,b,c两两距离相等,求这样的三元组的个数。

n<=100000

题解

初步思考

这种题目有个经典解法就是固定两点求中点。但很遗憾这种做法是$O(n^2)$的。

所以我们用树形dp解决这个问题。

考虑从三点lca处统计贡献,这样比较方便。

状态设计

我们考虑逐步满足法,先统计一个点,再用一个点的数据计算两个点都满足的数据,再用一个点和两个点的数据计算三个点都满足的数据。

于是,我们用$f[i][j]$表示在i的子树内与i距离为j的点的个数,

用$g[i][j]$表示在i的子树内有两个点,且他们再加上一个i子树外的与i距离为j的点就满足题目要求。

之所以这样设置是为了方便将g和f拼在一起,且不重不漏地统计答案。

转移方程

那么怎样转移计算g[i][j]的值呢?

因为第二项是j,由定义得现在还需要一条长度为j的边。

考虑一个个枚举i的儿子,统计贡献。

第一种情况,这两条边,一条在已经枚举过的儿子里找,一条在新加进来的找。

那么i就是这两条边的交点,可得两条边长度为i.

下图中,红色表示现在连接的边,黄色表示需要的边,绿色表示将要配对成一个三元组的边。

这种情况可以用$f[i][j]*f[son][j-1]$描述,其中son表示新加进来的儿子。

另外一种情况,两条边都在son里找。

那么就是在son原有的二元组的基础上再加上i--son这一条,因为现在还需要一条长度为j的边,那么在连接之前son则需要一条长度为j+1的边。

这部分可以用$g[son][j+1]$描述

那么,整个g的方程就呼之欲出了

$g[i][j]+=g[son][j+1]+f[i][j]*f[son][j-1]$

f的方程也很简单

$f[id][j]+=f[to][j-1]$

答案的统计也可以列出来了:

$ans+=g[to][j+1]*f[id][j]+g[id][j]*f[to][j-1]$

注意下图的情况要单独计算

 

 

 可以用$ans+=g[id][0]$来统计。

注意先转移g,再f,再ans。

优化转移

但是,这样仍然是$O(n^2)$

注意到第一个儿子对i的贡献是

$f[x][i]+=f[u][i-1]$

$ g[x][i]+=g[u][i+1]$

也就是说第一个儿子可以$O(1)$转移。

那么我们可以长链剖分,把最深的儿子放在第一位(每个儿子转移复杂度为他的链的长度)

这样,每条链只会在顶部被计算一次,则复杂度为$O(n)$

剩下的就是空间复杂度的问题,我们可以像重链剖分一样把每条链的节点放在一起,因为f和g数组是直接在深儿子的数组上位移而来。所以每条链可以共用一段空间,只不过每个点的起始位不同。对于

 

 

 存储如下

 

这样,空间也是$O(n)$的。这样就可以开始敲代码了。

代码

#include <iostream>
#include <cstdio>
#include <vector>
using namespace std;
#define N 1000001
#define int long long
vector<int> vec[N];
int maxd[N],dson[N],dep[N];
int g[N*5],f[N*5],dfn[N],cnt,top[N],pos[N],cnt2;
void get_maxd(int id,int from)
{
	maxd[id]=1;
	dep[id]=dep[from]+1;
	for(int i=0;i<vec[id].size();i++)
	{
		int to=vec[id][i];
		if(to==from) continue;
		get_maxd(to,id);
		if(maxd[to]+1>maxd[id]) dson[id]=to;
		maxd[id]=max(maxd[to]+1,maxd[id]);
	}
}
void get_dfn(int id,int from,int root)
{
	dfn[id]=++cnt;
	top[id]=root;
	if(dson[id]) get_dfn(dson[id],id,root);
	for(int i=0;i<vec[id].size();i++)
	{
		int to=vec[id][i];
		if(to==from||to==dson[id]) continue;
		pos[to]=cnt2+maxd[to];
		cnt2+=maxd[to]*2;
		get_dfn(to,id,to);
	}
}
int ans;
#define f(i,j) f[dfn[i]+j]
#define g(i,j) g[pos[top[i]]-dep[i]+dep[top[i]]+j]
void solve(int id,int from)
{
	f(id,0)=1;
	int tot=0;
	if(dson[id]) solve(dson[id],id);
	for(int i=0;i<vec[id].size();i++)
	{
		int to=vec[id][i];
		if(to==from||to==dson[id]) continue;
		solve(to,id);
		g(id,0)+=g(to,1);
		for(int j=1;j<=maxd[to];j++) 
		{
			tot+=(j<maxd[to])*g(to,j+1)*f(id,j)+g(id,j)*f(to,j-1);
			g(id,j)+=(j<maxd[to])*g(to,j+1)+f(id,j)*f(to,j-1);
			f(id,j)+=f(to,j-1);
		}
	}
	tot+=g(id,0);
	//cout<<id<<" find: "<<tot<<endl;
	ans+=tot;
}
signed main()
{
	int n;
	//freopen("data.txt","r",stdin);
	cin>>n;
	for(int i=1;i<n;i++) 
	{
		int a,b;
		scanf("%lld%lld",&a,&b);
		vec[a].push_back(b);
		vec[b].push_back(a);
	}
	get_maxd(1,0);
	pos[1]=maxd[1];
	cnt2=maxd[1]*2;
	get_dfn(1,0,1);
	solve(1,0);
	cout<<ans;
}

  

posted @ 2020-02-06 16:34  linzhuohang  阅读(179)  评论(0编辑  收藏  举报