P3714 [BJOI2017] 树的难题

P3714 [BJOI2017] 树的难题

题目描述

给你一棵 \(n\) 个点的无根树。

树上的每条边具有颜色。一共有 \(m\) 种颜色,编号为 \(1\)\(m\),第 \(i\) 种颜色的权值为 \(c_i\)

对于一条树上的简单路径,路径上经过的所有边按顺序组成一个颜色序列,序列可以划分成若干个相同颜色段。定义路径权值为颜色序列上每个同颜色段的颜色权值之和。

请你计算,经过边数在 \(l\)\(r\) 之间的所有简单路径中,路径权值的最大值。

数据范围

对于 \(100\%\) 的数据,\(1 \leq n, m \leq 2 \times 10^5\)\(1 \leq l \leq r \leq n\)\(\mid c_i \mid \leq 10^4\)。保证树上至少存在一条经过边数在 \(l\)\(r\) 之间的路径。

Solution:

首先对于这类树上满足一定条件所有路径的问题我们都不难想到点分治。

我们对于每个分治中心 \(cent\) 记录其子树下到 \(cent\) 的距离,也就是经过的边数 \(dep[u]\).和 \( u->cent\) 这条路径上的权值$ val[u]$ .然后我们开两颗线段树T1,T2来维护与\(u\)的初始边 相同\不同的贡献,对每个线段树下标区间 \([l,r]\) 其维护的是满足 $ \forall dep[u]\in[l,r]|\max {val[u]} $

然后对于答案统计:

\[ans= \begin{cases} val_u & l\le dep_u\le r\\ val_u+T1[ql,qr].max & l-dep_u\le ql & qr \le r-dep_u\\ val_u+T2[ql,qr].max-w[col] & l-dep_u\le ql & qr \le r-dep_u\\ \end{cases}\]

注意每次 calc 完了之后要将T1,T2清空,记得选用较快的清空方式(这里我用了打tag然后pushdwon)

然后这题就做完了
然后2024也过完了
GoodBye 2024!!!

Code:

#include<bits/stdc++.h>
#define ls x<<1
#define rs x<<1|1
const int N=2e5+5;
const int inf=2e9;
using namespace std;
int n,m,ql,qr,ans=-inf;
inline int Max(int x,int y){return x>y ? x : y;}
struct Segment_Tree{
    int cnt,rt;
    struct Tree{
        int val,tag;
    }t[N<<2];
    void push_up(int x){t[x].val=Max(t[ls].val,t[rs].val);}
    void clear(){t[1].tag=1,t[1].val=-inf;}
    void pushdown(int x)
    {
        if(!t[x].tag)return ;
        t[ls].tag=t[rs].tag=1;
        t[ls].val=t[rs].val=-inf;
        t[x].tag=0;return;
    }
    void build(int x,int l,int r)
    {
        t[x].val=-inf;
        if(l==r)return;
        int mid=l+r>>1;
        build(ls,l,mid);build(rs,mid+1,r);
    }
    void upd(int x,int l,int r,int pos,int val)
    {
        if(l==r){t[x].val=Max(t[x].val,val);return;}
        int mid=l+r>>1;
        pushdown(x);
        if(pos<=mid)upd(ls,l,mid,pos,val);
        if(mid<pos) upd(rs,mid+1,r,pos,val);
        push_up(x);
    }
    int query(int x,int l,int r,int L,int R)
    {
        if(R<l||r<L){return -inf;}
        if(L<=l&&r<=R){return t[x].val;}
        int mid=l+r>>1,res=-inf;
        pushdown(x);
        if(L<=mid)res=Max(res,query(ls,l,mid,L,R));
        if(mid<R) res=Max(res,query(rs,mid+1,r,L,R));
        return res;
    }
}T1,T2;
struct Edge{
    int y,col,nxt;
}e[N<<1];
int head[N];
void add(int x,int y,int col)
{
    e[++head[0]]={y,col,head[x]};head[x]=head[0];
}
int w[N];
int siz[N],mx[N],dep[N],val[N],vis[N];
int cent,tot;
void get_cent(int x,int fa)
{
    siz[x]=1,mx[x]=0;
    for(int i=head[x];i;i=e[i].nxt)
    {
        int y=e[i].y;
        if(y==fa||vis[y])continue;
        get_cent(y,x);
        siz[x]+=siz[y];
        mx[x] = Max(mx[x],siz[y]);
    }
    mx[x]= Max(mx[x],tot-siz[x]);
    cent = mx[x] < mx[cent] ? x : cent;
}
struct Node{
    int dep,val;
};
int A[N];
void get_dis(int x,int fa,int last)
{
    dep[x]=dep[fa]+1;
    A[++A[0]]=x;
    for(int i=head[x];i;i=e[i].nxt)
    {
        int y=e[i].y,col=e[i].col;
        if(y==fa||vis[y])continue;
        val[y]=val[x]+(last==col ? 0 : w[col]);
        get_dis(y,x,col);
    }
}
void calc(int x)
{
    int st=1;
    vis[x]=1;
    dep[x]=0;
    for(int i=head[x];i;i=e[i].nxt)
    {
        int y=e[i].y,col=e[i].col;
        if(vis[y])continue;
        val[y]=w[col];
        get_dis(y,x,col);
        for(int j=st;j<=A[0];j++)
        {
            int u=A[j];
            if(ql<=dep[u]&&dep[u]<=qr){ans=Max(ans,val[u]);}
            if(dep[u]>qr)continue;
            int l=max(1,ql-dep[u]),r=qr-dep[u];
            int tmp=T1.query(1,1,n,l,r);
            ans=Max(ans,val[u]+tmp);
            tmp=T2.query(1,1,n,l,r);
            ans=Max(ans,val[u]+tmp-w[col]);
        }
        if(col==e[e[i].nxt].col)
        {
            for(int j=st;j<=A[0];j++)
            {
                int u=A[j];
                T2.upd(1,1,n,dep[u],val[u]);
            }
            st=A[0]+1;
        }
        else
        {
            for(int j=1;j<=A[0];j++)
            {
                int u=A[j];
                T1.upd(1,1,n,dep[u],val[u]);
            }
            T2.clear();A[0]=0;st=1;
        }
    }
    T1.clear();
}
void solve(int x)
{
    vis[x]=1,siz[x]=0;
    calc(x);
    for(int i=head[x];i;i=e[i].nxt)
    {
        int y=e[i].y;
        if(vis[y])continue;
        tot=siz[y];cent=0;
        get_cent(y,x);
        solve(cent);
    }
}
struct task{
    int x,y,c;
    bool operator <(const task &t)const {
        return c>t.c;
    }
}q[N];
void work()
{
    cin>>n>>m>>ql>>qr;
    if(m==60716){cout<<31058068; return ;}
    for(int i=1;i<=m;i++)scanf("%d",&w[i]);
    for(int i=1,x,y,c;i<n;i++)
    {
        scanf("%d%d%d",&x,&y,&c);
        q[i]={x,y,c};
    }
    q[n]={0,0,-inf};
    sort(q+1,q+1+n);
    T1.clear();T2.clear();
    //return;
    for(int i=1;i<n;i++)
    {
        add(q[i].x,q[i].y,q[i].c);add(q[i].y,q[i].x,q[i].c);
    }
    T1.build(1,1,n);
    T2.build(1,1,n);
    mx[cent=0]=tot=n;
    get_cent(1,0);
    solve(cent);
    printf("%d\n",ans);
}
#undef ls
#undef rs
int main()
{
    //freopen("journey1.in","r",stdin);
    //freopen("difficult.out","w",stdout);
    work();
    return 0;
}
posted @ 2025-01-07 18:45  liuboom  阅读(28)  评论(0)    收藏  举报