后缀平衡树学习笔记
前置知识
重量平衡树。
重量平衡树的定义为每次插入新节点对应的子树大小为期望或均摊 \(O(\log n)\) 的平衡树。
主要的重量平衡树有:Treap、WBLT、替罪羊树等。
本文将使用 FHQTreap 进行讲解。
算法简介
定义
后缀平衡树维护一个包含字符串所有后缀的有序集合,并支持动态加入后缀(在字符串前添加字符)。通常采用重量平衡树实现。后缀平衡树的中序遍历即为后缀数组。
后缀平衡树可以实现动态在开头加字符,删字符,查询后缀排名,查询字符串排名等。
过程
先考虑暴力构造。每次向平衡树里插入一个后缀,然后暴力比较新字符串和当前遍历到的节点。一次比较复杂度 \(O(n)\),一次插入复杂度 \(O(n\log n)\),总复杂度 \(O(n^2\log n)\)。显然,这个复杂度是不可接受的。
既然后缀平衡树的中序遍历是后缀数组,那么我们可以考虑像后缀数组一样用已经求出的后缀的排名关系优化字符串比较的过程。
每次比较新后缀和一个已有的后缀,先比较对应位置上字符的大小,若相等,则比较上一个后缀的大小。比较已有后缀可以直接查排名,也可以用 Hash + 二分,复杂度 \(O(\log n)\),总复杂度 \(O(n\log^2 n)\)。
这样的复杂度仍然不够优秀,考虑进一步优化。观察到比较两个位置的字符是 \(O(1)\) 的,但比较两个后缀是 \(O(\log n)\) 的,这很不优。所以我们考虑加速两个后缀的比较。
我们令平衡树上每个点对应一个区间 \([l,r]\),定义这个点的权值 \(val\) 为 \(\frac{l+r}{2}\),并且令每个节点左儿子对应的区间为 \([l,\frac{l+r}{2}]\),右儿子为 \([\frac{l+r}{2},r]\),则所有节点的权值一定满足平衡树左儿子 \(<\) 自身 \(<\) 右儿子的特性。
那么,比较两个在平衡树上的节点就只需要比较它们的 \(val\) 就可以了。这就完成了 \(O(1)\) 的两个后缀比较。然后在插入新节点或删除节点后将子树内的 \(val\) 值重新计算即可。一次插入复杂度 \(O(\log n)\),总复杂度 \(O(n\log n)\)。删除同理。
查询后缀排名就是查询平衡树内某个节点的排名,是平衡树的基本操作之一,复杂度 \(O(\log n)\),不再赘述。查询字符串的排名因为字符串不一定在原串中出现,所以只能朴素地字符串比较。总复杂度 \(O(|S|\log n)\),其中 \(|S|\) 表示查询字符串的长度。
复杂度证明
Treap 期望树高 \(O(\log n)\),所以一次插入或删除共进行 \(O(\log n)\) 次比较,每次比较复杂度 \(O(1)\),故插入或删除复杂度为 \(O(\log n)\)。
根据重量平衡树的性质:一次插入或删除找到的节点的子树大小期望或均摊 \(O(\log n)\),可以知道插入或删除后暴力重构子树内 \(val\) 值是 \(O(\log n)\) 的。
查询后缀排名的复杂度与树高相同,是 \(O(\log n)\) 的。
查询字符串排名的复杂度,一次比较为 \(O(|S|)\),比较次数与树高相同,为 \(O(\log n)\),因此总复杂度为 \(O(|S|log n)\)。
所以算法总复杂度为 \(O(n\log n+|S|\log n)\)。
代码实现
本文使用 FHQTreap 实现。
在字符串后插入或删除字符字符可以将字符串翻转,然后改为在字符串前插入或删除字符。
对于插入操作,找到要插入的位置,然后将该位置分裂,两个根节点作为新节点的两个子节点。
对于删除操作,找到要删除的位置,将两个儿子对应的子树合并作为当前节点。
查询字符串 \(s\) 出现次数,实际上就是问 \(\text {rank}(s+char(inf))-\text {rank}(s+char(0))\),其中 \(\text{rank}(s)\) 表示 \(s\) 在当前集合中的排名,\(s+c\) 表示在字符串 \(s\) 后面加一个字符 \(c\) 形成的字符串。
维护的权值 \(val\) 最好使用浮点数类型,并给根节点赋一个较大的区间,不容易爆精度。
#include <bits/stdc++.h>
using namespace std;
constexpr int N=2e6+6,inf=1e9+7;
mt19937 f(chrono::steady_clock::now().time_since_epoch().count());
inline int rd(){return f()%inf;}//随机函数
struct FHQTreap{
int ls[N],rs[N],size[N],pri[N],cnt,rt;
char c[N];//当前节点字符
double val[N];//当前节点权值
inline int newnode(char v){//新建节点
int p=++cnt;
ls[p]=rs[p]=0;
size[p]=1;
c[p]=v;
pri[p]=rd();
return p;
}
inline bool cmp(int x,int y){//比较后缀大小
return c[x]>c[y]||(c[x]==c[y]&&val[x-1]>val[y-1]);
}
inline void pushup(int p){//维护子树size
size[p]=size[ls[p]]+size[rs[p]]+1;
}
void get_val(int p,double l,double r){//更新p子树内的权值
if(!p)return;
val[p]=(l+r)/2;
get_val(ls[p],l,val[p]);
get_val(rs[p],val[p],r);
}
void split_rank(int rt,int rank,int &rtl,int &rtr){//按排名分裂
if(!rt){rtl=rtr=0;return;}
if(size[ls[rt]]>=rank){
rtr=rt;
split_rank(ls[rtr],rank,rtl,ls[rtr]);
pushup(rtr);
}
else{
rtl=rt;
split_rank(rs[rtl],rank-size[ls[rtl]]-1,rs[rtl],rtr);
pushup(rtl);
}
}
int merge(int rtl,int rtr){//合并
if(!rtl)return rtr;
if(!rtr)return rtl;
if(pri[rtl]<pri[rtr]){
rs[rtl]=merge(rs[rtl],rtr);
pushup(rtl);
return rtl;
}
else{
ls[rtr]=merge(rtl,ls[rtr]);
pushup(rtr);
return rtr;
}
}
int get_rank(int p,int node){//获得节点node在节点p为根的子树内的排名
int rank=0;
while(p){
if(cmp(node,p))rank+=size[ls[p]]+1,p=rs[p];
else p=ls[p];
}
return rank;
}
void insert(int &p,int node,double l,double r){//向以p为根的子树内插入节点node
if(!p||pri[p]>pri[node]){
split_rank(p,get_rank(p,node),ls[node],rs[node]);
get_val(p=node,l,r);
pushup(p);
return;
}
if(cmp(node,p))insert(rs[p],node,val[p],r);
else insert(ls[p],node,l,val[p]);
pushup(p);
}
void del(int &p,int rank,double l,double r){//在以p为根的子树内删除排名为rank的节点
if(size[ls[p]]==rank){
cnt--;
get_val(p=merge(ls[p],rs[p]),l,r);
return;
}
if(size[ls[p]]>rank)del(ls[p],rank,l,val[p]);
else del(rs[p],rank-size[ls[p]]-1,val[p],r);
pushup(p);
}
inline bool scmp(string s,int p){//字符串比较
for(char v:s){
if(!p)return 1;
if(v>c[p])return 1;
if(v<c[p])return 0;
p--;
}
return 0;
}
int get_rank(string s){//获取字符串s在整棵树内的排名
int p=rt,rank=0;
while(p){
if(scmp(s,p))rank+=size[ls[p]]+1,p=rs[p];
else p=ls[p];
}
return rank;
}
}tr;
int mask;
void decodeWithMask(string &s,int mask){
for(int j=0;j<s.size();j++){
mask=(mask*131+j)%s.size();
swap(s[j],s[mask]);
}
}
int n;
string s;
signed main(){
ios::sync_with_stdio(false);
cin.tie(0),cout.tie(0);
cin>>n;
cin>>s;
tr.insert(tr.rt,tr.newnode(1),0,1e18);//先加入一个极小值
for(char c:s){
tr.insert(tr.rt,tr.newnode(c),0,1e18);
}
for(int i=1;i<=n;i++){
string type,x;
int ans,y;
cin>>type;
if(type=="QUERY"){
cin>>x;
decodeWithMask(x,mask);
reverse(x.begin(),x.end());//先翻转再查询
ans=tr.get_rank(x+char(127))-tr.get_rank(x+char(0));
cout<<ans<<'\n',mask^=ans;
}
else if(type=="ADD"){
cin>>x;
decodeWithMask(x,mask);
for(char c:x){
tr.insert(tr.rt,tr.newnode(c),0,1e18);
}
}
else{
cin>>y;
for(int i=1;i<=y;i++){
tr.del(tr.rt,tr.get_rank(tr.rt,tr.size[tr.rt]),0,1e18);
}
}
}
return 0;
}

浙公网安备 33010602011771号