CF1458F Range Diameter Sum【树剖,分治】

传送门

题目描述:给定 \(n\) 个点的树,求所有编号区间的直径之和。

数据范围:\(n\le 10^5\)


定义 \(C(v,r)\) 表示与 \(v\) 距离 \(\le r\) 的点。

对于点集 \(S\),定义 \(\text{Cov}(S)\) 表示覆盖 \(S\)\(C(v,r)\)\(r\) 最小的,则有 \(r\) 是点集半径,\(v\) 是直径中点。

\(v\) 是边中点的情况较难处理,于是在原树上每条边中点建一个虚点。

然后就要考虑合并点集,给定 \(\text{Cov}(S)=C(v_S,r_S),\text{Cov}(T)=C(v_T,r_T)\)

  • \(\text{Cov(S)}\subseteq\text{Cov}(T)\) 时,也即 \(\text{dis}(v_S,v_T)\le r_T-r_S\),则有 \(\text{Cov}(S\cup T)=\text{Cov}(T)\)
  • \(\text{Cov}(T)\subseteq\text{Cov}(S)\) 时同理。
  • 否则 \(\text{Cov}(S\cup T)=(V,R)\),其中 \(R=(r_S+r_T+\text{dis}(v_S,v_T))/2\)\(V\)\(v_S\rightarrow v_T\) 路径上与 \(v_S\) 相距 \(R-r_S\) 的点。

证明不会,感性理解一下或者看官方题解吧

计算答案可以分治,设当前在做 \([L,R]\) 这段区间,计算 \(l\in[L,mid],r\in(mid,R]\) 的答案之和。

首先预处理出 \(\text{Cov}([l,mid])\)\(\text{Cov}((mid,r])\) 和半径的前缀和。

枚举 \(i=mid\rightarrow L\),三种情况(\(\supseteq\),严格相交,\(\subseteq\))分别占三段区间且分界点单调不降,使用 two-pointer 方法,前后两段都很好算,中间那一段需要求类似 \(\sum_{u\in S}\text{dis}(v,u)\) 的东西,支持 \(S\) 的加/删点,把 \(\text{dis}(v,u)\) 拆成 \(dep_u+dep_v-2dep_{\text{lca}(u,v)}\),再多维护个 \(dep\) 前缀和,\(\sum dep(lca)\) 是经典套路,用树剖 + BIT 维护链加/链求和即可。LCA 和 kth-anc 都可以用树剖求。时间复杂度 \(O(n\log^2n)\)

