欢迎来到endl的博客hhh☀☾☽♡♥

浏览器标题切换
把博客园图标替换成自己的图标
把博客园图标替换成自己的图标end

Splay大法好

新手推荐阅读:splay详解(一)Splay入门解析【保证让你看不懂(滑稽)】


 打算记点关于 Splay 的笔记

splay嘛,本质上是一棵BST(即二叉查找树)。这棵树上的每一个节点的左孩子都比它小,右孩子都比它大,也就是说这棵树需要维护中序遍历。

【核心操作】

splay(x,y):把点 x 旋转到点 y 下面。

注意:当 x, y, z 在一条直线上时,先转 y 再转 x,否则先转 x 然后再转一遍 x。

void splay(int x,int o) {
    if(!o) root=x;//更新根节点 
    while(fa(x)!=o)    {
        int y=fa(x),z=fa(y);
        if(z!=o) {
            if(chk(x)^chk(y)) rotate(x);//x,y,z不在一条直线上 
            else rotate(y);    //x,y,z在一条直线上 
        }
        rotate(x);
    }
}

rotate (x):改变三对节点的父子关系,具体看图

void rotate(int x) {
    int k=chk(x),y=fa(x),z=fa(y),w=a[x].ch[k^1];
    a[y].ch[k]=w;fa(w)=y;
    a[z].ch[chk(y)]=x;fa(x)=z;
    a[x].ch[k^1]=y;fa(y)=x;//顺序不可以随意改变! 
    pushup(y);pushup(x);
}

树上的每一个节点不仅代表了原序列的一个值,还记录了一段序列(即它的子树)的相关信息,因此操作中要维护每个节点的 size,tag 等等。

struct hh{
    int ch[2],fa,rev,val,siz;//用ch[0]表示左孩子,ch[1]表示右孩子 
}a[N];

建树前先在首尾加两个值分别为 inf(极大) 和 -inf(极小)的节点以免出现莫名错误。

b[1]=-inf;b[n+2]=inf;
for(R i=1;i<=n;++i) b[i+1]=i;
root=build(1,n+2,0);

chk (x):查询 x 是 a[x].fa 的左孩子还是右孩子

int chk(int x) {return x==rs(fa(x));}//如果是右孩子就返回1,否则返回0

向上和向下维护信息

void pushup(int x) {a[x].siz=a[ls(x)].siz+a[rs(x)].siz+1;}
void pushdown(int x) {
    if(!a[x].rev)    return ;
    swap(ls(x),rs(x));
    a[ls(x)].rev^=1;a[rs(x)].rev^=1;
    a[x].rev=0;
}

建树

int build(int l,int r,int f) {
    if(l>r) return 0;
    int mid=l+r>>1,id=++num;
    fa(id)=f;a[id].siz=1;a[id].val=b[mid];
    ls(id)=build(l,mid-1,id);
    rs(id)=build(mid+1,r,id);
    pushup(id);
    return id;
}
//a[id]不仅代表原序列中的b[mid],还代表[l,r]这段区间

找区间对应的节点

int find(int k,int x) {
    pushdown(k);
    int cnt=a[ls(k)].siz;
    if(cnt+1==x) return k;
    if(cnt>=x) return find(ls(k),x);
    else return find(rs(k),x-cnt-1);
}
void work(int l,int r) {
    int x=find(root,l-1),y=find(root,r+1);
    splay(x,0);splay(y,x);
    a[ls(y)].rev^=1;
}//work函数是翻转[l,r]区间,该区间对应节点即ls(y)

中序遍历输出整段序列

void print(int x) {
    pushdown(x);
    if(ls(x)) print(ls(x));    
    if(a[x].val!=inf&&a[x].val!=-inf) printf("%d ",a[x].val);
    if(rs(x)) print(rs(x));
}

洛谷P3391【模板】文艺平衡树为例:

这道题只有翻转操作,每个节点只要再维护一个 rev(代表是否翻转)

 1 #include<bits/stdc++.h>
 2 #define fa(x) a[x].fa
 3 #define ls(x) a[x].ch[0]
 4 #define rs(x) a[x].ch[1]
 5 #define R register int
 6 
 7 using namespace std;
 8 const int mod=10000,N=1e5+5,inf=0x3f3f3f3f;
 9 
