[bzoj2243][SDOI2011]染色——树链剖分+线段树

Brief Description

给定一棵树,每个点有颜色,您需要支持两种操作:

  1. 把某两个点路径上的所有点的颜色设为某种颜色。
  2. 查询两个点路径上的颜色块的个数。
    例如,颜色块:22112233视为4个颜色块。

Algorithm Design

树链剖分裸题。
我们需要在线段树上记录区间最左边的颜色与区间最右边的颜色以方便合并。

Notice

注意要多update,因为这个WA了3次

Code

#include <algorithm>
#include <cstdio>
const int maxn = 100010;
int head[maxn], value[maxn], que[maxn], deep[maxn];
int vis[maxn], size[maxn], fa[maxn][20], pl[maxn], belong[maxn];
int n, m, cnt, sz;
struct seg {
  int l, r, lc, rc, s, tag;
  // s:颜色的个数, tag:标记, l:最左端的颜色, r:最右边的颜色
} t[maxn << 2];
struct edge {
  int to, next;
} e[maxn << 1];
void insert(int u, int v) {
  e[++cnt].to = v;
  e[cnt].next = head[u];
  head[u] = cnt;
  e[++cnt].to = u;
  e[cnt].next = head[v];
  head[v] = cnt;
}
void dfs1(int x) {
  vis[x] = size[x] = 1;
  for (int i = 1; i <= 17; i++) {
    if (deep[x] < (1 << i))
      break;
    fa[x][i] = fa[fa[x][i - 1]][i - 1];
  }
  for (int i = head[x]; i; i = e[i].next) {
    if (!vis[e[i].to]) {
      deep[e[i].to] = deep[x] + 1;
      fa[e[i].to][0] = x;
      dfs1(e[i].to);
      size[x] += size[e[i].to];
    }
  }
}
void dfs2(int x, int chain) {
  pl[x] = ++sz;
  que[sz] = x;
  belong[x] = chain;
  int k = 0;
  for (int i = head[x]; i; i = e[i].next)
    if (deep[e[i].to] > deep[x] && size[k] < size[e[i].to])
      k = e[i].to;
  if (!k)
    return;
  dfs2(k, chain);
  for (int i = head[x]; i; i = e[i].next)
    if (e[i].to != k && deep[e[i].to] > deep[x])
      dfs2(e[i].to, e[i].to);
}
void update(int k) {
  int tmp = !(t[k << 1].rc ^ t[k << 1 | 1].lc);
  t[k].s = t[k << 1].s + t[k << 1 | 1].s - tmp;
  t[k].lc = t[k << 1].lc;
  t[k].rc = t[k << 1 | 1].rc;
}
void build(int k, int l, int r) {
  t[k].l = l, t[k].r = r, t[k].s = 1, t[k].tag = -1;
  if (l == r) {
    t[k].lc = t[k].rc = value[que[l]];
    return;
  }
  int mid = (l + r) >> 1;
  build(k << 1, l, mid);
  build(k << 1 | 1, mid + 1, r);
  update(k);
}
int lca(int x, int y) {
  if (deep[x] < deep[y])
    std::swap(x, y);
  int t = deep[x] - deep[y];
  for (int i = 0; i <= 17; i++) {
    if (t & (1 << i))
      x = fa[x][i];
  }
  for (int i = 17; i >= 0; i--) {
    if (fa[x][i] != fa[y][i]) {
      x = fa[x][i];
      y = fa[y][i];
    }
  }
  if (x == y)
    return x;
  else
    return fa[x][0];
}

void pushdown(int k) {
  int tmp = t[k].tag;
  t[k].tag = -1;
  if (tmp == -1 || t[k].l == t[k].r)
    return;
  t[k << 1].s = t[k << 1 | 1].s = 1;
  t[k << 1].tag = t[k << 1 | 1].tag = tmp;
  t[k << 1].lc = t[k << 1].rc = tmp;
  t[k << 1 | 1].lc = t[k << 1 | 1].rc = tmp;
  update(k);
}
int query(int k, int x, int y) {
  pushdown(k);
  int l = t[k].l, r = t[k].r, mid = (l + r) >> 1;
  if (x <= l && r <= y) {
    return t[k].s;
  }
  int tmp = !(t[k << 1].rc ^ t[k << 1 | 1].lc) && (x <= mid) && (y > mid),
      ans = 0;
  if (x <= mid)
    ans += query(k << 1, x, y);
  if (y > mid)
    ans += query(k << 1 | 1, x, y);
  return ans - tmp;
}
int getcolor(int k, int pos) {
  pushdown(k);
  int l = t[k].l, r = t[k].r, mid = (l + r) >> 1;
  if (l == r)
    return t[k].lc;
  if (pos <= mid)
    return getcolor(k << 1, pos);
  else
    return getcolor(k << 1 | 1, pos);
}
int solvesum(int x, int f) {
  int sum = 0;
  while (belong[x] != belong[f]) {
    sum += query(1, pl[belong[x]], pl[x]);
    if (getcolor(1, pl[belong[x]]) == getcolor(1, pl[fa[belong[x]][0]]))
      sum--;
    x = fa[belong[x]][0];
  }
  sum += query(1, pl[f], pl[x]);
  return sum;
}
void change(int k, int x, int y, int c) {
  pushdown(k);
  int l = t[k].l, r = t[k].r, mid = (l + r) >> 1;
  if (x <= l && r <= y) {
    t[k].lc = t[k].rc = c;
    t[k].s = 1;
    t[k].tag = c;
    return;
  }
  if (x <= mid)
    change(k << 1, x, y, c);
  if (y > mid)
    change(k << 1 | 1, x, y, c);
  update(k);
}
void solvechange(int x, int f, int c) {
  while (belong[x] != belong[f]) {
    change(1, pl[belong[x]], pl[x], c);
    x = fa[belong[x]][0];
  }
  change(1, pl[f], pl[x], c);
}
void solve() {
  dfs1(1);
  dfs2(1, 1);
  build(1, 1, n);
  for (int i = 1; i <= m; i++) {
    char ch[10];
    scanf("%s", ch);
    if (ch[0] == 'Q') {
      int a, b;
      scanf("%d %d", &a, &b);
      int t = lca(a, b);
      printf("%d\n", solvesum(a, t) + solvesum(b, t) - 1);
    } else {
      int a, b, c;
      scanf("%d %d %d", &a, &b, &c);
      int t = lca(a, b);
      solvechange(a, t, c);
      solvechange(b, t, c);
    }
  }
}
void init() {
  scanf("%d %d", &n, &m);
  for (int i = 1; i <= n; i++)
    scanf("%d", &value[i]);
  for (int i = 1; i < n; i++) {
    int x, y;
    scanf("%d %d", &x, &y);
    insert(x, y);
  }
}
int main() {
#ifndef ONLINE_JUDGE
  freopen("input", "r", stdin);
#endif
  init();
  solve();
  return 0;
}

posted on 2017-03-07 17:34  蒟蒻konjac  阅读(193)  评论(0编辑  收藏  举报