JZOJ 6866. 【2020.11.16提高组模拟】路径大小差(点分治+树状数组)
JZOJ 6866. 【2020.11.16提高组模拟】路径大小差
题目大意
- 问树上有多少点对之间路径边权 m a x − m i n = k max-min=k max−min=k, k k k为定值。
- k ≤ n ≤ 2 ∗ 1 0 5 k\leq n\leq2*10^5 k≤n≤2∗105.
题解
- 其实这题比较套路,并不难想。
- 关于树上路径计数的问题,一般先考虑点分治能不能实现,发现是可以的。
- 按照一般点分治的套路,找到某个子树重心后,记录每个点到它的路径边权 m a x , m i n max,min max,min,有两种情况,一种是重心为路径的一端,直接枚举判断;另一种是重心在路径中间。
- 第二种情况,按 m a x max max从小到大排序,枚举一条路径和前面的另一条组合,
- 因为已经排好序了,所以 m a x max max一定在当前这条路径上,接着再分两种情况,一种是该路径的 m a x − m i n < k max-min<k max−min<k,那么查找前面 m i n = m a x − k min=max-k min=max−k的数量加入答案;一种是该路径的 m a x − m i n = k max-min=k max−min=k,则查找前面 m i n ≥ m a x − k min\geq max-k min≥max−k的数量加入答案。用树状数组维护。
- 但是会发现组合的两条路径可能出现在当前根的同一子树中,那么把每棵子树的路径单独求一遍,从答案中减去即可。
代码
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define ll long long
#define N 200010
int n, K;
ll ans = 0;
int last[N], nxt[N * 2], to[N * 2], we[N * 2], len = 0;
int vi[N], si[N], sum[N], s, rt, mi;
int tot = 0, f[N];
struct node {
int mx, mi, r;
}a[N];
void add(int x, int y, int w) {
to[++len] = y;
we[len] = w;
nxt[len] = last[x];
last[x] = len;
}
void dfs(int k, int fa) {
si[k] = 1;
for(int i = last[k]; i; i = nxt[i]) if(to[i] != fa && !vi[to[i]]) {
dfs(to[i], k);
si[k] += si[to[i]];
}
}
void find(int k, int fa) {
int mx = s - si[k];
for(int i = last[k]; i; i = nxt[i]) if(to[i] != fa && !vi[to[i]]) {
find(to[i], k);
mx = max(mx, si[to[i]]);
}
if(mx < mi) mi = mx, rt = k;
}
void dfs1(int k, int fa, int t0, int t1, int r) {
if(t1) a[++tot].mx = t1, a[tot].mi = t0, a[tot].r = r;
for(int i = last[k]; i; i = nxt[i]) if(to[i] != fa && !vi[to[i]]) {
dfs1(to[i], k, min(t0, we[i]), max(t1, we[i]), r == 0 ? to[i] : r);
}
}
int cmp(node x, node y) {
if(x.mx == y.mx) return x.mi < y.mi;
return x.mx < y.mx;
}
int cmp1(node x, node y) {
return x.r < y.r;
}
int low(int x) {
return x & (-x);
}
void ins(int k, int c) {
for(int i = k; i <= n; i += low(i)) f[i] += c;
}
int ct(int k) {
int s = 0;
for(int i = k; i; i -= low(i)) s += f[i];
return s;
}
void ds(int l, int r, int o) {
sort(a + l, a + r + 1, cmp);
for(int i = l; i <= r; i++) {
if(a[i].mx - a[i].mi == K) {
ans += (i - l - ct(a[i].mi - 1)) * o;
}
else if(a[i].mx - a[i].mi < K) ans += sum[a[i].mx - K] * o;
sum[a[i].mi]++;
ins(a[i].mi, 1);
}
for(int i = l; i <= r; i++) sum[a[i].mi]--, ins(a[i].mi, -1);
}
void calc(int k) {
tot = 0;
dfs1(k, 0, n + 1, 0, 0);
sort(a + 1, a + tot + 1, cmp);
for(int i = 1; i <= tot; i++) if(a[i].mx - a[i].mi == K) ans++;
ds(1, tot, 1);
sort(a + 1, a + tot + 1, cmp1);
int la = 1;
for(int i = 1; i <= tot; i++) {
if(i == tot || a[i].r != a[i + 1].r) {
ds(la, i, -1);
la = i + 1;
}
}
}
void solve(int k) {
dfs(k, 0);
s = si[k], mi = n + 1;
find(k, 0);
calc(rt);
vi[rt] = 1;
for(int i = last[rt]; i; i = nxt[i]) if(!vi[to[i]]) solve(to[i]);
}
int main() {
int i, x, y, w;
scanf("%d%d", &n, &K);
for(i = 1; i < n; i++) {
scanf("%d%d%d", &x, &y, &w);
add(x, y, w), add(y, x, w);
}
solve(1);
printf("%lld\n", ans);
return 0;
}
哈哈哈哈哈哈哈哈哈哈

浙公网安备 33010602011771号