10 int read() {
11     int f=1;char ch;
12     while((ch=getchar())<'0'||ch>'9') if(ch=='-') f=-1;
13     int res=ch-'0';
14     while((ch=getchar())>='0'&&ch<='9') res=res*10+ch-'0';
15     return f*res;
16 }
17 
18 struct hh{
19     int ch[2],fa,rev,val,siz;
20 }a[N];
21 int b[N],n,m,num,root;
22 
23 int chk(int x) {return x==rs(fa(x));}
24 void pushup(int x) {a[x].siz=a[ls(x)].siz+a[rs(x)].siz+1;}
25 void pushdown(int x) {
26     if(!a[x].rev)    return ;
27     swap(ls(x),rs(x));
28     a[ls(x)].rev^=1;a[rs(x)].rev^=1;
29     a[x].rev=0;
30 }
31 
32 int build(int l,int r,int f) {
33     if(l>r) return 0;
34     int mid=l+r>>1,id=++num;
35     fa(id)=f;a[id].siz=1;a[id].val=b[mid];
36     ls(id)=build(l,mid-1,id);
37     rs(id)=build(mid+1,r,id);
38     pushup(id);
39     return id;
40 }
41 
42 int find(int k,int x) {
43     pushdown(k);
44     int cnt=a[ls(k)].siz;
45     if(cnt+1==x) return k;
46     if(cnt>=x) return find(ls(k),x);
47     else return find(rs(k),x-cnt-1);
48 }
49 
50 void rotate(int x) {
51     int k=chk(x),y=fa(x),z=fa(y),w=a[x].ch[k^1];
52     a[y].ch[k]=w;fa(w)=y;
53     a[z].ch[chk(y)]=x;fa(x)=z;
54     a[x].ch[k^1]=y;fa(y)=x;
55     pushup(y);pushup(x);
56 }
57 
58 void splay(int x,int o) {
59     if(!o) root=x;
60     while(fa(x)!=o)    {
61         int y=fa(x),z=fa(y);
62         if(z!=o) {
63             if(chk(x)^chk(y)) rotate(x);
64             else rotate(y);    
65         }
66         rotate(x);
67     }
68 }
69 
70 void work(int l,int r) {
71     int x=find(root,l-1),y=find(root,r+1);
72     splay(x,0);splay(y,x);
73     a[ls(y)].rev^=1;
74 }
75 
76 void print(int x) {
77     pushdown(x);
78     if(ls(x)) print(ls(x));    
79     if(a[x].val!=inf&&a[x].val!=-inf) printf("%d ",a[x].val);
80     if(rs(x)) print(rs(x));
81 }
82 
83 int main() { 
84     n=read();m=read();
85     b[1]=-inf;b[n+2]=inf;
86     for(R i=1;i<=n;++i) b[i+1]=i;
87     root=build(1,n+2,0);
88     while(m--) {
89         int l=read()+1,r=read()+1;
90         work(l,r);
91     }
92     print(root);
93   return 0;
94 }
代码在这里

 (去年 noip 之后码风变了很多,个人觉得更美观了,还加上了宏定义什么的方便嵌套

 


 

接下来我尝试了一道相当毒瘤的题:洛谷P2042 [NOI2005] 维护数列

操作一:在序列的第 pos 和 pos+1 个数字之间插入 tot 个数字

注意,题目中给出的 pos 在我们的操作中对应的其实是 pos+1,因为序列首端多加入了一个 -inf

根据 splay 的性质,我们先把 pos 移至根节点,再把 pos+1 移到 pos 下面,这样 pos+1 的左孩子就是我们插入序列的位置(因为这个位置是 pos+1 的左孩子,同时也是 pos 的右孩子的子树中的一个,即它比 pos+1 小,但比 pos 大)

新插入的序列也先建成一棵 splay 再插入

void insert(int pos,int tot) {
    for(R i=1;i<=tot;++i) b[i]=read();
    int id=build(1,tot,0);
    int x=find(root,pos),y=find(root,pos+1);
    splay(x,0);splay(y,x);
    ls(y)=id;fa(id)=y;
    pushup(y);pushup(x);
}

 

操作二:删除 [pos, pos+tot-1] 这个区间

如果我们把 pos-1 移至根节点,再把 pos+tot 移到 pos 下面,这样 pos+1 的左孩子就是序列中的 [pos, pos+tot-1],那么要将这个序列删除,把 pos+1 的左孩子记为空即可。但因为本题直接这么做会爆空间,那就把已经删除了的节点先存储起来,以备后用,避免开太多空间。我在代码里采用的是压栈的方法。

void recycle(int x) {
    if(!x) return ;
    st[++top]=x;
    recycle(ls(x)),recycle(rs(x));
}//回收节点x

void del(int pos,int tot) {
    int x=find(root,pos-1),y=find(root,pos+tot);
    splay(x,0);splay(y,x);
    recycle(ls(y));ls(y)=0;
    pushup(y);pushup(x);
}

 

操作三:将 [pos, pos+tot-1] 这个区间的值全部改为 c

像上面那样先找到 [pos, pos+tot-1],然后给代表这个区间的节点都加上和赋值有关的懒标记即可(pushup 别忘了!

void Tag(int x,int c) {
    if(!x) return;
    a[x].val=c;
    a[x].sum=c*a[x].siz;
    a[x].mx=max(a[x].sum,c);
    a[x].lm=a[x].rm=max(0,a[x].mx);//a[x].lm和a[x].rm可以为0,因为a[x].mx不一定要包含ls(x)或rs(x) 
    a[x].tag=1;
}

void make_same(int pos,int tot,int c) {
    int x=find(root,pos-1),y=find(root,pos+tot);
    splay(x,0);splay(y,x);
    Tag(ls(y),c);
    pushup(y);pushup(x);
}    

 

操作四:翻转 [pos, pos+tot-1] 这个区间

同样先找到该区间,给代表节点打上和翻转有关的懒标记

void Rev(int x) {
    if(!x) return ;
    swap(a[x].lm,a[x].rm);//注意此处交换lm和rm! 
    swap(ls(x),rs(x));
    a[x].rev^=1;
}

void rev(int pos,int tot) {
    int x=find(root,pos-1),y=find(root,pos+tot);
    splay(x,0);splay(y,x);
    Rev(ls(y));
    pushup(y);pushup(x);
}

 

操作五:求 [pos, pos+tot-1] 这个区间的和

建树时让每个节点维护所对应区间的和,查询时只要找到 [pos, pos+tot-1] 这个区间的代表节点就可以 O(1) 输出

int query(int pos,int tot) {
    int x=find(root,pos-1),y=find(root,pos+tot);
    splay(x,0);splay(y,x);
    return a[ls(y)].sum;
}

 

操作六:求整个序列的最大子段和

这个子问题的做法类似于洛谷P4513 小白逛公园。建树时让每个节点维护所对应区间的最大子段和,推出转移方程即可。

int max_sum() {
    int x=find(root,1),y=find(root,n+2);
    splay(x,0);splay(y,x);
    return a[ls(y)].mx;
}    
void pushup(int x) {
    a[x].siz=a[ls(x)].siz+a[rs(x)].siz+1;
    a[x].sum=a[ls(x)].sum+a[rs(x)].sum+a[x].val;
    a[x].lm=max(a[ls(x)].lm,a[ls(x)].sum+a[x].val+a[rs(x)].lm);
    a[x].rm=max(a[rs(x)].rm,a[rs(x)].sum+a[x].val+a[ls(x)].rm);
    a[x].mx=max(max(a[ls(x)].mx,a[rs(x)].mx),a[ls(x)].rm+a[x].val+a[rs(x)].lm);
}
int build(int l,int r,int f) {
   if(l>r) return 0;
   int mid=l+r>>1,id=get();
   a[id].val=b[mid];fa(id)=f;a[id].siz=1;
   a[id].lm=a[id].rm=max(0,a[id].val);
   a[id].mx=a[id].sum=a[id].val;
   a[id].rev=a[id].tag=0;
   ls(id)=build(l,mid-1,id);
   rs(id)=build(mid+1,r,id);
   pushup(id);
   return id;
}

这道题毒瘤之处在于有很多坑点,洛谷讨论帖一堆“告诫后人”,我也调了好久,感谢喻队帮忙看代码orz

详情见代码吧

  1 #include<bits/stdc++.h>
  2 #define R register int
  3 #define ls(x) a[x].ch[0]
  4 #define rs(x) a[x].ch[1]
  5 #define fa(x) a[x].fa
  6 
  7 using namespace std;
  8 const int N=5e5+5,inf=1e9;
  9 
 10 int read() {
 11     int f=1;char ch;
 12     while((ch=getchar())<'0'||ch>'9') if(ch=='-') f=-1;
 13     int res=ch-'0';
 14     while((ch=getchar())>='0'&&ch<='9') res=res*10+ch-'0';    
 15     return res*f;
 16 }
 17 
 18 int n,m,st[N],top,root,b[N];
 19 struct hh{
 20     int ch[2],lm,rm,mx,sum,siz,val,fa,tag,rev;
 21 }a[N];
 22 
 23 int chk(int x) {return x==rs(fa(x));}
 24 int get() {return st[top--];}
 25 
 26 void Rev(int x) {
 27     if(!x) return ;
 28     swap(a[x].lm,a[x].rm);//注意此处交换lm和rm! 
 29     swap(ls(x),rs(x));
 30     a[x].rev^=1;
 31 }
 32 
 33 void Tag(int x,int c) {
 34     if(!x) return;
 35     a[x].val=c;
 36     a[x].sum=c*a[x].siz;
 37     a[x].mx=max(a[x].sum,c);
 38     a[x].lm=a[x].rm=max(0,a[x].mx);//a[x].lm和a[x].rm可以为0,因为a[x].mx不一定要包含ls(x)或rs(x) 
 39     a[x].tag=1;
 40 }
 41 
 42 void pushup(int x) {
 43     a[x].siz=a[ls(x)].siz+a[rs(x)].siz+1;
 44     a[x].sum=a[ls(x)].sum+a[rs(x)].sum+a[x].val;
 45     a[x].lm=max(a[ls(x)].lm,a[ls(x)].sum+a[x].val+a[rs(x)].lm);
 46     a[x].rm=max(a[rs(x)].rm,a[rs(x)].sum+a[x].val+a[ls(x)].rm);
 47     a[x].mx=max(max(a[ls(x)].mx,a[rs(x)].mx),a[ls(x)].rm+a[x].val+a[rs(x)].lm);
 48 }
 49 
 50 void pushdown(int x) {
 51     if(!x) return ;
 52     if(a[x].tag) {
 53         Tag(ls(x),a[x].val),Tag(rs(x),a[x].val);//也可以在这里先判断一下ls(x)、rs(x)是否为空 
 54         a[x].tag=a[x].rev=0;
 55     }
 56     if(a[x].rev) {
 57         Rev(ls(x)),Rev(rs(x));//同上 
 58         a[x].rev=0;
 59     }
 60 }
 61 
 62 void rotate(int x) {
 63     int y=fa(x),z=fa(y),k=chk(x),w=a[x].ch[k^1];
 64     a[y].ch[k]=w;fa(w)=y;
 65     a[z].ch[chk(y)]=x;fa(x)=z; 
 66     a[x].ch[k^1]=y;fa(y)=x;
 67     pushup(y);pushup(x);
 68 }
 69 
 70 void splay(int x,int o) {
 71     if(!o) root=x;
 72     while(fa(x)!=o)    {
 73         int y=fa(x),z=fa(y);
 74         if(z!=o) {
 75             if(chk(x)^chk(y)) rotate(x);
 76             else rotate(y);    
 77         }
 78         rotate(x);
 79     }
 80 }
 81 
 82 int find(int x,int k) {
 83     pushdown(x);//记得标记下传! 
 84     int cnt=a[ls(x)].siz;
 85     if(cnt==k-1) return x;
 86     if(cnt>=k) return find(ls(x),k);
 87     else return find(rs(x),k-cnt-1);
 88 }
 89 
 90 int build(int l,int r,int f) {
 91     if(l>r) return 0;
 92     int mid=l+r>>1,id=get();
 93     a[id].val=b[mid];fa(id)=f;a[id].siz=1;
 94     a[id].lm=a[id].rm=max(0,a[id].val);
 95     a[id].mx=a[id].sum=a[id].val;
 96     a[id].rev=a[id].tag=0;
 97     ls(id)=build(l,mid-1,id);
 98     rs(id)=build(mid+1,r,id);
 99     pushup(id);
100     return id;
101 }
102 
103 int max_sum() {
104     int x=find(root,1),y=find(root,n+2);
105     splay(x,0);splay(y,x);
106     return a[ls(y)].mx;
107 }    
108 
109 void make_same(int pos,int tot,int c) {
110     int x=find(root,pos-1),y=find(root,pos+tot);
111     splay(x,0);splay(y,x);
112     Tag(ls(y),c);
113     pushup(y);pushup(x);
114 }    
115 
116 void insert(int pos,int tot) {
117     for(R i=1;i<=tot;++i) b[i]=read();
118     int id=build(1,tot,0);
119     int x=find(root,pos),y=find(root,pos+1);
120     splay(x,0);splay(y,x);
121     ls(y)=id;fa(id)=y;
122     pushup(y);pushup(x);
123 }
124 
125 void recycle(int x) {
126     if(!x) return ;
127     st[++top]=x;
128     recycle(ls(x)),recycle(rs(x));
129 }
130 
131 void del(int pos,int tot) {
132     int x=find(root,pos-1),y=find(root,pos+tot);
133     splay(x,0);splay(y,x);
134     recycle(ls(y));ls(y)=0;
135     pushup(y);pushup(x);
136 }
137 
138 void rev(int pos,int tot) {
139     int x=find(root,pos-1),y=find(root,pos+tot);
140     splay(x,0);splay(y,x);
141     Rev(ls(y));
142     pushup(y);pushup(x);
143 }
144 
145 int query(int pos,int tot) {
146     int x=find(root,pos-1),y=find(root,pos+tot);
147     splay(x,0);splay(y,x);
148     return a[ls(y)].sum;
149 }
150 
151 int main() {
152     a[0].mx=-inf;//attention!
153     n=read(),m=read();
154     for(R i=1;i<N;i++) st[i]=i;top=N-1;//st用来回收节点 
155     
156     for(R i=1;i<=n;++i) b[i+1]=read();
157     b[1]=-inf,b[n+2]=inf;//attention!
158     root=build(1,n+2,0);
159     while(m--) {
160         char s[12];scanf("%s",s);
161         if(s[0]=='M') {
162             if(s[2]=='X') printf("%d\n",max_sum());
163             else {
164                 int pos=read()+1,tot=read(),c=read();//pos需要+1 
165                 make_same(pos,tot,c);
166             }
167         }
168         else {
169             int pos=read()+1,tot=read();//pos需要+1 
170             if(s[0]=='I') n+=tot,insert(pos,tot);//n需要更新,因为max_sum函数中会用到新的n
171             else if(s[0]=='D') n-=tot,del(pos,tot);//同上 
172             else if(s[0]=='R') rev(pos,tot);
173             else if(s[0]=='G') printf("%d\n",query(pos,tot));
174         }
175     }
176     return 0;    
177 }
完整代码在这里

 洛谷P4008 [NOI2003] 文本编辑器

  1 #include<bits/stdc++.h>
  2 #define R register int
  3 #define ls(x) a[x].ch[0]
  4 #define rs(x) a[x].ch[1]
  5 #define fa(x) a[x].fa
  6 
  7 using namespace std;
  8 const int N=2100000,inf=0x3f3f3f3f;
  9 
 10 int read() {
 11     int f=1;char ch;
 12     while((ch=getchar())<'0'||ch>'9') if(ch=='-') f=-1;
 13     int res=ch-'0';
 14     while((ch=getchar())>='0'&&ch<='9') res=res*10+ch-'0';    
 15     return res*f;
 16 }
 17 
 18 int root,pos=1,num,n;
 19 char b[N];
 20 struct hh{
 21     int ch[2],fa,siz;
 22     char s;
 23 }a[N];
 24 
 25 int chk(int x) {return x==rs(fa(x));}
 26 void pushup(int x) {a[x].siz=a[ls(x)].siz+a[rs(x)].siz+1;}
 27 
 28 void print(int x) {
 29     if(ls(x)) print(ls(x));
 30     printf("%c",a[x].s);
 31     if(rs(x)) print(rs(x));
 32     pushup(x);
 33 }
 34 
 35 int build(int l,int r,int f) {
 36     if(l>r) return 0;
 37     int mid=l+r>>1,id=++num;
 38     a[id].s=b[mid];a[id].fa=f;a[id].siz=1;
 39     ls(id)=build(l,mid-1,id);
 40     rs(id)=build(mid+1,r,id);
 41     pushup(id);
 42     return id;
 43 }
 44 
 45 int find(int x,int k) {
 46     int cnt=a[ls(x)].siz;
 47     if(cnt+1==k) return x;
 48     if(cnt>=k) return find(ls(x),k);
 49     else return find(rs(x),k-cnt-1);
 50 }
 51 
 52 void rotate(int x) {
 53     int y=fa(x),z=fa(y),k=chk(x),w=a[x].ch[k^1];
 54     a[z].ch[chk(y)]=x;fa(x)=z;
 55     a[y].ch[k]=w;fa(w)=y;
 56     a[x].ch[k^1]=y;fa(y)=x;
 57     pushup(y);pushup(x);
 58 }    
 59 
 60 void splay(int x,int o) {
 61     if(!o) root=x;
 62     while(fa(x)!=o) {
 63         int y=fa(x),z=fa(y);
 64         if(z!=o) {
 65             if(chk(x)^chk(y)) rotate(x);
 66             else rotate(y);
 67         }
 68         rotate(x);
 69     }
 70 }
 71 
 72 void insert(int cnt) {
 73     for(R i=1;i<=cnt;++i) {
 74         b[i]=getchar();
 75         if(b[i]<32||b[i]>126) i--;
 76     }
 77     int id=build(1,cnt,0);
 78     int x=find(root,pos),y=find(root,pos+1);
 79     splay(x,0);splay(y,x);
 80     ls(y)=id;fa(id)=y;
 81     pushup(y);pushup(x);
 82 }
 83 
 84 void del(int cnt) {
 85     int x=find(root,pos),y=find(root,pos+cnt+1);
 86     splay(x,0);splay(y,x);
 87     ls(y)=0;
 88     pushup(y);pushup(x);
 89 }
 90 
 91 void get(int cnt) {
 92     int x=find(root,pos),y=find(root,pos+cnt+1);
 93     splay(x,0);splay(y,x);
 94     print(ls(y));putchar('\n');
 95 }
 96 
 97 int main() {
 98     b[0]=b[1]=b[2]=' ';
 99     root=build(1,2,0);
100     n=2;
101     int t=read();
102     while(t--) {
103         char s[10];scanf("%s",s);
104         if(s[0]=='P') {if(pos) pos--;}
105         else if(s[0]=='N') pos++;
106         else {
107             int cnt=read();
108             if(s[0]=='M') pos=cnt+1;
109             else if(s[0]=='I') n+=cnt,insert(cnt);
110             else if(s[0]=='D') cnt=min(n-pos,cnt),n-=cnt,del(cnt);
111             else if(s[0]=='G') cnt=min(n-pos,cnt),get(cnt);
112         }
113     }
114     return 0;    
115 }
ac code

洛谷P4567 [AHOI2006]文本编辑器

  1 #include<bits/stdc++.h>
  2 #define R register int
  3 #define ls(x) a[x].ch[0]
  4 #define rs(x) a[x].ch[1]
  5 #define fa(x) a[x].fa
  6 
  7 using namespace std;
  8 const int N=2100000,inf=0x3f3f3f3f;
  9 
 10 int read() {
 11     int f=1;char ch;
 12     while((ch=getchar())<'0'||ch>'9') if(ch=='-') f=-1;
 13     int res=ch-'0';
 14     while((ch=getchar())>='0'&&ch<='9') res=res*10+ch-'0';    
 15     return res*f;
 16 }
 17 
 18 int root,pos=1,num,n;
 19 char b[N];
 20 struct hh{
 21     int ch[2],fa,siz,rev;
 22     char s;
 23 }a[N];
 24 
 25 int chk(int x) {return x==rs(fa(x));}
 26 void pushup(int x) {a[x].siz=a[ls(x)].siz+a[rs(x)].siz+1;}
 27 
 28 void Rev(int x) {
 29     a[x].rev^=1,swap(ls(x),rs(x));
 30 }
 31 
 32 void pushdown(int x) {
 33     if(!a[x].rev) return;
 34     if(ls(x)) Rev(ls(x));
 35     if(rs(x)) Rev(rs(x));
 36     a[x].rev=0;
 37 }
 38 
 39 void print(int x) {
 40     pushdown(x);
 41     if(ls(x)) print(ls(x));
 42     printf("%c",a[x].s);
 43     if(rs(x)) print(rs(x));
 44     pushup(x);
 45 }
 46 
 47 int build(int l,int r,int f) {
 48     if(l>r) return 0;
 49     int mid=l+r>>1,id=++num;
 50     a[id].s=b[mid];a[id].fa=f;a[id].siz=1,a[id].rev=0;
 51     ls(id)=build(l,mid-1,id);
 52     rs(id)=build(mid+1,r,id);
 53     pushup(id);
 54     return id;
 55 }
 56 
 57 int find(int x,int k) {
 58     pushdown(x);
 59     int cnt=a[ls(x)].siz;
 60     if(cnt+1==k) return x;
 61     if(cnt>=k) return find(ls(x),k);
 62     else return find(rs(x),k-cnt-1);
 63 }
 64 
 65 void rotate(int x) {
 66     int y=fa(x),z=fa(y),k=chk(x),w=a[x].ch[k^1];
 67     a[z].ch[chk(y)]=x;fa(x)=z;
 68     a[y].ch[k]=w;fa(w)=y;
 69     a[x].ch[k^1]=y;fa(y)=x;
 70     pushup(y);pushup(x);
 71 }    
 72 
 73 void splay(int x,int o) {
 74     if(!o) root=x;
 75     while(fa(x)!=o) {
 76         int y=fa(x),z=fa(y);
 77         if(z!=o) {
 78             if(chk(x)^chk(y)) rotate(x);
 79             else rotate(y);
 80         }
 81         rotate(x);
 82     }
 83 }
 84 
 85 void insert(int cnt) {
 86     for(R i=1;i<=cnt;++i) {
 87         b[i]=getchar();
 88         if(b[i]<32||b[i]>126) {
 89             if(i==cnt) b[i]=' ';
 90             else i--;
 91         }
 92     }
 93     int id=build(1,cnt,0);
 94     int x=find(root,pos),y=find(root,pos+1);
 95     splay(x,0);splay(y,x);
 96     ls(y)=id;fa(id)=y;
 97     pushup(y);pushup(x);
 98 }
 99 
100 void del(int cnt) {
101     int x=find(root,pos),y=find(root,pos+cnt+1);
102     splay(x,0);splay(y,x);
103     ls(y)=0;
104     pushup(y);pushup(x);
105 }
106 
107 void get(int cnt) {
108     int x=find(root,pos),y=find(root,pos+cnt+1);
109     splay(x,0);splay(y,x);
110     print(ls(y));putchar('\n');
111 }
112 
113 void reverse(int cnt) {
114     int x=find(root,pos),y=find(root,pos+cnt+1);
115     splay(x,0);splay(y,x);
116     Rev(ls(y));
117     pushup(y);pushup(x);
118 }
119 
120 int main() {
121     b[1]=b[2]=' ';
122     root=build(1,2,0);
123     n=2;
124     int t=read();
125     while(t--) {
126         char s[10];scanf("%s",s);
127         if(s[0]=='P') {if(pos) pos--;}
128         else if(s[0]=='N') pos++;
129         else if(s[0]=='G') get(1);
130         else {
131             int cnt=read();
132             if(s[0]=='M') pos=cnt+1;
133             else if(s[0]=='I') n+=cnt,insert(cnt);
134             else if(s[0]=='D') cnt=min(n-pos,cnt),n-=cnt,del(cnt);
135             else if(s[0]=='R') reverse(cnt);
136         }
137     }
138     return 0;
139 }
Ac Code

 


 

洛谷P3215 [HNOI2011]括号修复 / [JSOI2011]括号序列

这道题调了好几天。。几个明显的错误都没看出来/扶额

最大的问题是区间赋值时要直接用所赋的值c更新a[x].tag,而不能直接在懒标记下传时用节点值a[x].val来更新x的左右儿子的值,因为之前的Inv操作会影响到a[x].val,进而影响下传的标记。

  1 #include<bits/stdc++.h>
  2 #define IL inline
  3 #define R register int
  4 #define ls(x) a[x].ch[0]
  5 #define rs(x) a[x].ch[1]
  6 #define fa(x) a[x].fa
  7 
  8 using namespace std;
  9 const int N=5e5+5,inf=0x3f3f3f3f;
 10 
 11 int read() {
 12     int f=1;char ch;
 13     while((ch=getchar())<'0'||ch>'9') if(ch=='-') f=-1;
 14     int res=ch-'0';
 15     while((ch=getchar())>='0'&&ch<='9') res=res*10+ch-'0';    
 16     return res*f;
 17 }
 18 
 19 char cha[N];
 20 int n,q,root,num,b[N];
 21 struct hh {
 22     int qma,hma,qmi,hmi,fa,ch[2],siz,sum,val,rev,tag,inv;
 23 }a[N];
 24 
 25 IL int min(int x,int y){return x<y?x:y;}
 26 IL int max(int x,int y){return x>y?x:y;}
 27 int chk(int x) {return x==rs(fa(x));}
 28 
 29 void Rev(int x) {
 30     a[x].rev^=1;
 31     swap(a[x].qma,a[x].hma);
 32     swap(a[x].qmi,a[x].hmi);
 33     swap(ls(x),rs(x));
 34 }
 35 
 36 void Tag(int x,int c) {
 37     a[x].val=c;
 38     a[x].sum=c*a[x].siz;
 39     a[x].tag=c;a[x].rev=a[x].inv=0;
 40     a[x].qma=a[x].hma=max(0,a[x].sum);
 41     a[x].qmi=a[x].hmi=min(0,a[x].sum);
 42 }
 43 
 44 void Inv(int x) {
 45     a[x].inv^=1;
 46     a[x].sum=-a[x].sum;a[x].val=-a[x].val;
 47     swap(a[x].qmi,a[x].qma);
 48     a[x].qmi=-a[x].qmi;a[x].qma=-a[x].qma;
 49     swap(a[x].hma,a[x].hmi);
 50     a[x].hma=-a[x].hma;a[x].hmi=-a[x].hmi;
 51 }
 52 
 53 void pushup(int x) {
 54     a[x].siz=1+a[ls(x)].siz+a[rs(x)].siz;
 55     a[x].sum=a[x].val+a[ls(x)].sum+a[rs(x)].sum;
 56     a[x].qma=max(a[ls(x)].qma,a[ls(x)].sum+a[x].val+a[rs(x)].qma);
 57     a[x].qmi=min(a[ls(x)].qmi,a[ls(x)].sum+a[x].val+a[rs(x)].qmi);
 58     a[x].hma=max(a[rs(x)].hma,a[rs(x)].sum+a[x].val+a[ls(x)].hma);
 59     a[x].hmi=min(a[rs(x)].hmi,a[rs(x)].sum+a[x].val+a[ls(x)].hmi);
 60 }
 61 
 62 void pushdown(int x) {
 63     if(a[x].tag) {
 64         if(ls(x)) Tag(ls(x),a[x].tag);
 65         if(rs(x)) Tag(rs(x),a[x].tag);
 66         a[x].tag=0;
 67     }
 68     if(a[x].rev) {
 69         if(ls(x)) Rev(ls(x));
 70         if(rs(x)) Rev(rs(x));
 71         a[x].rev=0;
 72     }    
 73     if(a[x].inv) {
 74         if(ls(x)) Inv(ls(x));
 75         if(rs(x)) Inv(rs(x));
 76         a[x].inv=0;
 77     }
 78 }
 79 
 80 int build(int l,int r,int f) {
 81     if(l>r) return 0;
 82     int mid=l+r>>1,id=++num;
 83     a[id].val=b[mid];
 84     a[id].siz=1;fa(id)=f;a[id].sum=a[id].val;
 85     a[id].rev=0;a[id].tag=0;a[id].inv=0;
 86     a[id].qma=a[id].hma=max(0,a[id].sum);
 87     a[id].qmi=a[id].hmi=min(0,a[id].sum);
 88     ls(id)=build(l,mid-1,id);
 89     rs(id)=build(mid+1,r,id);
 90     pushup(id);
 91     return id;
 92 }
 93 
 94 int find(int x,int k) {
 95     pushdown(x);
 96     int cnt=a[ls(x)].siz;
 97     if(cnt==k-1) return x;
 98     if(cnt>=k) return find(ls(x),k);
 99     else return find(rs(x),k-cnt-1);
100 }
101 
102 void rotate(int x) {
103     int y=fa(x),z=fa(y),k=chk(x),w=a[x].ch[k^1];
104     a[z].ch[chk(y)]=x;fa(x)=z;
105     a[x].ch[k^1]=y;fa(y)=x;
106     a[y].ch[k]=w;fa(w)=y;
107     pushup(y);pushup(x);
108 }
109 
110 void splay(int x,int o) {
111     if(!o) root=x;
112     while(fa(x)!=o) {
113         int y=fa(x),z=fa(y);
114         if(z!=o) {
115             if(chk(x)^chk(y)) rotate(x);
116             else rotate(y);
117         }
118         rotate(x);
119     }
120 }
121 
122 void replace(int l,int r,int c) {
123     int x=find(root,l-1),y=find(root,r+1);
124     splay(x,0);splay(y,x);
125     Tag(ls(y),c);
126     pushup(y);pushup(x);
127 }
128 
129 void reverse(int l,int r) {
130     int x=find(root,l-1),y=find(root,r+1);
131     splay(x,0);splay(y,x);
132     Rev(ls(y));
133     pushup(y);pushup(x);
134 }
135 
136 void invert(int l,int r) {
137     int x=find(root,l-1),y=find(root,r+1);
138     splay(x,0);splay(y,x);
139     Inv(ls(y));
140     pushup(y);pushup(x);
141 }
142 
143 void query(int l,int r) {
144     int x=find(root,l-1),y=find(root,r+1);
145     splay(x,0);splay(y,x);
146     int ans=((-a[ls(y)].qmi+1)>>1)+((a[ls(y)].hma+1)>>1);
147     printf("%d\n",ans);
148 }
149 
150 void print(int x) {
151     pushdown(x);
152     if(ls(x)) print(ls(x));
153     if(a[x].val==1) cout<<"(";
154     else if(a[x].val==-1) cout<<")";
155     if(rs(x)) print(rs(x));
156 }
157 
158 //'(':1,')':-1
159 //hma/2+qmi/2
160 int main() {
161     n=read();q=read();
162     scanf("%s",cha+2);
163     for(R i=2;i<=n+1;++i) 
164         if(cha[i]=='(') b[i]=1;
165         else b[i]=-1;
166     root=build(1,n+2,0);
167     while(q--) {
168         char s[8];int l,r;
169         scanf("%s",s);
170         l=read()+1,r=read()+1;
171         if(s[0]=='R') {
172             char c[2];scanf("%s",c);
173             int val=(c[0]=='('?1:-1);
174             replace(l,r,val);
175         }
176         else if(s[0]=='S') reverse(l,r);
177         else if(s[0]=='I') invert(l,r);
178         else query(l,r);
179     }
180     return 0;    
181 }
P3215代码

 


Acwing266. 超级备忘录

看了一圈代码好像都没有和我一样的码风/doge

终于依靠自己把revolve函数写出来并且调出来了!

#include<bits/stdc++.h>
//#define int long long
#define IL inline
#define R register int
#define fa(x) a[x].fa
#define ls(x) a[x].ch[0]
#define rs(x) a[x].ch[1]

using namespace std;
const int N=1e6+5,inf=0x3f3f3f3f;

IL int read() {
    int f=1;
    char ch;
    while((ch=getchar())<'0'||ch>'9') if(ch=='-') f=-1;
    int res=ch-'0';
    while((ch=getchar())>='0'&&ch<='9') res=res*10+ch-'0';
    return res*f;
}

int n,m,b[N],root,num;
struct hh {
    int fa,rev,tag,mi,siz,ch[2],val,ad;
}a[N];

int chk(int x) {return x==rs(fa(x));}
void Rev(int x) {a[x].rev^=1;swap(ls(x),rs(x));}
void Add(int x,int d) {a[x].val+=d;a[x].ad+=d;a[x].mi+=d;}

void pushup(int x) {
    a[x].siz=a[ls(x)].siz+a[rs(x)].siz+1;
    a[x].mi=a[x].val;
    if(ls(x)) a[x].mi=min(a[ls(x)].mi,a[x].mi);//注意判断左右孩子是否存在
    if(rs(x)) a[x].mi=min(a[rs(x)].mi,a[x].mi);
}

void pushdown(int x) {
    if(a[x].rev) {
        if(ls(x)) Rev(ls(x));
        if(rs(x)) Rev(rs(x));    
        a[x].rev=0;
    }
    if(a[x].ad) {
        if(ls(x)) Add(ls(x),a[x].ad);
        if(rs(x)) Add(rs(x),a[x].ad);
        a[x].ad=0;    
    }
}

int find(int x,int k) {
    pushdown(x);
    int cnt=a[ls(x)].siz;
    if(cnt==k-1) return x;
    if(cnt>=k) return find(ls(x),k);
    else return find(rs(x),k-cnt-1);    
}

void rotate(int x) {
    int y=fa(x),z=fa(y),k=chk(x),w=a[x].ch[k^1];
    a[z].ch[chk(y)]=x;fa(x)=z;
    a[y].ch[k]=w;fa(w)=y;
    a[x].ch[k^1]=y;fa(y)=x;
    pushup(y);pushup(x);    
}

void splay(int x,int o) {
    if(!o) root=x;
    while(fa(x)!=o) {
        int y=fa(x),z=fa(y);
        if(z!=o) {
            if(chk(x)^chk(y)) rotate(x);    
            else rotate(y);
        }
        rotate(x);
    }
}

int build(int l,int r,int f) {
    if(l>r) return 0;
    int mid=l+r>>1,id=++num;
    a[id].val=a[id].mi=b[mid];
    a[id].siz=1;fa(id)=f;
    a[id].rev=a[id].tag=0;
    ls(id)=build(l,mid-1,id);
    rs(id)=build(mid+1,r,id);
    pushup(id);
    return id;
}

void del(int k) {
    int x=find(root,k-1),y=find(root,k+1);
    splay(x,0);splay(y,x);
    ls(y)=0;    
    pushup(y);pushup(x);
}

void insert(int k,int c) {
    int x=find(root,k),y=find(root,k+1);
    splay(x,0);splay(y,x);
    int id=++num;
    ls(y)=id;fa(id)=y;
    a[id].val=a[id].mi=c;
    a[id].siz=1;a[id].rev=a[id].tag=0;
    pushup(y);pushup(x);
}

void reverse(int l,int r) {
    int x=find(root,l-1),y=find(root,r+1);
    splay(x,0);splay(y,x);
    Rev(ls(y));
    pushup(y);pushup(x);
}

void revolve(int l,int r,int t) {
    t%=r-l+1;//优化
    if(!t) return ;
    int x=find(root,l-1),y=find(root,r-t+1);
    splay(x,0);splay(y,x);
    pushdown(x);pushdown(y);
    int id1=ls(y);//id1即区间[l,r-t]的编号
    ls(y)=0;//先删去这个区间
    pushup(y);pushup(x);
    
    x=find(root,r-t-a[id1].siz),y=find(root,r-a[id1].siz+1);//a[id1].siz写成r-t-l+1亦可
    splay(x,0);splay(y,x);
    int id2=ls(y);//id2即原区间[r-t+1,r-t]的编号,因为上面已经删去了编号为id1的区间[l,r-t],所以找区间编号的时候要减去id1区间的大小a[id1].siz
    pushdown(x);pushdown(y);pushdown(id2);
    while(rs(id2)) id2=rs(id2),pushdown(id2);//令id2为区间[r-t+1,r-t]最右端的点的编号
    splay(id2,y);//将这个点移为y的左孩子
    fa(id1)=id2;rs(id2)=id1;//把编号为id1的区间[l,r-t]插到区间[r-t+1,r-t]后面,即成为id2的右孩子
    pushup(id2);pushup(y);pushup(x);
}

void mi(int l,int r) {
    int x=find(root,l-1),y=find(root,r+1);
    splay(x,0);splay(y,x);
    printf("%d\n",a[ls(y)].mi);    
}

void add(int l,int r,int d) {
    int x=find(root,l-1),y=find(root,r+1);
    splay(x,0);splay(y,x);
    Add(ls(y),d);
    pushup(y);pushup(x);
}

int main() {
    n=read();
    b[1]=-inf;b[n+2]=inf;
    for(R i=2;i<=n+1;++i) b[i]=read();
    root=build(1,n+2,0);
    m=read();
    while(m--) {
        char s[8];scanf("%s",s);
        int x=read()+1;
        if(s[0]=='D')    del(x);
        else if(s[0]=='I') {
            int p=read();
            insert(x,p);    
        }
        else {
            int y=read()+1;    
            if(s[0]=='R') {
                if(s[3]=='E') reverse(x,y);
                else {
                    int t=read();
                    revolve(x,y,t);
                }    
            }
            else if(s[0]=='M') mi(x,y);
            else {
                int d=read();
                add(x,y,d);    
            }
        }
    }
    return 0;
}
Ac Code

 

 

 

 

 

 

posted @ 2021-03-05 22:43  endl\n  阅读(320)  评论(0编辑  收藏  举报