2021“MINIEYE杯”中国大学生算法设计超级联赛(2)(1002 I love tree)(树状数组+树链剖分)

传送门

前置知识:树状数组 差分 树链剖分 LCA

对树上路径经过的点进行操作,实际上是对区间维护一个函数。

开三个树状数组维护函数的三个系数。

都是基本操作,具体看代码注释。

题外话:

上次写树剖还是两年前(? 这几天重新又学了遍  树状数组学习博客  我的树剖板子们

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;

#define N 100002
#define ll long long

int n,m;

int sumedge,cnt;

int head[N];

int siz[N],dad[N],top[N],son[N],deep[N],tpos[N];

ll tr1[N],tr2[N],tr3[N];

struct Edge
{
    int x,y,nxt;
    Edge(int x=0,int y=0,int nxt=0):x(x),y(y),nxt(nxt){}
}edge[N<<1];

void add(int x,int y)
{
    edge[++sumedge]=Edge(x,y,head[x]);
    head[x]=sumedge;
}

void dfs(int x)
{
    siz[x]=1;deep[x]=deep[dad[x]]+1;
    for(int i=head[x];i;i=edge[i].nxt)
    {
        int v=edge[i].y;
        if(v==dad[x]) continue;
        dad[v]=x;
        dfs(v);
        siz[x]+=siz[v];
    }
}

void dfs_(int x)
{
    int s=0;tpos[x]=++cnt;
    if(!top[x]) top[x]=x;
    for(int i=head[x];i;i=edge[i].nxt)
    {
        int v=edge[i].y;
        if(v!=dad[x]&&siz[v]>siz[s]) s=v;
    }
    if(s)
    {
        top[s]=top[x];
        dfs_(s);
    }
    for(int i=head[x];i;i=edge[i].nxt)
    {
        int v=edge[i].y;
        if(v!=dad[x]&&v!=s) dfs_(v);
    }
}

int LCA(int x,int y)
{
    for(;top[x]!=top[y];)
    {
        if(deep[top[x]]>deep[top[y]]) swap(x,y);
        y=dad[top[y]];
    }
    if(deep[x]>deep[y]) swap(x,y);
    return x;
}

int lowbit(int x)
{
    return x&(-x);
}

void add_tree(ll d[],int x,ll v)
{
    for(int i=x;i<=n;i+=lowbit(i))
    {
        d[i]+=v;
    }
}

ll get_sum(ll d[],int x)
{
    ll res=0;
    for(int i=x;i>=1;i-=lowbit(i))
    {
        res+=d[i];
    }
    return res;
}

void change(ll d[],int stp,int edp,ll v)
{
    add_tree(d,stp,v);
    add_tree(d,edp+1,-v);
}

void update(int x,int y,int len)
{
    int p1=1,p2=len; //两个端点 假设区间[1,2,3,4,5]需要加[1^2,2^2,3^2,4^2,5^2],则p1=1,p2=5;
    for(;top[x]!=top[y];)
    {
        if(deep[top[x]]>deep[top[y]]) //跳x所在的链
        {
            int ed=tpos[x];
            int st=tpos[top[x]];
              /*对区间[st~ed]操作,对于区间中下标为i的,加上(ed-i+p1)^2=i^2-2*(ed+p1)*i+(ed+p1)^2;*/
            change(tr1,st,ed,1); //二次项系数
            change(tr2,st,ed,-1ll*(ed+p1));//一次项
            change(tr3,st,ed,1ll*(ed+p1)*(ed+p1));//常数项
            p1=p1+ed-st+1; //下一个区间开始加的平方数
            x=dad[top[x]];
        }else{               //跳y所在的链
            int ed=tpos[y];
            int st=tpos[top[y]];
            /*对区间[st~ed]操作,令gg=p2-(ed-st);
             对于区间中下标为i的,加上(i-st+gg)^2=i^2+2*(gg-st)+(gg-st)^2;*/
            int gg=p2-(ed-st);
            change(tr1,st,ed,1);//二次项
            change(tr2,st,ed,(gg-st));//一次项
            change(tr3,st,ed,1ll*(gg-st)*(gg-st));//常数项
            y=dad[top[y]];
            p2=gg-1;
        }
    }
    if(deep[x]<=deep[y]) //要从x跳到y
    {
        int st=tpos[x];    //对区间[st~ed]操作
        int ed=tpos[y];
         /*对区间[st~ed]操作,对于区间中下标为i的,加上(i-st+p1)^2=i^2+2*(p1-st)*i+(p1-st)*(p1-st);*/
        change(tr1,st,ed,1);//二次项
        change(tr2,st,ed,p1-st);//一次项
        change(tr3,st,ed,1ll*(p1-st)*(p1-st)); //常数项
    }else{
        int st=tpos[y];
        int ed=tpos[x];
        /*对区间[st~ed]操作,对于区间中下标为i的,加上(ed-i+p1)^2=i^2-2*(ed+p1)*i+(ed+p1)*(ed+p1);*/
        change(tr1,st,ed,1); //二次项
        change(tr2,st,ed,-1ll*(ed+p1));//一次项
        change(tr3,st,ed,1ll*(ed+p1)*(ed+p1));//常数项
    }
}

int main()
{
    scanf("%d",&n);
    for(int i=1;i<n;i++)
    {
        int x,y;
        scanf("%d%d",&x,&y);
        add(x,y);
        add(y,x);
    }
    dfs(1);
    dfs_(1);
    scanf("%d",&m);
    for(int i=1;i<=m;i++)
    {
        int od;
        scanf("%d",&od);
        if(od==1)
        {
            int x,y,lca;
            scanf("%d%d",&x,&y);
            lca=LCA(x,y);
            int len=deep[x]+deep[y]-2*deep[lca]+1; //从x跳到y之间的点的数目
            update(x,y,len);
        }else
        {
            int x;
            scanf("%d",&x);
            x=tpos[x];
            ll aa=get_sum(tr1,x); //二次项系数
            ll bb=get_sum(tr2,x); //一次项
            ll cc=get_sum(tr3,x); //常数项
            ll ans=aa*x*x+2*bb*x+cc;
            printf("%lld\n",ans);
        }
    }
    return 0;
}

 

posted @ 2021-07-26 21:13  ANhour  阅读(73)  评论(0编辑  收藏  举报