【数据结构】树上启发式合并

https://codeforces.com/contest/1923/problem/E

这一道题有三种不一样的写法,有一种是用启发式合并(跟树上没啥关系)去优化dp的做法。

int n, k;
int c[200005];
vector<int> G[200005];

ll ans;
map<int, int> cnt[200005];

void dfs (int u, int p) {
    cnt[u][c[u]] = 1;  // cnt[u][x]表示子树u内部,颜色为x的有效节点的数量

    for (int v : G[u]) {
        if (v == p) {
            continue;
        }
        dfs (v, u);

        if (cnt[u].size() < cnt[v].size()) {
            // 启发式合并优化dp,保证一定是小的数组合并到大的里
            swap (cnt[u], cnt[v]);
        }
        for (auto [x, y] : cnt[v]) {
            ans += 1LL * cnt[u][x] * y;  // 相当于把同为x的答案,从v子树合并到u的已合并部分中
            cnt[u][x] += y;
        }
        cnt[u][c[u]] = 1;    // 这道题比较特殊,退出节点v的时候回到了节点u,要把节点u重新堵上
        cnt[v].clear();
    }

}


void solve() {
    RD (n);
    RDN (c, n);
    for (int i = 1; i <= n; ++i) {
        G[i].clear();
        cnt[i].clear();
    }
    for (int i = 1; i <= n - 1; ++i) {
        int x, y;
        RD (x, y);
        G[x].push_back (y);
        G[y].push_back (x);
    }
    ans = 0;
    dfs (1, 1);
    WT (ans);
}

https://codeforces.com/contest/375/problem/D

这道题也有用树上莫队的解法,这里再给一个树上启发式合并的解法。

树上启发式合并,类似轻重链剖分,先算出每个节点的重儿子,然后计算答案时先递归计算轻儿子的答案,标记clr为true,然后计算重儿子的答案,clr为false,然后把轻儿子的节点暴力插到重儿子里,最后把父节点自己加入。

int n;
vector<int> G[MAXN];
int siz[MAXN], mch[MAXN];

void dfs1(int u, int p) {
    // 初始化siz,找到重儿子
    siz[u] = 1, mch[u] = 0;
    for (int &v : G[u]) {
        if (v == p)
            continue;
        dfs1(v, u);
        siz[u] += siz[v];
        if (siz[mch[u]] < siz[v])
            mch[u] = v;
    }
}

由于树上启发式合并并不关心深度,所以没有必要维护深度。

void calc(int u, int p, int skip, int d) {
    bit.Add(cnt[c[u]], -1);
    cnt[c[u]] += d;
    bit.Add(cnt[c[u]], 1);
    for (int v : G[u]) {
        if (v == p || v == skip)
            continue;
        calc(v, u, 0, d);
    }
}

void dfs2(int u, int p, bool keep) {
    // 启发式合并dfs(这个例子不好扩展)
    // 先遍历所有的轻儿子,计算轻儿子“内部”的答案
    for (int &v : G[u]) {
        if (v == p || v == mch[u])
            continue;
        dfs2(v, u, false);
    }
    // 如果存在重儿子,则计算重儿子“内部的答案”,将重儿子的答案数组继承给u
    if (mch[u])
        dfs2(mch[u], u, true);
    // 把节点u和其他所有的轻儿子,合并到重儿子的答案中,并计算贡献(合并的写法)
    calc(u, p, mch[u], 1);
    for (pii &q : Q[u]) {
        int id = q.first, k = q.second;
        ans[q.first] = bit.Sum(k, n);
    }
    // 如果不是重儿子,则消除当前节点的所有影响
    if (!keep)
        calc(u, p, 0, -1);
}

然后是主要的计算过程dfs2,dfs2优先进入所有的轻儿子,并且不keep轻儿子的答案,保持树状数组为空。然后进入重儿子计算并keep重儿子的结果。这里使用一个辅助函数calc,calc的修改值为1时表示向树状数组中添加,然后命令其在添加时skip掉重儿子。计算完毕后树状数组中存着这棵子树对应的状态,然后取出所有的询问进行回答。那之后,假如不keep树状数组,调用calc修改值为-1,并且不跳过重儿子,把整棵子树删除干净。

