SPOJ COT2 - Count on a tree II(LCA+离散化+树上莫队)

COT2 - Count on a tree II


You are given a tree with N nodes. The tree nodes are numbered from 1 to N. Each node has an integer weight.

We will ask you to perform the following operation:

  • u v : ask for how many different integers that represent the weight of nodes there are on the path from u to v.


In the first line there are two integers N and M. (N <= 40000, M <= 100000)

In the second line there are N integers. The i-th integer denotes the weight of the i-th node.

In the next N-1 lines, each line contains two integers u v, which describes an edge (uv).

In the next M lines, each line contains two integers u v, which means an operation asking for how many different integers that represent the weight of nodes there are on the path from u to v.


For each operation, print its result.


8 2
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5
7 8


题目链接:SPOJ COT2





#include <stdio.h>
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define LC(x) (x<<1)
#define RC(x) ((x<<1)+1)
#define MID(x,y) ((x+y)>>1)
#define CLR(arr,val) memset(arr,val,sizeof(arr))
#define FAST_IO ios::sync_with_stdio(false);cin.tie(0);
typedef pair<int, int> pii;
typedef long long LL;
const double PI = acos(-1.0);
const int N = 40010;
const int M = 200010;
struct edge
    int to, nxt;
    edge() {}
    edge(int _to, int _nxt): to(_to), nxt(_nxt) {}
struct query
    int u, v;
    int lca, x;
    int id, l, r, b;
    bool operator<(const query &rhs)const
        if (b != rhs.b)
            return b < rhs.b;
        return r < rhs.r;
edge E[N << 1];
query Q[M];
int head[N], tot;

int arr[N];//values in nodes
int ver[N << 1], F[N], D[N << 1], dp[N << 1][18], ts; //for LCA
int ST[N << 1], EN[N << 1], A[N << 1], sz, unit; //for Mo's algo

int ans[M], cnt[N];
int cnode[N];

void init()
    CLR(head, -1);
    tot = 0;
    ts = 0;
    sz = 0;
    CLR(ans, 0);
    CLR(cnt, 0);
    CLR(cnode, 0);
inline void add(int s, int t)
    E[tot] = edge(t, head[s]);
    head[s] = tot++;
void dfs(int u, int pre, int d)
    ver[++ts] = u;
    D[ts] = d;
    F[u] = ts;

    ST[u] = ++sz;
    A[sz] = u;

    for (int i = head[u]; ~i; i = E[i].nxt)
        int v = E[i].to;
        if (v != pre)
            dfs(v, u, d + 1);

            ver[++ts] = u;
            D[ts] = d;

    EN[u] = ++sz;
    A[sz] = u;
void RMQ_init(int l, int r)
    int i, j;
    for (i = l; i <= r; ++i)
        dp[i][0] = i;
    for (j = 1; l + (1 << j) - 1 <= r; ++j)
        for (i = l; i + (1 << j) - 1 <= r; ++i)
            int a = dp[i][j - 1], b = dp[i + (1 << (j - 1))][j - 1];
            dp[i][j] = D[a] < D[b] ? a : b;
int LCA(int u, int v)
    int l = F[u], r = F[v];
    if (l > r)
        swap(l, r);
    int k = log2(r - l + 1);
    int a = dp[l][k], b = dp[r - (1 << k) + 1][k];
    return D[a] < D[b] ? ver[a] : ver[b];
inline void Add(const int &u,int &Ans)
    if (cnode[u] == 1)
        if (++cnt[arr[u]] == 1)
    else if (cnode[u] == 2)
        if (--cnt[arr[u]] == 0)
inline void Del(const int &u, int &Ans)
    if (cnode[u] == 0)
        if (--cnt[arr[u]] == 0)
    else if (cnode[u] == 1)
        if (++cnt[arr[u]] == 1)
int main(void)
    int n, m, i;
    while (~scanf("%d%d", &n, &m))
        for (i = 1; i <= n; ++i)
            scanf("%d", &arr[i]);

        sort(vec.begin(), vec.end());
        vec.erase(unique(vec.begin(), vec.end()), vec.end());
        for (i = 1; i <= n; ++i)
            arr[i] = lower_bound(vec.begin(), vec.end(), arr[i]) - vec.begin() + 1;

        for (i = 1; i < n; ++i)
            int u, v;
            scanf("%d%d", &u, &v);
            add(u, v);
            add(v, u);
        dfs(1, -1, 0);
        unit = sqrt(sz);
        RMQ_init(1, ts);
        for (i = 0; i < m; ++i)
            scanf("%d%d", &Q[i].u, &Q[i].v);
            Q[i].id = i;
            Q[i].lca = LCA(Q[i].u, Q[i].v);

            if (ST[Q[i].u] > ST[Q[i].v])
                swap(Q[i].u, Q[i].v);
            if (Q[i].lca == Q[i].u)
                Q[i].l = ST[Q[i].u];
                Q[i].r = ST[Q[i].v];
                Q[i].x = 0;
                Q[i].l = EN[Q[i].u];
                Q[i].r = ST[Q[i].v];
                Q[i].x = 1;
            Q[i].b = Q[i].l / unit;
        sort(Q, Q + m);
        int L = Q[0].l, R = L - 1;
        int Ans = 0;
        for (i = 0; i < m; ++i)
            while (L > Q[i].l)
                Add(A[--L], Ans);

            while (L < Q[i].l)
                Del(A[L++], Ans);

            while (R > Q[i].r)
                Del(A[R--], Ans);

            while (R < Q[i].r)
                Add(A[++R], Ans);

            if (Q[i].x)
                Add(Q[i].lca, Ans);

            ans[Q[i].id] = Ans;

            if (Q[i].x)
                Del(Q[i].lca, Ans);
        for (i = 0; i < m; ++i)
            printf("%d\n", ans[i]);
    return 0;