#include<bits/stdc++.h>
#define PB emplace_back
#define MP make_pair
#define fi first
#define se second
using namespace std;
typedef long long LL;
typedef pair<int, int> pii;
const int N = 222222;
template<typename T>
void read(T &x){
	int ch = getchar(); x = 0; bool f = false;
	for(;ch < '0' || ch > '9';ch = getchar()) f |= ch == '-';
	for(;ch >= '0' && ch <= '9';ch = getchar()) x = x * 10 + ch - '0';
	if(f) x = -x;
} template<typename T>
bool chmax(T &a, const T &b){if(a < b) return a = b, 1; return 0;}
template<typename T>
bool chmin(T &a, const T &b){if(a > b) return a = b, 1; return 0;}
int n, cnt, tim, head[N], to[N<<1], nxt[N<<1], dfn[N], pre[N], siz[N], dep[N], fa[N], top[N], wson[N];
void add(int u, int v){
	to[++cnt] = v; nxt[cnt] = head[u]; head[u] = cnt;
	to[++cnt] = u; nxt[cnt] = head[v]; head[v] = cnt;
}
void dfs1(int x){ siz[x] = 1;
	for(int i = head[x];i;i = nxt[i]) if(to[i] != fa[x]){
		fa[to[i]] = x; dep[to[i]] = dep[x] + 1;
		dfs1(to[i]); siz[x] += siz[to[i]];
		if(siz[to[i]] > siz[wson[x]]) wson[x] = to[i];
	}
}
void dfs2(int x, int tp){
	top[x] = tp; dfn[x] = ++tim; pre[tim] = x;
	if(wson[x]){ dfs2(wson[x], tp);
		for(int i = head[x];i;i = nxt[i])
			if(to[i] != wson[x] && to[i] != fa[x])
				dfs2(to[i], to[i]);
	}
}
int lca(int u, int v){
	while(top[u] != top[v]){
		if(dep[top[u]] < dep[top[v]]) swap(u, v);
		u = fa[top[u]];
	} return dep[u] < dep[v] ? u : v;
} int dis(int u, int v){return dep[u] + dep[v] - (dep[lca(u,v)]<<1);}
int jump(int u, int k){
	while(dep[u] - dep[top[u]] < k){
		k -= dep[u] - dep[top[u]] + 1; u = fa[top[u]];
	} return pre[dfn[u] - k];
}
struct BIT {
	LL tr[N];
	void upd(int p, LL v){for(;p <= tim;p += p & -p) tr[p] += v;}
	LL qry(int p){LL res = 0; for(;p;p -= p & -p) res += tr[p]; return res;}
} t1, t2;
void upd(int p, LL v){t1.upd(p, v); t2.upd(p, v * p);}
void upd(int l, int r, LL v){upd(l, v); upd(r+1, -v);}
LL qry(int p){return (p+1) * t1.qry(p) - t2.qry(p);}
LL qry(int l, int r){return qry(r) - qry(l-1);}
void updt(int u, LL v){while(u){upd(dfn[top[u]], dfn[u], v); u = fa[top[u]];}}
LL qryt(int u){LL res = 0; while(u){res += qry(dfn[top[u]], dfn[u]); u = fa[top[u]];} return res;}
bool in(pii c1, pii c2){return c1.se <= c2.se && dis(c1.fi, c2.fi) <= c2.se - c1.se;}
pii merge(pii c1, pii c2){
	int u = c1.fi, v = c2.fi, r1 = c1.se, r2 = c2.se;
	int uho = lca(u, v), len = dep[u] + dep[v] - (dep[uho]<<1);
	if(len <= r1 - r2) return c1;
	if(len <= r2 - r1) return c2;
	int rad = r1 + r2 + len >> 1;
	if(dep[u] - dep[uho] >= rad - r1) return MP(jump(u, rad - r1), rad);
	return MP(jump(v, rad - r2), rad);
} pii circ[N]; LL sd[N], sr[N];
LL calc(int L, int R){
	if(L == R) return 0; int mid = L+R>>1;
	circ[mid+1] = MP(mid+1, 0);
	for(int i = mid+2;i <= R;++ i) circ[i] = merge(circ[i-1], MP(i, 0));
	circ[mid] = MP(mid, 0);
	for(int i = mid-1;i >= L;-- i) circ[i] = merge(circ[i+1], MP(i, 0));
	sd[mid] = sr[mid] = 0;
	for(int i = mid+1;i <= R;++ i){
		sd[i] = sd[i-1] + dep[circ[i].fi];
		sr[i] = sr[i-1] + circ[i].se;
	} LL ans = 0; int p1 = mid+1, p2 = mid; // (mid, p1) sup, [p1, p2] inter, (p2, R] sub
	for(int i = mid;i >= L;-- i){
		while(p1 <= R && in(circ[p1], circ[i])){
			updt(circ[p1].fi, -1); ++ p1;
		} while(p2 < p1-1 || (p2 < R && !in(circ[i], circ[p2+1]))){
			++ p2; updt(circ[p2].fi, 1);
		} ans += (LL)circ[i].se * (p1-mid-1) + sr[R] - sr[p2];
		if(p1 <= p2) ans += (((p2-p1+1ll)*(dep[circ[i].fi]+circ[i].se+2) + sd[p2] - sd[p1-1] + sr[p2] - sr[p1-1]) >> 1) - qryt(circ[i].fi);
	} while(p1 <= p2){updt(circ[p1].fi, -1); ++ p1;}
	return ans + calc(L, mid) + calc(mid+1, R);
}
int main(){ read(n);
	for(int i = 1, u, v;i < n;++ i){
		read(u); read(v); add(u, n+i); add(v, n+i);
	} dfs1(1); dfs2(1, 1); printf("%lld\n", calc(1, n));
}
posted @ 2021-03-31 17:13  mizu164  阅读(104)  评论(0编辑  收藏  举报