Samsara
题解
原题:[北大集训2021] 小明的树
定义两个端点都未点亮的边为黑边,一个端点点亮一个端点为点亮的边为黑白边,两个端点都点亮的边为白边。
考虑进行转化,发现当树的的黑边构成一个连通块时这棵树一定合法。
然后考虑经典结论,连通块个数等于 \(点数\) \(-\) \(边数\) ,而黑点的个数又等于 \(n - 操作次数\),则当树合法时 \(n - 操作次数 - 黑边个数 = 1\)。
再考虑何时合法,此时连通块个数等于黑白边的个数。
然后考虑如何维护,这种问题肯定想到线段树。
我们以时间为轴来建立线段树,接下来我们假定一条边的两个端点的点亮时间分别为 \(t_x\) 和 \(t_y\) 且 \(t_x < t_y\)。
在一开始,黑边的贡献为 \(1\)。
当点 \(x\) 被点亮时,黑白边的个数加 \(1\)。
当点 y 被点亮时,白边的贡献为 \(1\)。
然后我们拿线段树维护下最小值,最小值的个数和为最小值的点的权值和就行了。
代码
#include <cstdio>
#include <algorithm>
using namespace std;
const int N = 5e5 + 10;
int n, q, b[N];
pair<int, int> e[N];
struct SegTree { int l, r, sum, min, tag[2], cnt; } t[N * 4];
void pushup(int x) {
t[x].sum = t[x * 2].sum + t[x * 2 + 1].sum; t[x].min = min(t[x * 2].min, t[x * 2 + 1].min);
t[x].cnt = 0, t[x].sum = 0;
if (t[x].min == t[x * 2].min) t[x].cnt += t[x * 2].cnt, t[x].sum += t[x * 2].sum;
if (t[x].min == t[x * 2 + 1].min) t[x].cnt += t[x * 2 + 1].cnt, t[x].sum += t[x * 2 + 1].sum;
}
void apply(int x, int k1, int k2) { t[x].tag[0] += k1, t[x].tag[1] += k2, t[x].min += k1, t[x].sum += t[x].cnt * k2; }
void pushdown(int x) {
if (t[x].tag[0] || t[x].tag[1]) {
apply(x * 2, t[x].tag[0], t[x].tag[1]), apply(x * 2 + 1, t[x].tag[0], t[x].tag[1]);
t[x].tag[0] = 0, t[x].tag[1] = 0;
}
}
void build(int x, int l, int r) {
t[x].l = l, t[x].r = r, t[x].tag[0] = 0, t[x].tag[1] = 0, t[x].sum = 0;
if (l == r) return t[x].cnt = 1, t[x].min = l == n ? n + 1 : n - l, void();
int mid = (l + r) >> 1;
build(x * 2, l, mid);
build(x * 2 + 1, mid + 1, r);
pushup(x);
}
void update(int x, int ql, int qr, int k1, int k2) {
if (ql > qr) return;
if (ql <= t[x].l && t[x].r <= qr) return apply(x, k1, k2);
pushdown(x);
int mid = (t[x].l + t[x].r) >> 1;
if (ql <= mid) update(x * 2, ql, qr, k1, k2);
if (mid < qr) update(x * 2 + 1, ql, qr, k1, k2);
pushup(x);
}
void modify(int x, int y, int k) {
if (x > y) swap(x, y);
// printf("x: %d y: %d\n", x, y);
update(1, 1, x - 1, - k, 0);
// printf("add val [%d, %d] k: %d\n", x, y - 1, k);
update(1, x, y - 1, 0, k);
}
int main() {
scanf("%d%d", &n, &q);
for (int i = 1; i < n; i ++) scanf("%d%d", &e[i].first, &e[i].second);
for (int i = 1, x; i < n; i ++) scanf("%d", &x), b[x] = i;
b[1] = n;
build(1, 1, n);
for (int i = 1; i < n; i ++) modify(b[e[i].first], b[e[i].second], 1);
printf("%d\n", (t[1].min == 1) * t[1].sum);
while (q --) {
int x_1, y_1, x_2, y_2; scanf("%d%d%d%d", &x_1, &y_1, &x_2, &y_2);
modify(b[x_1], b[y_1], -1), modify(b[x_2], b[y_2], 1);
printf("%d\n", (t[1].min == 1) * t[1].sum);
}
return 0;
}
浙公网安备 33010602011771号