题解P3714 [BJOI2017]树的难题

P3714 [BJOI2017]树的难题

首先吐槽一下自己刚开始理解错题意了,以为题目中的按顺序可以按任意顺序.

这道题是一道关于树上路径的问题,很明显可以想到点分治,考虑当前的分治中心为 \(x\).

那么答案可以分为下面四种情况

1.序列一端为 \(x\), 另一端在子树内.

2.序列两端在两个不同的子树内,且两个子树到分治中心的颜色相同

3.序列两端在两个不同的子树内,且两个子树到分治中心的颜色不同

4.序列两端在一个子树内

对于第4种情况,很好解决,直接递归就行,第1种情况也是,重点是第2种和第3种.

因为我们需要将所有可能的情况都存起来,所以考虑用线段树, \(x\) 的所有儿子按边的颜色排序,然后再来枚举,

那么就只需要开两棵线段树了,一棵存了之前的有儿子,一棵只存了当前颜色的儿子,在 \(dfs\) 的时候分别查询一下就是了,

只需要在查询存的是当前颜色的儿子时需要减去一个 \(val\).

代码

#include<iostream>
#include<cstdio>
#include<queue>
#define INF (0x3f3f3f3f)
using namespace std;
const int N = 2e5 + 5;
typedef pair<int, int> PII;

int n, m, ql, qr;
int idx, rt1, rt2, tot, val[N], ans = -INF;
int head[2 * N], net[2 * N], edge[2 * N], ver[2 * N];
struct tree
{
    int l, r, v, maxn;
    void init()
    {
        l = r = 0;
        v = maxn = -INF;
    }
} tr[80 * N];
bool st[N];

void pushup(int p)
{
    tr[p].maxn = max(tr[tr[p].l].maxn, tr[tr[p].r].maxn);
}

int merge(int p, int q, int l, int r)
{
    if (!p || !q)
        return p + q;
    if (l == r)
    {
        tr[p].maxn = tr[p].v = max(tr[p].v, tr[q].v);
        tr[q].init();
        return p;
    }
    int mid = (l + r) >> 1;
    tr[p].l = merge(tr[p].l, tr[q].l, l, mid);
    tr[p].r = merge(tr[p].r, tr[q].r, mid + 1, r);
    tr[q].init();
    pushup(p);
    return p;
}

void insert(int p, int l, int r, int d, int x)
{
    if (l == r)
    {
        tr[p].maxn = tr[p].v = max(tr[p].v, x);
        return;
    }
    int mid = (l + r) >> 1;
    if (d <= mid)
    {
        if (!tr[p].l)
            tr[p].l = ++tot;
        insert(tr[p].l, l, mid, d, x);            
    }
    else
    {
        if (!tr[p].r)
            tr[p].r = ++tot;
        insert(tr[p].r, mid + 1, r, d, x);        
    }
    pushup(p);
}

int query(int p, int L, int R, int l, int r)
{
    if (L >= l && R <= r)
        return tr[p].maxn;
    int mid = (L + R) >> 1, res = -INF;
    if (l <= mid && tr[p].l)
        res = max(res, query(tr[p].l, L, mid, l, r));
    if (r > mid && tr[p].r)
        res = max(res, query(tr[p].r, mid + 1, R, l, r));
    return res;
}

void add(int a, int b, int c)
{
    net[++idx] = head[a], ver[idx] = b;
    edge[idx] = c, head[a] = idx;
}

int get_siz(int u, int fa)
{
    int res = 1;
    for (int i = head[u]; i; i = net[i])
        if (ver[i] != fa && !st[ver[i]])
            res += get_siz(ver[i], u);
    return res;
}

int get_wc(int u, int fa, int tot, int &wc)
{
    int sum = 1, ms = 0;
    for (int i = head[u]; i; i = net[i])
    {
        int v = ver[i];
        if (st[v] || v == fa)
            continue;
        int t = get_wc(v, u, tot, wc);
        sum += t;
        ms = max(ms, t);
    }
    ms = max(ms, tot - sum);
    if (ms <= tot / 2)
        wc = u;
    return sum;
}

void get_dist(int rt, int u, int fa, int vals, int len, int last, int ned)
{
    if (len > qr)
        return;
    ans = max(ans, query(rt1, 1, qr, max(1, ql - len), max(1, qr - len))  + vals - ned);
    ans = max(ans, query(rt2, 1, qr, max(1, ql - len), max(1, qr - len)) + vals);
    if (len >= ql)
        ans = max(ans, vals);
    insert(rt, 1, qr, len, vals);
    for (int i = head[u]; i; i = net[i])
    {
        int v = ver[i];
        if (v == fa || st[v])
            continue;
        if (edge[i] == last)
            get_dist(rt, v, u, vals, len + 1, last, ned);
        else
            get_dist(rt, v, u, vals + val[edge[i]], len + 1, edge[i], ned);
    }
}

void clear(int x)
{
    if (tr[x].l)
        clear(tr[x].l);
    if (tr[x].r)
        clear(tr[x].r);
    tr[x].init();
}

void calc(int u)
{
    get_wc(u, -1, get_siz(u, -1), u);
    priority_queue<PII> q;
    st[u] = true;
    for (int i = head[u]; i; i = net[i])
    {
        int v = ver[i];
        if (st[v])
            continue;
        q.push(make_pair(edge[i], v));
    }
    int last = 0;
    rt2 = ++tot;
    while (!q.empty())
    {
        int v = q.top().second, tp = q.top().first;
        q.pop();   
        if (tp != last)
        {
            merge(rt2, rt1, 1, qr);
            rt1 = ++tot;
        }
        int rt = ++tot;    
        get_dist(rt, v, u, val[tp], 1, tp, val[tp]);
        merge(rt1, rt, 1, qr);
        last = tp;
    }
    tot = 0, clear(rt1), clear(rt2);
    for (int i = head[u]; i; i = net[i])
        if (!st[ver[i]])
            calc(ver[i]);
}

int main()
{
    for (int i = 0; i < 10 * N; i++)
        tr[i].v = tr[i].maxn = -INF;
    scanf("%d%d%d%d", &n, &m, &ql, &qr);
    for (int i = 1; i <= m; i++)
        scanf("%d", &val[i]);
    for (int i = 1; i < n; i++)
    {
        int a, b, c;
        scanf("%d%d%d", &a, &b, &c);
        add(a, b, c), add(b, a, c);
    }
    calc(1);
    printf("%d", ans);
    return 0;
}
posted @ 2021-03-28 14:39  DSHUAIB  阅读(92)  评论(0)    收藏  举报