HDU 5923 Prediction
这题是2016 CCPC 东北四省赛的B题, 其实很简单. 现场想到的就是正解, 只是在合并两个并查集这个问题上没想清楚.
做法
并查集合并 + 归并
- 对每个节点 \(u\), 将 \(u\) 到根的那些边添到一个初始为空的并查集中, 得到的并查集记作 \(a_u\).
 - 询问相当于将 \(k\) 个并查集合并. 
采用二路归并, 合并次数是 \(O(n \cdot \log(n))\).
$ n/2 + n/4 + n/8 + \dots + 1 = O(n \cdot \log(n)) $ 
合并两个并查集
详细讨论将并查集 \(B\) 合并到并查集 \(A\) 中这一问题.
这个问题与
给定两无向图 $A, B, V_B \subset V_A; \quad A(E_A, V_A) \to A'( E_A, E_A \cup E_B) $.
等价.
做法
$ \forall u \in E_B, \quad A.\mathrm{unite}(u, B.\mathrm{root}(u)) $
正确性
只要验证
在\(B\)中连通的任意两点 \(u, v\), 在$ A'$中也连通.
是否满足.
Implementation
#include <bits/stdc++.h>
using namespace std;
const int N{1<<9};
const int M=1e4+5;
int n, m;
struct DSU{
    int par[N];
    int cnt;
    int find(int x){
        return par[x]==x?x: par[x]=find(par[x]);
    }
    void unite(int x, int y){
        x=find(x);
        y=find(y);
        if(x!=y){
            par[x]=y;
            --cnt;
        }
    }
    void unite(DSU &a){
        for(int i=1; i<=n; i++){
            unite(find(i), a.find(i));  // ?
        }
    }
    void init(){
        for(int i=1; i<=n; i++){
            par[i]=i;
        }
        cnt=n;
    }
    void copy(const DSU &a){
        for(int i=1; i<=n; i++){
            par[i]=a.par[i];
        }
        cnt=a.cnt;
    }
};
DSU a[M], b[M];
vector<int> g[M];
struct Edge{
    int u, v;
    void read(){
        scanf("%d%d", &u, &v);
    }
}E[M];
void dfs(int u, int f){
    a[u].copy(a[f]);
    a[u].unite(E[u].u, E[u].v);
    for(auto v: g[u]){
        dfs(v, u);
    }
}
void solve(int n){
    for(int i=1; i<n; i<<=1){   // error-prone
        for(int j=0; j+i<n; j+=i<<1){
            b[j].unite(b[j+i]);
        }
    }
    printf("%d\n", b[0].cnt);
}
// int par[M];
int main(){
    int T, cas{};
    for(cin>>T; T--; ){
        printf("Case #%d:\n", ++cas);
        // int n, m;
        cin>>n>>m;
        for(int i=1; i<=m; ++i){
            g[i].clear();
        }
        for(int i=2; i<=m; i++){
            // scanf("%d", par+i);
            int fa;
            scanf("%d", &fa);
            g[fa].push_back(i);
        }
        for(int i=1; i<=m; ++i){
            E[i].read();
        }
        a[0].init();
        dfs(1, 0);
        int q;
        cin>>q;
        for(; q--; ){
            int k;
            scanf("%d", &k);
            for(int i=0; i<k; i++){
                int x;
                scanf("%d", &x);
                b[i].copy(a[x]);
            }
            solve(k);
        }
    }
    return 0;
}
Pitfalls
归并
for(int i=1; i<n; i<<=1){   // error-prone
    for(int j=0; j+i<n; j+=i<<1){
        b[j].unite(b[j+i]);
    }
}
容易写错.
我第一发是这样写的
for(int i=2; i<=n; i<<=1){
    for(int j=0; j+i/2<n; j+=i){
        b[j].unite(b[j+i/2]);
    }
}
当n==3时, 只做了1轮归并.
应采纳第一种写法, 很清楚.
UPD
太SB了.
- 根本不用归并, 直接逐个合并就好了.
 - 根本不用 
b[i].copy(a[x]);, 只要从一个边集为空的图 (以下简称"空图") 开始, 不断把\(k\)个并查集合并进去就好了. - 不从空图开始, 而从某个并查集开始, 会快很多.
 
#include <bits/stdc++.h>
using namespace std;
const int N{1<<9};
const int M=1e4+5;
int n, m;
struct DSU{
    int par[N];
    int cnt;
    int find(int x){
        return par[x]==x?x: par[x]=find(par[x]);
    }
    void unite(int x, int y){
        x=find(x);
        y=find(y);
        if(x!=y){
            par[x]=y;
            --cnt;
        }
    }
    void unite(DSU &a){
        for(int i=1; i<=n; i++){
            unite(find(i), a.find(i));  // ?
        }
    }
    void init(){
        for(int i=1; i<=n; i++){
            par[i]=i;
        }
        cnt=n;
    }
    void copy(const DSU &a){
        for(int i=1; i<=n; i++){
            par[i]=a.par[i];
        }
        cnt=a.cnt;
    }
};
DSU a[M], b[M];
vector<int> g[M];
struct Edge{
    int u, v;
    void read(){
        scanf("%d%d", &u, &v);
    }
}E[M];
void dfs(int u, int f){
    a[u].copy(a[f]);
    a[u].unite(E[u].u, E[u].v);
    for(auto v: g[u]){
        dfs(v, u);
    }
}
int solve(int n){
    if(k==0){
        return n;
    }
    int x;
    scanf("%d", &x);
    a[0].copy(a[x]);
    for(int i=1; i<n; i++){
        scanf("%d", &x);
        a[0].unite(a[x]);
    }
    return a[0].cnt;
}
int main(){
    int T, cas{};
    for(cin>>T; T--; ){
        printf("Case #%d:\n", ++cas);
        cin>>n>>m;
        for(int i=1; i<=m; ++i){
            g[i].clear();
        }
        for(int i=2; i<=m; i++){
            // scanf("%d", par+i);
            int fa;
            scanf("%d", &fa);
            g[fa].push_back(i);
        }
        for(int i=1; i<=m; ++i){
            E[i].read();
        }
        a[0].init();
        dfs(1, 0);
        int q;
        cin>>q;
        for(; q--; ){
            int k;
            scanf("%d", &k);        
            printf("%d\n", solve(k));
        }
    }
    return 0;
}
                    
                
                
            
        
浙公网安备 33010602011771号