时间复杂度为 \(O(nlog^2n)\)

https://codeforces.com/gym/102832/problem/F

这里的查询要去重,所以要先计算再查询。而且要注意cache的命中。一次树遍历就统计出所有的信息,把常用的局部值放在数组的低维。

int n, k;
int a[MAXN];

vector<int> G[MAXN];
int siz[MAXN], mch[MAXN];

void dfs1(int u, int p) {
    // 维护siz,计算重儿子
    siz[u] = 1, mch[u] = 0;
    for (int &v : G[u]) {
        if (v == p)
            continue;
        dfs1(v, u);
        siz[u] += siz[v];
        if (siz[mch[u]] < siz[v])
            mch[u] = v;
    }
}

int cnt[1 << 20][17][2];
ll ans;

void calc1(int u, int p, int LCA) {
    // 单纯计算u子树内的节点合并入cnt中时产生的贡献,并不真正把u节点中的子树合并入cnt中(因为会和“轻儿子内部的答案”重复计算)
    int val = a[u] ^ a[LCA];
    for (int k = 16; k >= 0; --k) {
        int uk = (u >> k) & 1;
        ans += (1LL << k) * cnt[val][k][uk ^ 1];
    }
    for (int &v : G[u]) {
        if (v == p)
            continue;
        calc1(v, u, LCA);
    }
}

void calc2(int u, int p) {
    // 只是把u节点的子树合并入cnt中,并不计算他们的贡献
    int val = a[u];
    for (int k = 16; k >= 0; --k) {
        int uk = (u >> k) & 1;
        ++cnt[val][k][uk];
    }
    for (int &v : G[u]) {
        if (v == p)
            continue;
        calc2(v, u);
    }
}

void calc3(int u, int p) {
    // 消除u子树的影响,注意这里的memset不是对整个数组进行的
    int val = a[u];
    memset(cnt[val], 0, sizeof(cnt[val]));
    for (int &v : G[u]) {
        if (v == p)
            continue;
        calc3(v, u);
    }
}

void dfs2(int u, int p, bool keep) {
    // 计算轻儿子内部的贡献
    for (int &v : G[u]) {
        if (v == p || v == mch[u])
            continue;
        dfs2(v, u, false);
    }
    // 计算重儿子内部的贡献
    if (mch[u])
        dfs2(mch[u], u, true);
    // 将u节点和重儿子合并,称为“已合并部分”
    int val = a[u];
    for (int k = 16; k >= 0; --k) {
        int uk = (u >> k) & 1;
        ans += (1LL << k) * cnt[0][k][uk ^ 1];
        ++cnt[val][k][uk];
    }
    // 将轻儿子逐个添加到已合并部分
    for (int &v : G[u]) {
        if (v == p || v == mch[u])
            continue;
        // 只计算轻儿子和已合并部分中间的贡献,不要把轻儿子加入(避免重复计算轻儿子内部的贡献)
        calc1(v, u, u);
        // 只把轻儿子加入,不计算他们内部和他们和已合并部分之间的贡献
        calc2(v, u);
    }
    if (!keep) {
        // 如果是轻儿子,就清空他的贡献
        memset(cnt[val], 0, sizeof(cnt[val]));
        for (int &v : G[u]) {
            if (v == p)
                continue;
            calc3(v, u);
        }
    }
}

void solve() {
    scanf("%d", &n);
    for (int i = 1; i <= n; ++i)
        scanf("%d", &a[i]);
    for (int i = 1; i <= n; ++i)
        G[i].clear();
    for (int i = 1; i <= n - 1; ++i) {
        int u, v;
        scanf("%d%d", &u, &v);
        G[u].push_back(v);
        G[v].push_back(u);
    }
    dfs1(1, 0);
    dfs2(1, 0, true);
    printf("%lld\n", ans);
}
posted @ 2021-03-16 18:33  purinliang  阅读(84)  评论(0编辑  收藏  举报