Luogu「StOI-2」简单的树 树链剖分+线段树+倍增

考场的时候智障了,写了 6k+ 的树链剖分.   

如果题目带修改的话可以用树链剖分来维护,但由于没有修改用一个前缀和其实就够了.  

求 $\sum_{i=l}^{r} f(a,i)$ 可以写成两个前缀相减的形式.  

然后我们就要求 $\sum_{i=0}^{r} f(a,i)$.    

求这个的话用倍增讨论 $a$ 的初始值的影响范围,因为在影响范围内刚开始都是由子树中次大值来贡献.  

然后这个次大值显然单调,我们就可以找到贡献会比次大值大的临界点,贡献是一个等差数列的形式,维护平方和以及区间和即可.   

然后对于 $a$ 的初始值贡献不到的地方也这么讨论一下即可.   

如果加上一个带修改还真的挺毒瘤的,不过反正考场上总共花了 60 多分钟就过掉了. 

代码: 

#include <cstdio>
#include <cstring>
#include <algorithm> 
#define N 500009   
#define ll long long  
#define mod 998244353   
#define lson now<<1 
#define rson now<<1|1  
#define setIO(s) freopen(s".in","r",stdin)  
using namespace std;  
ll lastans;  
int edges,n,Q,OPT,tim;  
int hd[N],to[N<<1],nex[N<<1],val[N];  
int fa[N],size[N],son[N],dfn[N],bu[N],f[20][N],dep[N],top[N]; 
void add(int u,int v) { 
    nex[++edges]=hd[u]; 
    hd[u]=edges,to[edges]=v; 
}     
int DECODE(int x) {  
    ll y=1ll*x+1ll*OPT*lastans; 
    y%=n;        
    ++y;  
    return (int)y;  
}  
struct data { 
    ll sqr,sum;  
    data() { sqr=sum=0; }
    data operator+(const data b) const { 
        data c;  
        c.sqr=sqr+b.sqr; 
        c.sum=sum+b.sum; 
        return c; 
    }
}; 
struct node { 
    data se,mx;           
    node operator+(const node b) const { 
        node c; 
        c.se=se+b.se; 
        c.mx=mx+b.mx;  
        return c;  
    }
}s[N<<2];  
struct Tree { 
    int mx,se;  
    Tree(int mx=0,int se=0):mx(mx),se(se){}      
    Tree operator+(const Tree b) const {        
        Tree c;  
        c.mx=c.se=0;  
        if(mx<b.mx) {   
            c.mx=b.mx;  
            c.se=max(mx,b.se);   
        } 
        if(mx>b.mx) {    
            c.mx=mx;  
            c.se=max(se,b.mx);  
        }     
        if(mx==b.mx) {
            c.se=c.mx=mx;  
        }
        return c;  
    }
}tree[N];  
void dfs0(int x,int ff) {   
    fa[x]=ff,size[x]=1; 
    dep[x]=dep[ff]+1;  
    f[0][x]=fa[x];  
    tree[x]=Tree(val[x],0);   
    for(int i=hd[x];i;i=nex[i]) {       
        int y=to[i]; 
        if(y==ff) continue;  
        dfs0(y,x);  
        size[x]+=size[y]; 
        if(size[y]>size[son[x]]) son[x]=y;  
        tree[x]=tree[x]+tree[y];  
    }
}
void dfs1(int x,int tp) {   
    top[x]=tp; 
    dfn[x]=++tim; 
    bu[tim]=x;   
    if(son[x]) {    
        dfs1(son[x],tp);
    } 
    for(int i=hd[x];i;i=nex[i]) { 
        int y=to[i];
        if(y==fa[x]||y==son[x]) continue;  
        dfs1(y,y);   
    }
}  
void build(int l,int r,int now) { 
    if(l==r) {  
        int cur=bu[l]; 
        s[now].mx.sum=tree[cur].mx;  
        s[now].mx.sqr=1ll*tree[cur].mx*tree[cur].mx;  
        s[now].se.sum=tree[cur].se;  
        s[now].se.sqr=1ll*tree[cur].se*tree[cur].se; 
        return; 
    }  
    int mid=(l+r)>>1;  
    build(l,mid,lson),build(mid+1,r,rson); 
    s[now]=s[lson]+s[rson];  
}    
node query(int l,int r,int now,int L,int R) {   
    if(l>=L&&r<=R) {     
        return s[now];
    } 
    int mid=(l+r)>>1;  
    if(L<=mid&&R>mid) return query(l,mid,lson,L,R)+query(mid+1,r,rson,L,R);  
    else if(L<=mid)   return query(l,mid,lson,L,R); 
    else return query(mid+1,r,rson,L,R);  
}
node Query(int x,int y) { 
    node re;  
    while(top[x]!=top[y]) {  
        if(dep[top[x]]>dep[top[y]]) {       
            re=re+query(1,n,1,dfn[top[x]],dfn[x]);  
            x=fa[top[x]]; 
        }
        else { 
            re=re+query(1,n,1,dfn[top[y]],dfn[y]); 
            y=fa[top[y]];  
        }
    }  
    if(dep[x]>dep[y]) { 
        swap(x,y); 
    }  
    re=re+query(1,n,1,dfn[x],dfn[y]); 
    return re;  
}
ll solve(int x,int r) { 
    if(r<0) return 0;  
    // 极长最大值小于等于 val[x] 的     
    int tar=x; 
    for(int i=19;i>=0;--i) {   
        if(!f[i][tar]) continue;    
        // tree[f[i][tar]].mx<=val[x] 
        if(tree[f[i][tar]].mx<=val[x]) { 
            tar=f[i][tar];  
        }
    }  
    ll ans=0;    
    if(tree[tar].mx<=val[x]) {    
        // 存在这么一段   
        // x -> tar 这一段   
        // 先变成 0,故这一段的贡献先是 
        int pr=x;    
        for(int i=19;i>=0;--i) { 
            if(!f[i][pr]||dep[f[i][pr]]<dep[tar]) continue;     
            if(tree[f[i][pr]].se<r) pr=f[i][pr];   
        }          
        if(tree[pr].se<r) {  
            // 等差数列求和   
            node e=Query(x,pr);  
            ans+=e.se.sum*1ll*(r+1)%mod;        // 共 r+1 个时刻    
            int num=dep[x]-dep[pr]+1;   
            ll tm=1ll*r*r*num-2ll*r*e.se.sum+1ll*r*num+e.se.sqr-e.se.sum;   
            ans+=tm/2;      
            pr=fa[pr];  
        }   // 这部分算好了        
        if(dep[pr]>=dep[tar]) { 
            // 永远都不变的大哥     
            node e=Query(pr,tar);   
            ans+=e.se.sum*1ll*(r+1)%mod;  
            ans%=mod; 
        }     
        tar=fa[tar];  
    }    
    if(tar) {       
        // 其余是要依靠 r 来改变的    
        node e=Query(tar,1);      
        ans+=e.mx.sum*1ll*(r+1);     
        int pr=tar;           
        for(int i=19;i>=0;--i) { 
            if(!f[i][pr]) continue;     
            if(tree[f[i][pr]].mx<r) pr=f[i][pr];   
        }     
        if(tree[pr].mx<r) {    
            e=Query(tar,pr);        
            int num=dep[tar]-dep[pr]+1;   
            ll tm=1ll*r*r*num-2ll*r*e.mx.sum+1ll*r*num+e.mx.sqr-e.mx.sum;   
            ans+=tm/2;  
            ans%=mod;  
        }
    }
    return ans%mod;   
}
char buf[100000],*p1,*p2;
#define nc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++)
int rd()
{
    int x=0; char s=nc();
    while(s<'0') s=nc();
    while(s>='0') x=(((x<<2)+x)<<1)+s-'0',s=nc();
    return x;
}  
int main() {
    /// setIO("input");     
    int x,y,z;   
    n=rd(),Q=rd(),OPT=rd();  
    for(int i=1;i<=n;++i) { 
        val[i]=rd();    
    }   
    for(int i=1;i<n;++i) { 
        x=rd(),y=rd();  
        add(x,y),add(y,x); 
    }  
    dfs0(1,0); 
    dfs1(1,1);   
    build(1,n,1); 
    for(int i=1;i<19;++i) {     
        for(int j=1;j<=n;++j) { 
            f[i][j]=f[i-1][f[i-1][j]];  
        }
    }      
    ll fin=s[1].mx.sum;  
    for(int i=1;i<=Q;++i) { 
        int l=rd(),r=rd(),a=rd();     
        l=DECODE(l),r=DECODE(r),a=DECODE(a);  
        if(l>r) { 
            swap(l,r);  
        }        
        node e=Query(1,a);    
        ll cur=(fin-1ll*e.mx.sum%mod)*(r-l+1)%mod;     
        printf("%lld\n",lastans=(ll)(cur+solve(a,r)-solve(a,l-1)+mod)%mod);  
    }  
    return 0;
}

  

posted @ 2020-08-14 09:37  EM-LGH  阅读(62)  评论(0编辑  收藏