树上莫队 学习笔记

树上莫队本质上是把树上的结点转化为区间信息,从而使用莫队求解。但是不能直接使用树链剖分的 \(\text{dfs}\) 序,因为树上任意一条路径所对应的区间不是连续的。此处需要用到欧拉序。欧拉序即为一个结点入队时将其加到序列里,出队时再加入一次(所以序列的总长度是结点数 \(\times 2\),每个结点恰好出现 \(2\) 次)。

如:

图1

\(1\) 为根节点,该树的欧拉序为 \((1, 3, 5, 5, 4, 4, 3, 2, 2, 1)\)

\(\text{first[u]}\)\(u\) 在欧拉序里第一次出现的位置,\(\text{last[u]}\) 为第二次出现的位置,此时如果我们要查询 \((u, v)\) 之间的路径(假设 \(\text{first[u]} \leq \text{first[v]}\)),若 \(u\) 已是 \(v\) 的祖先,则欧拉序的区间 \([\text{first[u]}, \text{first[v]}]\) 中所有出现次数等于 \(1\) 的结点即为 \((u, v)\) 路径上的点;若 \(u\) 不是 \(v\) 的祖先,则 \([\text{last[u]}, \text{first[v]}]\) 中所有出现次数等于 \(1\) 的结点再加上 \(\operatorname{lca}(u, v)\) 即为 \((u, v)\) 之间的路径。

上述预处理可以直接在树剖的过程中完成。通过欧拉序可以直接将路径转化为连续的区间。然后便可以方便地用莫队求解。

例题: SP10707

板子题,看代码即可。

#include <bits/stdc++.h>

using namespace std;

int g[80005];

struct Q {
    int l, r, id, lca;

    const bool operator<(const Q &rhs) const {
        if (g[l] != g[rhs.l])
            return g[l] < g[rhs.l];
        else
            return (g[l] & 1? r > rhs.r: r < rhs.r);
    }
};

int n, m, u, v, nowAns, cnt, l = 1, r, tot, lca, len;
int euler[80005], head[40005], nxt[80005], to[80005], w[40005], ans[100005], dep[40005], first[40005], last[40005], hson[40005], fa[40005], size[40005], cntc[40005], cntid[40005], top[40005];
unordered_map<int, int> cc;
Q q[100005];
char ch;

void read(int &num) {
    while (!isdigit(ch = getchar()));
    
    num = ch - '0';

    while (isdigit(ch = getchar()))
        num = (num << 3) + (num << 1) + ch - '0';
}

void write(int num) {
    if (num >= 10) write(num / 10);

    putchar(num % 10 + '0');
}

void dfs1(int g, int _dep) {
    dep[g] = _dep;
    size[g] = 1;

    for (int i = head[g]; i; i = nxt[i])
        if (to[i] != fa[g]) {
            fa[to[i]] = g;
            dfs1(to[i], _dep + 1);
            size[g] += size[to[i]];

            if (size[to[i]] > size[hson[g]])
                hson[g] = to[i];
        }
}

void dfs2(int g, int _top) {
    euler[++tot] = g;
    first[g] = tot;
    top[g] = _top;

    if (hson[g]) {
        dfs2(hson[g], _top);

        for (int i = head[g]; i; i = nxt[i])
            if (to[i] !=  fa[g] && to[i] != hson[g])
                dfs2(to[i], to[i]);
    }

    euler[++tot] = g;
    last[g] = tot;
}

int LCA(int u, int v) {
    while (top[u] != top[v]) {
        if (dep[top[u]] < dep[top[v]])
            swap(u, v);
        
        u = fa[top[u]];
    }

    return dep[u] < dep[v]? u: v;
}

void move(int id, bool g) {
    if (g) {
        ++cntc[w[id]];

        if (!(cntc[w[id]] ^ 1))
            ++nowAns;
    } else {
        --cntc[w[id]];

        if (!cntc[w[id]])
            --nowAns;
    }
}

int main() {
    read(n), read(m);
    len = (n << 1) / sqrt(m * 2 / 3);

    for (int i = 1; i <= n; ++i) {
        read(w[i]);
        w[i] = (cc[w[i]]? cc[w[i]]: (cc[w[i]] = ++tot));
    }

    for (int i = 1; i <= (n << 1); ++i)
        g[i] = (i - 1) / len;

    for (int i = 0; i < n - 1; ++i) {
        read(u), read(v);
        nxt[++cnt] = head[u];
        head[u] = cnt;
        to[cnt] = v;
        nxt[++cnt] = head[v];
        head[v] = cnt;
        to[cnt] = u;
    }

    tot = 0;
    dfs1(1, 1);
    dfs2(1, 1);

    for (int i = 1; i <= m; ++i) {
        read(u), read(v);
        q[i].id = i;
        lca = LCA(u, v);

        if (first[u] > first[v]) swap(u, v);

        if (lca == u) q[i].l = first[u];
        else q[i].l = last[u], q[i].lca = lca;

        q[i].r = first[v];
    }

    sort(q + 1, q + m + 1);

    for (int i = 1; i <= m; ++i) {
        while (l > q[i].l) {
            ++cntid[euler[--l]];
            move(euler[l], 2 - cntid[euler[l]]);
        }

        while (r < q[i].r) {
            ++cntid[euler[++r]];
            move(euler[r], 2 - cntid[euler[r]]);
        }

        while (r > q[i].r) {
            --cntid[euler[r]];
            move(euler[r], cntid[euler[r]]);
            --r;
        }

        while (l < q[i].l) {
            --cntid[euler[l]];
            move(euler[l], cntid[euler[l]]);
            ++l;
        }

        if (q[i].lca) move(q[i].lca, 1);

        ans[q[i].id] = nowAns;

        if (q[i].lca) move(q[i].lca, 0);
    }

    for (int i = 1; i <= m; ++i)
        write(ans[i]), putchar('\n');

    return 0;
}
posted @ 2022-10-16 23:47  wf715  阅读(74)  评论(0)    收藏  举报