题解 CF2063E Triangle Tree
年级里面有大佬用神奇启发式合并写的,还好我不会这么复杂的算法。
题意
给定一棵 \(n\) 个点的树。
\(\operatorname{dist}(u,v)\) 定义为从 \(u\) 到 \(v\) 的唯一简单路径上的边数,\(\operatorname{lca}(u,v)\) 表示 \(u\) 和 \(v\) 的最近公共祖先,
设函数 \(f(u,v)\) 表示:
- 当 \(u\) 不为 \(v\) 的祖先,且 \(v\) 不为 \(u\) 的祖先时,存在多少个整数 \(x\) 使得边长为 \(\operatorname{dist}(u,\operatorname{lca}(u,v))\) 、 \(\operatorname{dist}(v,\operatorname{lca}(u,v))\)、\(x\) 能成为一个三角形。
- 否则函数值为 \(0\)。
最后需要求出:
\[\sum_{i = 1}^{n-1} \sum_{j = i+1}^n f(i,j)
\]
分析
根据三角形边长的限制,可以得到:
\[\vert \operatorname{dist}(u,\operatorname{lca}(u,v)) - \operatorname{dist}(v,\operatorname{lca}(u,v)) \vert < x < \operatorname{dist}(u,\operatorname{lca}(u,v)) + \operatorname{dist}(v,\operatorname{lca}(u,v))
\]
简单来说两边之和大于第三边,两边之差小于第三边。
把距离函数拆成深度,设 \(d_u\) 为 \(u\) 的深度,\(lca = \operatorname{lca}(u,v)\),那么以上式子可以表示为:
\[\vert d_u - d_v \vert < x < d_u + d_v - 2d_{lca}
\]
然后就可以表示出 \(f(u,v)\):
\[f(u,v) = (d_u + d_v - 2d_{lca}) - \vert d_u -d_v \vert - 1
\]
后面绝对值这一部分是 [ABC186D] Sum of difference,可以把每个深度的个数加入到一个桶中,统计有多少个数比它小或大。
\(d_u\) 和 \(d_v\) 的求和也是简单的,对于每个点 \(u\) 会有 \(n-1\) 个询问和他有关,所以对答案的贡献即为 \((n-1)d_u\)。
现在考虑这个 \(-2d_{lca}\) 如何处理,需要求出有多少对点的最近公共祖先为 \(lca\),显然在 \(lca\) 两个不同子树内的任意点的最近公共祖先都是它。
设 \(sz_u\) 表示 \(u\) 的子树大小,\(son_u\) 表示 \(u\) 的儿子集合,存在点对数即为:
\[\dfrac 1 2 \sum\limits_{u \in son_{lca}} \sum\limits_{v \in son_{lca} \land v \neq u} sz_u \times sz_v
\]
这个东西可以前缀和优化做到线性。
最后对于每一个询问还需要减 \(1\),然后会发现对于\(u\) 为 \(v\) 的祖先或 \(v\) 为 \(u\) 的祖先的情况不应该多减,所以加上这一部分就好。
时间复杂度和空间复杂度均为 \(O(n)\)。
代码
//the code is from chenjh
#include<bits/stdc++.h>
#define MAXN 300003
using namespace std;
typedef long long LL;
int n;
vector<int> G[MAXN];
LL ans1=0,ans2=0,ans3=0;
int dep[MAXN],sz[MAXN],a[MAXN];
void dfs(const int u,const int FA){
++a[dep[u]=dep[FA]+1],sz[u]=1;
ans1+=(n-1ll)*dep[u];
int x=0;
LL y=0;//前缀和统计 LCA 为当前点的点对个数。
for(const int v:G[u])if(v!=FA){
dfs(v,u),sz[u]+=sz[v];
y+=(LL)sz[v]*x,x+=sz[v];
}
ans1-=2*(y+sz[u]-1)*dep[u],ans3+=sz[u]-1;//减去 2d_{lca},排除祖先的情况。
}
void solve(){
scanf("%d",&n);
for(int i=1;i<=n;i++) G[i].clear(),a[i]=0;
for(int i=1,u,v;i<n;i++){
scanf("%d%d",&u,&v);
G[u].push_back(v),G[v].push_back(u);
}
ans1=ans2=ans3=0;
dfs(1,0);
for(int i=1,x=0;i<=n;x+=a[i++]) ans2+=(LL)i*a[i]*(x-(n-a[i]-x));//处理绝对值部分。
printf("%lld\n",ans1-ans2+ans3-n*(n-1ll)/2);
}
int main(){
int T;scanf("%d",&T);
while(T--) solve();
return 0;
}