BZOJ2243. [SDOI2011]染色

 

[传送门]

 

树链剖分。线段树维护区间最左边和最右边的颜色以及区间颜色段,合并时等于左区间颜色段+右区间颜色段-[左区间右端点颜色==右区间左端点颜色]

#include <bits/stdc++.h>

namespace IO
{
    #define getc getchar
    void read() {}
    template <typename T, typename... T2>
    inline void read(T &x, T2 &... oth) {
        T f = 1; x = 0;
        char ch = getc();
        while (!isdigit(ch)) { if (ch == '-') f = -1; ch = getc(); }
        while (isdigit(ch)) { x = x * 10 + ch - 48; ch = getc(); }
        x *= f;
        read(oth...);
    }
} // using namespace IO
#define read IO::read

const int N = 1e5 + 7;
int n, m, col[N], sz[N], son[N], top[N], dfn[N], tol, wt[N], fa[N];
int dep[N];
std::vector<int> vec[N];

struct Seg {
    #define lp p << 1
    #define rp p << 1 | 1
    static const int NN = N * 4;
    int lazy[NN], sum[NN], lc[NN], rc[NN];
    inline void pushup(int p) {
        sum[p] = sum[lp] + sum[rp];
        lc[p] = lc[lp]; rc[p] = rc[rp];
        if (rc[lp] == lc[rp]) sum[p]--;
    }
    inline void tag(int p, int color) {
        lc[p] = rc[p] = color;
        sum[p] = 1;
        lazy[p] = color;
    }
    inline void pushdown(int p) {
        if (lazy[p] >= 0) {
            tag(lp, lazy[p]);
            tag(rp, lazy[p]);
            lazy[p] = -1;
        }
    }
    void build(int p, int l, int r) {
        lazy[p] = -1;
        if (l == r) {
            sum[p] = 1;
            lc[p] = rc[p] = wt[l];
            return;
        }
        int mid = l + r >> 1;
        build(lp, l, mid);
        build(rp, mid + 1, r);
        pushup(p);
    }
    void update(int p, int l, int r, int x, int y, int c) {
        //if (x > r || l > y) return;
        if (x <= l && y >= r) {
            tag(p, c);
            return;
        }
        pushdown(p);
        int mid = l + r >> 1;
        if (x <= mid) update(lp, l, mid, x, y, c);
        if (y > mid) update(rp, mid + 1, r, x, y, c);
        pushup(p);
    }
    int query(int p, int l, int r, int x, int y) {
        //if (x > r || l > y) return 0;
        if (x <= l && y >= r) return sum[p];
        pushdown(p);
        int mid = l + r >> 1;
        if (x > mid) return query(rp, mid + 1, r, x, y);
        if (y <= mid) return query(lp, l, mid, x, y);
        int ans = query(lp, l, mid, x, y) + query(rp, mid + 1, r, x, y);
        if (rc[lp] == lc[rp]) ans--;
        return ans;
    }
    int query(int p, int l, int r, int pos) {
        if (l == r) return lc[p];
        if (lazy[p] >= 0) return lazy[p];
        int mid = l + r >> 1;
        pushdown(p);
        if (pos <= mid) return query(lp, l, mid, pos);
        return query(rp, mid + 1, r, pos);
    }
    void print(int p, int l, int r) {
        if (l == r) return (void)(printf("%d ", lc[p]));
        pushdown(p);
        int mid = l + r >> 1;
        print(lp, l, mid);
        print(rp, mid + 1, r);
    }
} seg;

void dfs1(int u, int pre) {
    sz[u] = 1;
    fa[u] = pre;
    dep[u] = dep[pre] + 1;
    for (int v: vec[u]) {
        if (v == pre) continue;
        dfs1(v, u);
        sz[u] += sz[v];
        if (sz[v] > sz[son[u]]) son[u] = v;
    }
}

void dfs2(int u, int tp) {
    top[u] = tp;
    dfn[u] = ++tol;
    wt[tol] = col[u];
    if (!son[u]) return;
    dfs2(son[u], tp);
    for (int v: vec[u])
        if (v != fa[u] && v != son[u])
            dfs2(v, v);
}

void solve(int u, int v, int c) {
    while (top[u] != top[v]) {
        if (dep[top[u]] < dep[top[v]]) std::swap(u, v);
        seg.update(1, 1, n, dfn[top[u]], dfn[u], c);
        u = fa[top[u]];
    }
    if (dep[u] > dep[v]) std::swap(u, v);
    seg.update(1, 1, n, dfn[u], dfn[v], c);
}

int solve(int u, int v) {
    int ans = 0;
    while (top[u] != top[v]) {
        if (dep[top[u]] < dep[top[v]]) std::swap(u, v);
        ans += seg.query(1, 1, n, dfn[top[u]], dfn[u]);
        if (seg.query(1, 1, n, dfn[top[u]]) == seg.query(1, 1, n, dfn[fa[top[u]]])) ans--;
        u = fa[top[u]];
    }
    if (dep[u] > dep[v]) std::swap(u, v);
    ans += seg.query(1, 1, n, dfn[u], dfn[v]);
    return ans;
}

int main() {
    read(n, m);
    for (int i = 1; i <= n; i++)
        read(col[i]);
    for (int i = 1, u, v; i < n; i++) {
        read(u, v);
        vec[u].push_back(v);
        vec[v].push_back(u);
    }
    dfs1(1, 0);
    dfs2(1, 1);
    seg.build(1, 1, n);
    //seg.print(1, 1, n);
    for (int u, v, c; m--; ) {
        char s[10];
        scanf("%s", s);
        read(u, v);
        if (s[0] == 'C') {
            read(c);
            solve(u, v, c);
        } else {
            printf("%d\n", solve(u, v));
        }
    }
    return 0;
}
View Code

 

posted @ 2019-10-18 23:40  Mrzdtz220  阅读(101)  评论(0编辑  收藏  举报