bzoj 4543: [POI2014]Hotel加强版

Description

给出一棵树求三元组 \((x,y,z)\,,x<y<z\) 满足三个点两两之间距离相等,求三元组的数量

Solution

考虑暴力 \(DP\)
\(f[i][j]\) 表示距离点 \(i\) 的子树内距离为 \(j\) 的点的数量
\(g[i][j]\) 表示 \(i\) 子树内的一个二元组 \((x,y)\) 满足 \(dis(x,lca)=dis(y,lca)=dis(i,lca)+j\) 的二元组的数量,可以视为等待与子树外合并的二元组的数量
显然有转移:
\(ans+=g[x][0]\)
\(ans+=f[son][j]*g[x][j+1]+f[x][j-1]*g[son][j]\)

\(f[x][j]=f[son][j-1]\)
\(g[x][j]=g[son][j+1]\)
\(g[x][j]=f[x][j]*f[son][j-1]\)

这样转移是 \(O(n^2)\)
对于深度为下标的树形 \(DP\),考虑长链剖分优化:
\(DP\) 转移之前, \(f[x],g[x]\) 是没有值的,初值要设为某个儿子的 \(DP\)
并且这个时候的转移仅仅是 \(f[x][j]=f[son][j-1]\),\(g[x][j]=g[son][j+1]\)
相当于一个数组位移,直接用指针优化,可以做到 \(O(1)\),所以用所在链最长的儿子来赋初值复杂度最优,其余儿子暴力转移
这样做复杂度均摊就是 \(O(n)\) 的了,一条链只会在链顶被枚举到,且链不相交,所以每个点只会被枚一次
空间对于每一个链动态开空间,空间复杂度也是 \(O(n)\)

#include<bits/stdc++.h>
using namespace std;
template<class T>void gi(T &x){
	int f;char c;
	for(f=1,c=getchar();c<'0'||c>'9';c=getchar())if(c=='-')f=-1;
	for(x=0;c<='9'&&c>='0';c=getchar())x=x*10+(c&15);x*=f;
}
typedef long long ll;
const int N=1e5+10;
int n,head[N],nxt[N*2],to[N*2],num=0,dep[N],mx[N];
inline void link(int x,int y){nxt[++num]=head[x];to[num]=y;head[x]=num;}
ll lis[N*5],*f[N],*g[N],*st=lis+5,ans=0;
inline void dfs(int x,int fa){
	mx[x]=x;
	for(int i=head[x],u;i;i=nxt[i]){
		if((u=to[i])==fa)continue;
		dep[u]=dep[x]+1;dfs(u,x);
		if(dep[mx[u]]>dep[mx[x]])mx[x]=mx[u];
	}
	for(int i=head[x],u;i;i=nxt[i]){
		if((u=to[i])==fa || (mx[u]==mx[x] && x!=1))continue;
		st+=dep[mx[u]]-dep[x]+1;
		f[mx[u]]=st;
		g[mx[u]]=++st;
		st+=(dep[mx[u]]-dep[x])*2+1;
	}
}
inline void dfs2(int x,int fa){
	for(int i=head[x],u;i;i=nxt[i]){
		if((u=to[i])==fa)continue;
		dfs2(u,x);
		if(mx[u]==mx[x])f[x]=f[u]-1,g[x]=g[u]+1;
	}
	f[x][0]=1;ans+=g[x][0];
	for(int i=head[x],u;i;i=nxt[i]){
		if((u=to[i])==fa || mx[u]==mx[x])continue;
		for(int j=dep[mx[u]]-dep[x];j>=0;j--)
			ans+=g[x][j+1]*f[u][j]+g[u][j]*(j?f[x][j-1]:0);
		for(int j=dep[mx[u]]-dep[x];j>=0;j--){
			f[x][j+1]+=f[u][j];
			g[x][j]+=g[u][j+1];
			g[x][j]+=f[x][j]*(j?f[u][j-1]:0);
		}
	}
}
int main(){
  freopen("pp.in","r",stdin);
  freopen("pp.out","w",stdout);
  cin>>n;
  int x,y;
  for(int i=1;i<n;i++)gi(x),gi(y),link(x,y),link(y,x);
  dfs(1,1);dfs2(1,1);
  cout<<ans<<endl;
  return 0;
}

posted @ 2018-05-02 22:48  PIPIBoss  阅读(419)  评论(0编辑  收藏  举报