[模板]BZOJ4756线段树合并

题面

Solution:

板子不解释

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cstring>

using namespace std;

namespace io {
    char buf[1<<21], *pos = buf, *end = buf;
    inline char getc() 
    { return pos == end && (end = (pos = buf) + fread(buf, 1, 1<<21, stdin), pos == end) ? EOF : *pos ++; }
    inline int rint() {
        register int x = 0, f = 1; register char c;
        while (!isdigit(c = getc())) if (c == '-') f = -1;
        while (x = (x << 1) + (x << 3) + (c ^ 48), isdigit(c = getc()));
        return x * f;
    }
}
using io::rint;

const int N = 1e5 + 1;

int n, ans[N];
int Ht[N], a[N], SIZE;
int size[N<<6], rt[N], cnt, ls[N<<6], rs[N<<6];//线段树合并一般<<6位

int head[N], nxt[N<<1], ver[N<<1], tot;
void addEdge(int u, int v) 
{ ver[++tot] = v, nxt[tot] = head[u], head[u] = tot; }

void pushup(int x)
{ size[x] = size[ls[x]] + size[rs[x]]; }

void insert(int &x, int lval, int rval, int val) {
    x = ++cnt;
    if (lval == rval) { size[x] ++; return; }
    int mid = lval + rval >> 1;
    if (val <= mid) insert(ls[x], lval, mid, val);
    else insert(rs[x], mid+1, rval, val);
    pushup(x);
}

int query(int x, int lval, int rval, int Left, int Right) {
    if (!x) return 0;
    if (Left <= lval && rval <= Right) return size[x];
    int mid = lval + rval >> 1;
    int sum = 0;
    if (Left <= mid) sum += query(ls[x], lval, mid, Left, Right);
    if (mid < Right) sum += query(rs[x], mid + 1, rval, Left, Right);
    return sum;
}

int merge(int x, int y) {
    if ((!x) || (!y)) return x + y;
    ls[x] = merge(ls[x], ls[y]);
    rs[x] = merge(rs[x], rs[y]);
    pushup(x);
    return x;
}

void DFS(int u, int fa) {
    insert(rt[u], 1, SIZE, a[u]);
    for (int i = head[u]; i; i = nxt[i]) if (ver[i] != fa) DFS(ver[i], u);
    for (int i = head[u]; i; i = nxt[i]) if (ver[i] != fa) rt[u] = merge(rt[u], rt[ver[i]]);
    ans[u] = query(rt[u], 1, SIZE, a[u]+1, SIZE);
}

int main() {
    freopen("BZOJ4756.in", "r", stdin);
    freopen("BZOJ4756.out", "w", stdout);

    n = rint();
    for (int i = 1; i <= n; ++ i) a[i] = Ht[i] = rint();
    for (int i = 2; i <= n; ++ i) {
        int fa = rint();
        addEdge(fa, i);
        addEdge(i, fa);
    }
    
    sort(Ht + 1, Ht + 1 + n);
    SIZE = unique(Ht + 1, Ht + 1 + n) - Ht - 1;
    for (int i = 1; i <= n; ++ i) a[i] = lower_bound(Ht + 1, Ht + 1 + n, a[i]) - Ht;

    DFS(1,0);
    
    for (int i = 1; i <= n; ++ i) printf("%d\n", ans[i]);
}

posted @ 2019-02-20 20:47  茶Tea  阅读(160)  评论(0编辑  收藏  举报