HDU-4916 Count on the path

Problem

给定一棵树和 \(m\) 个询问,每个询问要求回答不在 \(x\)\(y\) 两节点所形成的路径上的点的最小标号。

Input

多组数据,EOF结束。

第一行两个整数 \(n\)\(q\)。(\(n,q \le 10^6\))

接下来 \(n-1\) 行,每行两整数,表示树上的一条边。

接下来 \(q\) 行,每行两个整数 \(x\)\(y\)

询问强制在线,\(x, y\) 要与上一次的结果进行异或操作(第一次询问则不操作)

Output

对于每个询问,输出最小不在 \(x\)\(y\) 的路径上的最小编号。

若所有点都在路径上,则输出 \(n\)

Sample

Input 1

4 1
1 2
1 3
1 4
2 3
5 2
1 2
1 3
2 4
2 5
1 2
7 6

Output 1

4 3 1

Solution

首先容易发现一个性质,如果两点间的简单路径不经过 \(1\) 号点,答案就是 \(1\)。我们将 \(1\) 号点设为整棵树的根,然后分情况讨论。

  1. 如果两点不经过 \(1\),说明两点的 \(Lca \neq 1\),由于数据范围较大,直接求 \(Lca\) 可能会 TLE。不难发现两点一定属于 \(1\) 的同一个子树内,于是遍历整棵树的时候记录下每个节点属于 \(1\) 的哪一个子树即可。

  2. 如果两点经过 \(1\),可以考虑将 \(x\)\(y\) 的路径拆分成 \(1\)\(x\)\(1\)\(y\),也就是说我们要求不在这两条路径的点的编号最小值。我们用 \(mn_{x,0/1/2}\) 表示 \(x\) 的子树内,不包括 \(x\) 的最小的三个编号,这三个编号必须来自不同的子树。接着我们将 \(mn\) 数组的含义改为除去从 \(1\)\(x\) 路径上的点的最小的三个值。这在原来的基础上还加上了 \(1\)\(x\) 路径的侧链部分,可以考虑从父亲节点 \(fa\) 转移得到(\(fa\) 节点的答案包含了 \(x\) 点需要加上的侧链)。于是要在 \(x\) 子树内和 \(fa\) 的结果总共 \(6\) 个数中取较小的 \(3\) 个。此外,注意到 \(fa\) 的答案中可能包含 \(x\)子树内的答案,求取前 \(3\) 小时要注意去重。最后的答案只可能出现在 \(x\)\(3\) 个答案和 \(y\) 的三个答案中选出不在 \(x\)\(y\) 路径上的最小值,就判断是否均不为两点的祖先节点(或等于两点)即可。

时间复杂度 \(O(n)\),带常数。

代码:

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>

using namespace std;

const int kmax = 1e6 + 1;
const int kmaxM = 21;

struct E {
    int p, y;
} e[kmax << 1];

int n, m;
int h[kmax], ec;
int mn[kmax][3];
int pr[kmax];
int d[kmax], f[kmax][kmaxM];
int l[kmax], r[kmax], idx;
int lres;

void Getmin(int x, int v) {
    if(v == mn[x][0] || v == mn[x][1] || v == mn[x][2] || v == x) return; // 去重
    if(v < mn[x][0]) {
        mn[x][2] = mn[x][1];
        mn[x][1] = mn[x][0];
        mn[x][0] = v;
    } else if(v < mn[x][1]) {
        mn[x][2] = mn[x][1];
        mn[x][1] = v;
    } else {
        mn[x][2] = min(mn[x][2], v);
    }
}

void Dfs(int x, int fa) {
    l[x] = ++idx;
    d[x] = d[fa] + 1, f[x][0] = fa;
    pr[x] = (fa == 0 ? -1 : (fa == 1 ? x : pr[fa])); // 记录属于1的哪一棵子树
    for(int i = 0; i < 3; i++) {
        mn[x][i] = n;
    }
    for(int i = 1; i < kmaxM; i++) {
        f[x][i] = f[f[x][i - 1]][i - 1];
    }
    for(int i = h[x]; i; i = e[i].p) {
        int y = e[i].y;
        if(y == fa) continue;
        Dfs(y, x);
        Getmin(x, min(y, mn[y][0]));
    }
    r[x] = idx;
}

void Dfss(int x, int fa) {
    for(int i = 0; i < 3; i++) {
        Getmin(x, mn[fa][i]); // 从父亲节点转移
    }
    for(int i = h[x]; i; i = e[i].p) {
        int y = e[i].y;
        if(y == fa) continue;
        Dfss(y, x);
    }
}

int Lca(int x, int y) {
    if(d[x] > d[y]) swap(x, y);
    for(int i = kmaxM - 1; ~i; i--) {
        if(d[f[y][i]] >= d[x]) {
            y = f[y][i];
        }
    }
    for(int i = kmaxM - 1; ~i; i--) {
        if(f[x][i] != f[y][i]) {
            x = f[x][i];
            y = f[y][i];
        }
    }
    return x == y ? x : f[x][0];
}

void Addedge(int x, int y) {
    e[++ec] = {h[x], y};
    h[x] = ec;
}

bool Son(int z, int x) {
    if(z > n) return 0;
    return l[z] <= l[x] && l[x] <= r[z]; // 判是否是祖先节点
}

void Solve() {
    for(int i = 1, x, y; i < n; i++) {
        cin >> x >> y;
        Addedge(x, y);
        Addedge(y, x);
    }
    Dfs(1, 0);
    Dfss(1, 1);
//  for(int i = 1; i <= n; i++) {
//      cout << "i = " << i << ":\n";
//      for(int j = 0; j < 3; j++) {
//          cout << mn[i][j] << ' ';
//      }
//      cout << '\n';
//  }
    for(int i = 1, x, y, z, l; i <= m; i++) {
        cin >> x >> y;
        x ^= lres, y ^= lres; // 强制在线
        if(pr[x] == pr[y]) {
            lres = (x == 1) + 1;
        } else {
            lres = 1e9;
            for(int j = 0; j < 3; j++) {
                z = mn[x][j];
                if(!Son(z, x) && !Son(z, y)) {
                    lres = min(lres, z);
                }
//              cout << z << '\n';
                z = mn[y][j];
                if(!Son(z, x) && !Son(z, y)) {
                    lres = min(lres, z);
                }
//              cout << z << '\n';
            }
        }
        cout << lres << '\n';
    }
}

void Init() {
    memset(h, 0, sizeof(h));
    ec = idx = lres = 0;
}

int main() {
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
    while(cin >> n >> m) {
        Init();
        Solve();
    }
    return 0;
}
posted @ 2023-08-07 16:53  ereoth  阅读(24)  评论(0)    收藏  举报