P5327 [ZJOI2019] 语言 题解

题意:

有一个 \(n\) 个节点的树,每个点上有一个集合,有 \(m\) 次操作,第 \(i\) 次操作将 \(u_i\)\(v_i\) 上所有集合加入 \(i\)。求最后有多少 \(x < y\) 满足 \(x\)\(y\) 路径上的所有集合的交不为空。

\(n, m \le 10^5\)

思路:

考虑固定 \(x\),求所有的 \(y\),先不考虑大小限制,最后统一除以二即可。

我们观察可以得到结论:\(x\) 能到达的所有点 \(y\) 组成了一个连通块。

进一步,我们考虑 \(x\) 对于包含的某种元素来说,能到的最远点,这些最远点的最小生成树等价于到达所有点的连通块。

直觉上都比较显然。

然后沿用 P3320 [SDOI2015] 寻宝游戏 的做法,我们知道,\(a_1 \sim a_m\) 假设是 dfn 排完序的若干个点,他们的最小生成树的边的数量就是 \(d(a_1, a_2) + d(a_2, a_3) + \dots + d(a_{k - 1}, a_k) + d(a_k, a_1)\) 再除以二。

由于最小生成树一定包含 \(x\)\((x, x)\) 是不合法的,有贡献的大小就是点数减一,也就是边数。

现在考虑如何求这个东西。

我们将所有包含 \(x\) 的路径 \((u,v)\)\(x\) 的点集中加入 \(u\)\(v\),最终只用求这个点集中的生成树大小即可。

我们可以将 dfn 序拍平成线段树,这个式子显然可以合并,记录一个区间最小和最大的 dfn 以及这个区间的答案即可 pushup。

现在考虑优化。

将树上路径转化成树上差分,只需要将儿子的继承给父节点即可。这让我们联想到线段树合并,于是直接线段树合并即可。用 dfn 序配合 st 表 \(O(1)\) 求 lca, 时间复杂度 \(O(n \log n)\)

点击查看代码
#include <iostream>
#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 1e5 + 5;

int n, m;
vector<int> e[N];
int dfn[N] = {0}, rnk[N] = {0}, p[N] = {0};
int st[N][22] = {{0}}, cur = 0;
int logN[N] = {0}; 
int dth[N] = {0};

void dfs(int x, int pr, int d) {
	st[(dfn[x] = ++cur)][0] = p[x] = pr, rnk[cur] = x, dth[x] = d;
	for (auto i: e[x])
		if (i != pr)
			dfs(i, x, d + 1);
}
int get(int x, int y) {
	return dfn[x] < dfn[y] ? x : y;
}
void init() {
	dfs(1, 0, 0);
	logN[1] = 0;
	for (int i = 2; i <= n; i++)
		logN[i] = logN[i / 2] + 1;
	for (int j = 1; (1 << j) <= n; j++)
		for (int i = 1; i + (1 << j) - 1 <= n; i++)
			st[i][j] = get(st[i][j - 1], st[i + (1 << (j - 1))][j - 1]);
}
int qry(int x, int y) {
	if (x == y)
		return x;
	x = dfn[x], y = dfn[y];
	if (x > y)
		swap(x, y);
	int k = logN[y - x];
	return get(st[x + 1][k], st[y - (1 << k) + 1][k]);
}
int dis(int x, int y) {
	if (x == 0 || y == 0)	
		return 0;
	return dth[x] + dth[y] - 2 * dth[qry(x, y)];
}

int tot = 0;
struct Node {
	int ls, rs;
	int cnt, mn, mx, ans;//累计次数,最小的和最大的 
	Node () {
		ls = rs = 0;
		mn = mx = cnt = ans = 0;
	}
} a[N << 6];
int rt[N] = {0};
struct SegTree {
	#define mid ((lx + rx) >> 1)
	void pushup(int x) {
		a[x].cnt = a[a[x].ls].cnt + a[a[x].rs].cnt;
		a[x].ans = a[a[x].ls].ans + a[a[x].rs].ans + dis(a[a[x].ls].mx, a[a[x].rs].mn);
		a[x].mn = (a[a[x].ls].mn == 0 ? a[a[x].rs].mn : a[a[x].ls].mn);
		a[x].mx = (a[a[x].rs].mx == 0 ? a[a[x].ls].mx : a[a[x].rs].mx);
	}
	void mdf(int &x, int lx, int rx, int pos, int v) {
		if (x == 0)
			x = ++tot;
		if (lx + 1 == rx) {
			a[x].cnt += v;
			if (a[x].cnt == 0) 
				a[x].ans = a[x].mn = a[x].mx = 0;
			else
				a[x].ans = 0, a[x].mn = a[x].mx = rnk[lx];
			return;
		}
		(pos < mid) ? mdf(a[x].ls, lx, mid, pos, v) : mdf(a[x].rs, mid, rx, pos, v);
		pushup(x);
	}
	int mrg(int lx, int rx, int x, int y) {
		if (x == 0 || y == 0)
			return x + y;
		if (lx + 1 == rx) {
			a[x].cnt += a[y].cnt;
			if (a[x].cnt == 0)	
				a[x].ans = a[x].mn = a[x].mx = 0;
			else
				a[x].ans = 0, a[x].mn = a[x].mx = rnk[lx];
			return x;
		}
		a[x].ls = mrg(lx, mid, a[x].ls, a[y].ls);
		a[x].rs = mrg(mid, rx, a[x].rs, a[y].rs);
		pushup(x);
		return x;
	}
	SegTree () {}
	#undef mid
} sgt;

long long ans = 0ll;

void getans(int x) {
	for (auto i: e[x])
		if (i != p[x]) {
			getans(i);
			rt[x] = sgt.mrg(1, n + 1, rt[x], rt[i]);
		}
	sgt.mdf(rt[x], 1, n + 1, dfn[x], 1);
//	printf("the size of %d: %d  Ans: %d  Mx:%d Mn:%d\n", x, (a[rt[x]].ans + dis(a[rt[x]].mn, a[rt[x]].mx)) / 2, a[rt[x]].ans, a[rt[x]].mn, a[rt[x]].mx);
	ans += (a[rt[x]].ans + dis(a[rt[x]].mn, a[rt[x]].mx)) / 2;
	sgt.mdf(rt[x], 1, n + 1, dfn[x], -1);
} 

int main() {
//	freopen("test.in", "r", stdin);
//	freopen("my.out", "w", stdout);
	cin >> n >> m;
	for (int i = 1, u, v; i < n; i++) {
		cin >> u >> v;
		e[u].push_back(v);
		e[v].push_back(u); 
	}
	init();
		
	for (int i = 1, x, y; i <= m; i++) {
		cin >> x >> y;
		int lca = qry(x, y);
		sgt.mdf(rt[x], 1, n + 1, dfn[x], 1);
		sgt.mdf(rt[x], 1, n + 1, dfn[y], 1);
		
		sgt.mdf(rt[y], 1, n + 1, dfn[x], 1);
		sgt.mdf(rt[y], 1, n + 1, dfn[y], 1);
		
		sgt.mdf(rt[lca], 1, n + 1, dfn[x], -1);
		sgt.mdf(rt[lca], 1, n + 1, dfn[y], -1);
		
		if (p[lca] == 0)
			continue;
		
		sgt.mdf(rt[p[lca]], 1, n + 1, dfn[x], -1);
		sgt.mdf(rt[p[lca]], 1, n + 1, dfn[y], -1);
	}
	getans(1);
	cout << ans / 2 << endl;
	return 0;
}

*/
posted @ 2024-03-12 22:44  rlc202204  阅读(31)  评论(0)    收藏  举报