QOJ7206 Triple

QOJ 传送门

大分讨恶心题。

首先施容斥,变成求 \(\sum |AB| > \max(|AC|, |BC|)\)

遇到这种三个点的路径问题,可以找出一个点 \(X\),使得 \(A, B, C\)\(X\) 的不同子树内,也就是 \(A \to B, A \to C, B \to C\) 的路径的唯一一个交点 \(X\)。那么:

\[[|AB| > \max(|AC|, |BC|)] = [|AX| + |BX| > \max(|AX| + |CX|, |BX| + |CX|)] = [\min(|AX|, |BX|) > |CX|] \]

考虑先点分治,考虑对于一个分治中心 \(R\),计算 \(A, B, C\) 不是都在同一棵子树的方案。为了方便认为 \(R\) 是单独的一棵子树。

\(A, B, C\) 都不在同一棵子树,考虑直接枚举 \(C\),对子树每个点的深度预处理一个后缀和 \(F_d\) 和在每棵子树选两个深度 \(\ge d\) 的点的方案数 \(G_d\),为了容斥算出 \(A, B\) 不在同一棵子树的方案数。需要去除 \(C\) 所在子树对 \(F, G\) 的贡献。

\(A, C\) 在同一棵子树(\(B, C\) 一样,贡献 \(\times 2\)),考虑枚举 \(X\)\(|CX|\),那么要求 \(|AX| > |CX|\),可以用长剖计算这样的对数。还要求 \(|BX| > |CX|\),即 \(|BR| + |RX| > |CX|\),即 \(|BR| \ge |CX| - |RX| + 1\)。这部分贡献可以用上部分处理的 \(F\) 数组。

\(A, B\) 在同一棵子树,考虑直接枚举 \(X\)\(\min(|AX|, |BX|)\),也可以用长剖计算 \(X\) 子树内 \(A, B\) 的对数使得 \(\min(|AX|, |BX|) = k\)。还要求 \(|CX| < k\),即 \(|CR| \le k - |RX| - 1\)。这部分贡献也可以用第一部分处理的 \(F\) 数组。

于是总时间复杂度为 \(O(n \log n)\)

注意讨论一些 corner case。

code
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;

const int maxn = 100100;

int n;
vector<int> gg[maxn];

int f[maxn], sz[maxn], rt;
bool vis[maxn];

void dfs2(int u, int fa, int t) {
	f[u] = 0;
	sz[u] = 1;
	for (int v : gg[u]) {
		if (v == fa || vis[v]) {
			continue;
		}
		dfs2(v, u, t);
		sz[u] += sz[v];
		f[u] = max(f[u], sz[v]);
	}
	f[u] = max(f[u], t - sz[u]);
	if (!rt || f[u] < f[rt]) {
		rt = u;
	}
}

int D, mxd[maxn], son[maxn], dep[maxn];
vector<int> g[maxn];
// F[d]: 深度 >= d 的点数
// G[d]: 深度 >= d 且在同一棵子树的点对,用于容斥
ll F[maxn], G[maxn], ans;

void dfs3(int u, int fa, int rt) {
	D = max(D, dep[u]);
	if (dep[u] >= (int)g[rt].size()) {
		g[rt].pb(1);
	} else {
		++g[rt][dep[u]];
	}
	mxd[u] = dep[u];
	son[u] = 0;
	for (int v : gg[u]) {
		if (vis[v] || v == fa) {
			continue;
		}
		dep[v] = dep[u] + 1;
		dfs3(v, u, rt);
		if (mxd[v] > mxd[u]) {
			son[u] = v;
			mxd[u] = mxd[v];
		}
	}
}

int tim, h[maxn], id[maxn], p[maxn];

void dfs4(int u, int fa) {
	id[u] = ++tim;
	h[id[u]] = 1;
	p[u] = 1;
	if (son[u]) {
		dfs4(son[u], u);
		p[u] += p[son[u]];
	}
	for (int v : gg[u]) {
		if (v == fa || v == son[u] || vis[v]) {
			continue;
		}
		dfs4(v, u);
		ll s1 = p[u], s2 = 0;
		for (int j = mxd[v] - dep[u]; ~j; --j) {
			s1 -= h[id[u] + j];
		}
		for (int j = mxd[v] - dep[u]; j; --j) {
			// A, C 同子树
			ans += F[max(0, j - dep[u] + 1)] * 2 * (s1 * h[id[v] + j - 1] + s2 * h[id[u] + j]);
			s1 += h[id[u] + j];
			// A, B 同子树
			if (j > dep[u]) {
				ans += (F[0] - F[j - dep[u]]) * 2 * (s1 * h[id[v] + j - 1] + s2 * h[id[u] + j]);
			}
			s2 += h[id[v] + j - 1];
			h[id[u] + j] += h[id[v] + j - 1];
		}
		p[u] += p[v];
	}
	// X = C = u
	ans += 2LL * (p[u] - 1) * F[0];
	// X = R, C = u, A, B 不同子树
	ans += F[dep[u] + 1] * (F[dep[u] + 1] - 1) - G[dep[u] + 1];
}

void dfs(int u) {
	D = 0;
	vis[u] = 1;
	for (int v : gg[u]) {
		if (vis[v]) {
			continue;
		}
		g[v].pb(0);
		dep[v] = 1;
		dfs3(v, u, v);
		for (int i = mxd[v], s = 0; ~i; --i) {
			s += g[v][i];
			F[i] += s;
			G[i] += 1LL * s * (s - 1);
		}
	}
	// X = C = R 且 A, B 不同子树
	ans += F[1] * (F[1] - 1) - G[1];
	++F[0];
	for (int v : gg[u]) {
		if (vis[v]) {
			continue;
		}
		for (int i = mxd[v], s = 0; ~i; --i) {
			s += g[v][i];
			F[i] -= s;
			G[i] -= 1LL * s * (s - 1);
		}
		tim = 0;
		dfs4(v, u);
		for (int i = mxd[v], s = 0; ~i; --i) {
			s += g[v][i];
			F[i] += s;
			G[i] += 1LL * s * (s - 1);
		}
	}
	for (int i = 0; i <= D; ++i) {
		F[i] = G[i] = 0;
	}
	for (int v : gg[u]) {
		if (vis[v]) {
			continue;
		}
		vector<int>().swap(g[v]);
	}
	for (int v : gg[u]) {
		if (vis[v]) {
			continue;
		}
		rt = 0;
		dfs2(v, -1, sz[v]);
		dfs2(rt, -1, sz[v]);
		dfs(rt);
	}
}

void solve() {
	for (int i = 1; i <= n; ++i) {
		vector<int>().swap(gg[i]);
		vis[i] = 0;
	}
	for (int i = 1, u, v; i < n; ++i) {
		scanf("%d%d", &u, &v);
		gg[u].pb(v);
		gg[v].pb(u);
	}
	ans = 0;
	rt = 0;
	dfs2(1, -1, n);
	dfs2(rt, -1, n);
	dfs(rt);
	printf("%lld\n", 1LL * n * (n - 1) * (n - 2) - ans);
}

int main() {
	while (scanf("%d", &n) == 1) {
		solve();
	}
	return 0;
}

posted @ 2024-01-23 17:52  zltzlt  阅读(41)  评论(0)    收藏  举报