[题解]CF1990E2 Catch the Mole(Hard Version)

思路

我们先随便选择一个叶子结点,查询 \(B\) 次。如果是返回的结果是 \(1\),说明鼹鼠就在这个叶子结点;否则它将向上跳 \(B\) 次。

此时,我们得到一个关键结论,如果一棵子树最大深度小于等于 \(B\),那么鼹鼠一定不在这棵子树中,因为鼹鼠无论如何都跳了 \(B\) 次。

我们希望找到鼹鼠行动的链,找到了就可以直接二分找到其位置。

令当前到达的节点为 \(u\),查询所有 \(u\) 的子节点 \(v\),并把最大深度小于等于 \(B\) 的子树跳掉。当 \(v\) 查询出来过后结果为 \(1\),说明此时鼹鼠在 \(v\) 子树中,向下递归即可。

但是,这样 naive 的想法显然是有问题的,可以被链给创飞。

观察到,当我们即将查询 \(u\) 的最后一个满足最大深度大于 \(B\) 的子树时,前面所有的子树都不行,那么鼹鼠一定在 \(v\) 子树。

这样做的好处就是,每一次递归都会减少一个深度大于 \(B\) 的子树,那么最多只会查询 \(\frac{n}{B}\) 次。

接下来就只需要在我们找到的这条链上二分即可。

查询次数是 \(B + \frac{n}{B} + \log n\) 的,当 \(B = \sqrt{n}\) 时取得最小值 \(2\sqrt{n} + \log n\)

Code

#include <bits/stdc++.h>
#define re register

using namespace std;

const int N = 5010;
int n,B;
int d[N];
vector<int> v,g[N];

inline int read(){
    int r = 0,w = 1;
    char c = getchar();
    while (c < '0' || c > '9'){
        if (c == '-') w = -1;
        c = getchar();
    }
    while (c >= '0' && c <= '9'){
        r = (r << 3) + (r << 1) + (c ^ 48);
        c = getchar();
    }
    return r * w;
}

inline bool ask(int u){
    printf("? %d\n",u); fflush(stdout);
    return read();
}

inline void print(int u){
    printf("! %d\n",u); fflush(stdout);
}

inline void dfs1(int u,int fa){
    d[u] = 1;
    for (int v:g[u]){
        if (v == fa) continue;
        dfs1(v,u); d[u] = max(d[u],d[v] + 1);
    }
}

inline void dfs2(int u,int fa){
    v.push_back(u);
    vector<int> s;
    for (int v:g[u]){
        if (v == fa || d[v] <= B) continue;
        s.push_back(v);
    }
    for (re int i = 1;i < s.size();i++){
        int v = s[i];
        if (ask(v)) return dfs2(v,u);
    }
    if (!s.empty()) dfs2(s.front(),u);
}

inline void solve(){
    v.clear();
    n = read(); B = sqrt(n);
    fill(d + 1,d + n + 1,0);
    for (re int i = 1;i <= n;i++) g[i].clear();
    for (re int i = 1,a,b;i < n;i++){
        a = read(),b = read();
        g[a].push_back(b); g[b].push_back(a);
    }
    for (re int i = 2;i <= n;i++){
        if (g[i].size() == 1){
            for (re int j = 1;j <= B;j++){
                if (ask(i)) return print(i);
            }
            break;
        }
    }
    dfs1(1,0); dfs2(1,0);
    int l = 0,r = v.size() - 1;
    while (l < r){
        int mid = l + r + 1 >> 1;
        if (ask(v[mid])) l = mid;
        else{
            r = max(0,mid - 2);
            if (l) l--;
        }
    }
    print(v[l]);
}

int main(){
    int T; T = read();
    while (T--) solve();
    return 0;
}
posted @ 2024-07-25 22:19  WBIKPS  阅读(23)  评论(0)    收藏  举报