【BZOJ4817】树点涂色(SDOI2017)-LCT+LCA+线段树

测试地址:树点涂色
做法:本题需要用到LCT+LCA+线段树。
首先对于第一个操作,我们发现这个很像LCT中的access操作,那么因为每次涂的颜色不同,我们可以断定同种颜色的点一定是LCT中的一条重链,这样路径上不同的颜色段数就等于路径上轻边的数量+1
于是我们可以维护f(i),表示该点到根路径上的不同颜色段数。那么:
对于第二个操作,答案显然为f(u)+f(v)2f(lca(u,v))+1
对于第三个操作,因为树的结构不变,所以我们维护DFS序,这样询问子树中f(i)的最大值就等同于询问区间最大值,用线段树维护即可。
那么我们怎么修改f(i)呢?注意到,我们access时,若将一条重边变成轻边,那么这条边下面的所有点的f(i)都会+1(因为轻边数量增加了1),反之,若将一条轻边变成重边,那么这条边下面的所有点的f(i)都会1。操作涉及整棵子树的修改,扔到线段树上做即可。根据LCT的均摊复杂度和线段树的复杂度,以上算法的总时间复杂度应该是O(nlog2n)左右的,我也不知道为什么能过得去……
以下是本人代码:

#include <bits/stdc++.h>
using namespace std;
int n,m,first[100010]={0},tot=0,tim=0;
int fa[100010][22],dep[100010],in[100010],out[100010],pos[100010];
int pre[100010],ch[100010][2],seg[400010],p[400010]={0};
bool rt[100010];
struct edge
{
    int v,next;
}e[200010];

void insert(int a,int b)
{
    e[++tot].v=b;
    e[tot].next=first[a];
    first[a]=tot;
}

void dfs(int v)
{
    in[v]=++tim;
    pos[tim]=v;
    for(int i=first[v];i;i=e[i].next)
        if (e[i].v!=fa[v][0])
        {
            fa[e[i].v][0]=v;
            dep[e[i].v]=dep[v]+1;
            dfs(e[i].v);
        }
    out[v]=tim;
}

void pushdown(int no)
{
    if (p[no])
    {
        seg[no<<1]+=p[no],seg[no<<1|1]+=p[no];
        p[no<<1]+=p[no],p[no<<1|1]+=p[no];
        p[no]=0;
    }
}

void pushup(int no)
{
    seg[no]=max(seg[no<<1],seg[no<<1|1]);
}

void buildtree(int no,int l,int r)
{
    if (l==r) {seg[no]=dep[pos[l]];return;}
    int mid=(l+r)>>1;
    buildtree(no<<1,l,mid);
    buildtree(no<<1|1,mid+1,r);
    pushup(no);
}

void modify(int no,int l,int r,int s,int t,int c)
{
    if (l>=s&&r<=t) {seg[no]+=c,p[no]+=c;return;}
    int mid=(l+r)>>1;
    pushdown(no);
    if (s<=mid) modify(no<<1,l,mid,s,t,c);
    if (t>mid) modify(no<<1|1,mid+1,r,s,t,c);
    pushup(no);
}

int query(int no,int l,int r,int s,int t)
{
    if (l>=s&&r<=t) return seg[no];
    int mid=(l+r)>>1,mx=0;
    pushdown(no);
    if (s<=mid) mx=max(mx,query(no<<1,l,mid,s,t));
    if (t>mid) mx=max(mx,query(no<<1|1,mid+1,r,s,t));
    return mx;
}

void rotate(int x,bool f)
{
    int y=pre[x];
    ch[y][!f]=ch[x][f];
    pre[ch[x][f]]=y;
    ch[x][f]=y;
    if (!rt[y]) ch[pre[y]][ch[pre[y]][1]==y]=x;
    else rt[x]=1,rt[y]=0;
    pre[x]=pre[y];
    pre[y]=x;
}

void Splay(int x)
{
    while(!rt[x])
    {
        if (rt[pre[x]]) rotate(x,ch[pre[x]][0]==x);
        else
        {
            int y=pre[x],z=pre[pre[x]];
            bool f=(ch[y][1]==x);
            if (ch[z][f]==y) rotate(y,!f),rotate(x,!f);
            else rotate(x,!f),rotate(x,f);
        }
    }
}

int find_top(int x)
{
    if (ch[x][0]) return find_top(ch[x][0]);
    else return x;
}

void access(int x)
{
    int y,top;
    Splay(x);
    if (ch[x][1])
    {
        rt[ch[x][1]]=1;
        top=find_top(ch[x][1]);
        modify(1,1,n,in[top],out[top],1);
        ch[x][1]=0;
    }
    while(pre[x])
    {
        y=pre[x];
        Splay(y);
        if (ch[y][1])
        {
            rt[ch[y][1]]=1;
            top=find_top(ch[y][1]);
            modify(1,1,n,in[top],out[top],1);
        }
        rt[x]=0;
        top=find_top(x);
        modify(1,1,n,in[top],out[top],-1);
        ch[y][1]=x;
        Splay(x);
    }
}

void init()
{
    scanf("%d%d",&n,&m);
    for(int i=1;i<n;i++)
    {
        int a,b;
        scanf("%d%d",&a,&b);
        insert(a,b),insert(b,a);
    }

    fa[1][0]=0;dep[1]=1;
    dfs(1);
    for(int i=1;i<=20;i++)
        for(int j=1;j<=n;j++)
            fa[j][i]=fa[fa[j][i-1]][i-1];
    for(int i=1;i<=n;i++)
        pre[i]=fa[i][0],ch[i][0]=ch[i][1]=0,rt[i]=1;
}

int lca(int x,int y)
{
    if (dep[x]<dep[y]) swap(x,y);
    for(int i=20;i>=0;i--)
        if (dep[fa[x][i]]>=dep[y]) x=fa[x][i];
    if (x==y) return x;
    for(int i=20;i>=0;i--)
        if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
    return fa[x][0];
}

void work()
{
    buildtree(1,1,n);
    for(int i=1;i<=m;i++)
    {
        int op,x,y;
        scanf("%d",&op);
        if (op==1)
        {
            scanf("%d",&x);
            access(x);
        }
        if (op==2)
        {
            scanf("%d%d",&x,&y);
            int g=lca(x,y),ans=0;
            ans+=query(1,1,n,in[x],in[x]);
            ans+=query(1,1,n,in[y],in[y]);
            ans-=query(1,1,n,in[g],in[g])<<1;
            ans++;
            printf("%d\n",ans);
        }
        if (op==3)
        {
            scanf("%d",&x);
            printf("%d\n",query(1,1,n,in[x],out[x]));
        }
    }
}

int main()
{
    init();
    work();

    return 0;
}
posted @ 2018-04-13 09:06  Maxwei_wzj  阅读(151)  评论(0编辑  收藏  举报