Tree
给一个 $n$ 个点的树,点有标号,求有多少个点对 $i,j$ 满足 $i < j,|i - j| \ge \text{dist}(i,j)$ 其中 $\text{dist}(i,j)$ 表示 $i$ 到 $j$ 的距离。
#include<bits/stdc++.h> using namespace std; const int N=4e5+5; int sz[N*50],ch[N*50][2],cnt; void add(int &x,int y,int l,int r) { // 单点 y 加 1 int mid=(l+r)>>1; if(!x)x=++cnt; sz[x]++; if(l==r)return; if(y<=mid)return add(ch[x][0],y,l,mid); else return add(ch[x][1],y,mid+1,r); } int find(int x,int y,int l,int r) { if(y<l) return 0; if(y>r) y=r; int mid=(l+r)>>1; if(!x) return 0; if(l==r) return sz[x]; if(y<=mid) return find(ch[x][0],y,l,mid); else return sz[ch[x][0]]+find(ch[x][1],y,mid+1,r); } int merge(int x,int y,int l,int r) { if(!x||!y)return x|y; sz[x]+=sz[y]; if(l==r)return x; int mid=(l+r)>>1; ch[x][0]=merge(ch[x][0],ch[y][0],l,mid); ch[x][1]=merge(ch[x][1],ch[y][1],mid+1,r); return x; } const int L=-100000,R=200000; int t1[N],t2[N],dep[N]; int head[N],ver[N],nxt[N],tot; void add(int x,int y) { ver[++tot]=y; nxt[tot]=head[x]; head[x]=tot; } int tsmall,tbig; long long ans; void check(int x,int l,int r,int curdep) { if(l==r) { ans+=1ll*sz[x]*find(tbig,2*curdep-l,L,R); return; } int mid=(l+r)>>1; if(ch[x][0])check(ch[x][0],l,mid,curdep); if(ch[x][1])check(ch[x][1],mid+1,r,curdep); } void dfs(int x,int fa) { dep[x]=dep[fa]+1; add(t1[x],x+dep[x],L,R); add(t2[x],-x+dep[x],L,R); for(int i=head[x];i;i=nxt[i]) { int y=ver[i]; if(y==fa)continue; dfs(y,x); tsmall=t1[x];tbig=t2[y]; if(sz[tsmall]>sz[tbig])swap(tsmall,tbig); check(tsmall,L,R,dep[x]); tsmall=t2[x],tbig=t1[y]; if(sz[tsmall]>sz[tbig])swap(tsmall,tbig); check(tsmall,L,R,dep[x]); t1[x]=merge(t1[x],t1[y],L,R); t2[x]=merge(t2[x],t2[y],L,R); } } int n; int main() { scanf("%d",&n); for(int i=1;i<n;i++) { int x,y;scanf("%d%d",&x,&y); add(x,y);add(y,x); } dfs(1,0); printf("%lld\n",ans); return 0; }