[BJOI2017]树的难题 点分治 线段树

题面

[BJOI2017]树的难题

题解

考虑点分治。
对于每个点,将所有边按照颜色排序。
那么只需要考虑如何合并2条链。
有2种情况。

  • 合并路径的接口处2条路径颜色不同
  • 合并路径的接口处2条路径颜色相同

我们分别考虑这2种情况。
维护2棵线段树,分别表示与当前接口颜色不同和颜色相同。
如果我们遍历完了一棵子树,就将这棵子树的答案加入到颜色相同的线段树里面。
如果我们遍历完了一段颜色,就将第2个线段树合并到第一个线段树里面。
当然更新答案要在上面2个操作之前。
只需要对于当前子树的每条路径,在2棵线段树上分别查询对应长度区间的答案最大值然后合并即可。
注意从颜色相同线段树上查询到的答案合并时需要减一。

// luogu-judger-enable-o2
#include<bits/stdc++.h>
using namespace std;
#define R register int
#define LL long long
#define AC 401000
#define ac 850000
#define inf 9187201950435737472LL

int n, m, rot, lim_l, lim_r, cnt, tinct, top, ss, all, id;
int Head[AC], date[ac], Next[ac], color[ac], tot;
int Size[AC];
LL power[AC], s[AC], f[AC], have[AC], ans = -inf;
bool z[AC];

struct road{
    int x, y, c;
}way[ac];

inline int read()
{
    int x = 0;char c = getchar();bool z_ = false;
    while(c > '9' || c < '0') {if(c == '-') z_ = true; c = getchar();}
    while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    if(!z_) return x;
    else return -x;
}

inline void upmin(int &a, int b) {if(b < a) a = b;}
inline void upmax(LL &a, LL b) {if(b > a) a = b;}
inline void add(int f, int w, int S){date[++ tot] = w, Next[tot] = Head[f], Head[f] = tot, color[tot] = S;}
inline bool cmp(road a, road b){return (a.c < b.c);}

struct seg_tree{
    
    LL tree[ac]; int ls[ac], rs[ac], cnt, root;
    
    void init() {cnt = root = 1, tree[1] = tree[0] = -inf, ls[1] = rs[1] = 0;}
    int make() {tree[++ cnt] = -inf, ls[cnt] = rs[cnt] = 0; return cnt;}
    void update(int x) {tree[x] = max(tree[ls[x]], tree[rs[x]]);}
    
    void ins(int &x, int l, int r, int go, LL w)//只有单点修改?
    {
     	if(!x) x = make();
        if(l == r){upmax(tree[x], w); return ;}
        int mid = (l + r) >> 1;
        if(go <= mid) ins(ls[x], l, mid, go, w);
        else ins(rs[x], mid + 1, r, go, w);
        update(x);
    }
    
    LL find(int x, int l, int r, int ll, int rr)
    {
    	if(!x) return -inf;
        if(l == ll && r == rr) return tree[x];
        int mid = (l + r) >> 1;
        if(rr <= mid) return find(ls[x], l, mid, ll, rr);
        else if(ll > mid) return find(rs[x], mid + 1, r, ll, rr);
        else return max(find(ls[x], l, mid, ll, mid), find(rs[x], mid + 1, r, mid + 1, rr));
    }

}T1, T2;

void merge(){while(id) T1.ins(T1.root, 1, n, have[id], have[id - 1]), id -= 2;}

void getrot(int x, int fa)
{
    f[x] = 0, Size[x] = 1;
    for(R i = Head[x]; i; i = Next[i])
    {
        int now = date[i];
        if(z[now] || now == fa) continue;
        getrot(now, x);
        upmax(f[x], Size[now]);
        Size[x] += Size[now];
    }
    upmax(f[x], ss - Size[x]);
    if(f[x] < f[rot]) rot = x;
}

void dfs(int x, int fa, int last, int num)//找到当前子树的每条线段并加入线段树
{
    //T2.ins(1, 1, n, num, f[x]);
    if(num >= lim_l && num <= lim_r) upmax(ans, f[x]);//不拐弯
    if(num > lim_r) return ; 
    s[++ top] = have[++ id] = f[x], s[++ top] = have[++id] = num;
    int l = max(lim_l - num, 1), r = min(n, lim_r - num);
    if(l <= r) 
    {
        upmax(ans, T2.find(1, 1, n, l, r) + f[x] - power[tinct]);
        upmax(ans, T1.find(1, 1, n, l, r) + f[x]);
    }
    for(R i = Head[x]; i; i = Next[i])
    {
        int now = date[i];
        if(z[now] || now == fa) continue;
        f[now] = f[x] + ((color[i] == last) ? 0 : power[color[i]]);
        dfs(now, x, color[i], num + 1);
    }
}

void cal(int x)
{
    z[x] = true;
    T1.init(), T2.init();
    for(R i = Head[x]; i; i = Next[i])
    {
        int now = date[i];
        if(z[now]) continue;
        tinct = color[i], f[now] = power[tinct], dfs(now, x, color[i], 1);
        while(top) T2.ins(T2.root, 1, n, s[top], s[top - 1]), top -= 2;//放到后面再加入防止用到同一棵子树的点
        if(color[Next[i]] != color[i]) merge(), T2.init();
    }
}

void solve(int x)
{
    //printf("%d\n", x);
    cal(x);
    for(R i = Head[x]; i; i = Next[i])
    {
        int now = date[i];
        if(z[now]) continue;
        rot = 0, f[0] = ss = Size[now];
        getrot(now, 0);
        solve(rot);
    }
}

void pre()
{
    n = read(), m = read(), lim_l = read(), lim_r = read();
    for(R i = 1; i <= m; i ++) power[i] = read();
    for(R i = 1; i < n; i ++) 
    {
    	way[++ all].x = read(), way[all].y = read(), way[all].c = read();
    	way[all + 1] = way[all], ++all, swap(way[all].x, way[all].y);
    }
    sort(way + 1, way + all + 1, cmp);
    for(R i = 1; i <= all; i ++) add(way[i].x, way[i].y, way[i].c);
}

int main()
{
//	freopen("in.in", "r", stdin);
    pre();
    f[rot] = ss = n;//f[x]表示x的子树中最重的那棵的重量
    getrot(1, 0);
    solve(rot);
    printf("%lld\n", ans);
//	fclose(stdin);
    return 0;
}
posted @ 2019-03-05 01:17  ww3113306  阅读(328)  评论(0编辑  收藏  举报
知识共享许可协议
本作品采用知识共享署名-非商业性使用-禁止演绎 3.0 未本地化版本许可协议进行许可。