[BZOJ2555]SubString
[BZOJ2555]SubString
试题描述
懒得写背景了,给你一个字符串init,要求你支持两个操作
(1):在当前字符串的后面插入一个字符串
(2):询问字符串s在当前字符串中出现了几次?(作为连续子串)
你必须在线支持这些操作。
输入
第一行一个数Q表示操作个数
第二行一个字符串表示初始字符串init
接下来Q行,每行2个字符串Type,Str
Type是ADD的话表示在后面插入字符串。
Type是QUERY的话表示询问某字符串在当前字符串中出现了几次。
为了体现在线操作,你需要维护一个变量mask,初始值为0
读入串Str之后,使用这个过程将之解码成真正询问的串TrueStr。
询问的时候,对TrueStr询问后输出一行答案Result
然后mask = mask xor Result
插入的时候,将TrueStr插到当前字符串后面即可。
HINT:ADD和QUERY操作的字符串都需要解压
输出
对于每个询问输出字符串出现次数。
输入示例
2 A QUERY B ADD BBABBBBAAB
输出示例
0
数据规模及约定
40 % 的数据字符串最终长度 <= 20000,询问次数<= 1000,询问总长度<= 10000
100 % 的数据字符串最终长度 <= 600000,询问次数<= 10000,询问总长度<= 3000000
题解
用 splay 动态维护 dfs 序(括号序列),这样每次 extend 的时候就相当于插入一个节点或者是把一颗子树连到另一个节点上;对应 splay 操作就是每次将一个区间拎出来插到另一个缝隙中,或是对一个区间进行查询。
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <algorithm>
using namespace std;
const int BufferSize = 1 << 16;
char buffer[BufferSize], *Head, *Tail;
inline char Getchar() {
if(Head == Tail) {
int l = fread(buffer, 1, BufferSize, stdin);
Tail = (Head = buffer) + l;
}
return *Head++;
}
int read() {
int x = 0, f = 1; char c = Getchar();
while(!isdigit(c)){ if(c == '-') f = -1; c = Getchar(); }
while(isdigit(c)){ x = x * 10 + c - '0'; c = Getchar(); }
return x * f;
}
#define maxn 1200010
#define maxa 26
int Rt, tot, ch[maxn<<1][2], fa[maxn<<1], dl[maxn<<1], dr[maxn<<1];
struct Node {
int val, sum;
Node() {}
Node(int _): val(_) {}
} ns[maxn<<1];
inline void maintain(int o) {
ns[o].sum = ns[o].val;
for(int i = 0; i < 2; i++) if(ch[o][i])
ns[o].sum += ns[ch[o][i]].sum;
return ;
}
inline void rotate(int u) {
int y = fa[u], z = fa[y], l = 0, r = 1;
if(z) ch[z][ch[z][1]==y] = u;
if(ch[y][1] == u) swap(l, r);
fa[u] = z; fa[y] = u; fa[ch[u][r]] = y;
ch[y][l] = ch[u][r]; ch[u][r] = y;
maintain(y); maintain(u);
return ;
}
inline void splay(int u) {
while(fa[u]) {
int y = fa[u], z = fa[y];
if(z) {
if(ch[y][0] == u ^ ch[z][0] == y) rotate(u);
else rotate(y);
}
rotate(u);
}
return ;
}
inline void split(int& lrt, int& rrt, int id) {
splay(id);
lrt = id; rrt = ch[id][1];
ch[id][1] = fa[rrt] = 0;
maintain(lrt);
return ;
}
inline void splitl(int& lrt, int& rrt, int id) {
splay(id);
lrt = ch[id][0]; rrt = id;
ch[id][0] = fa[lrt] = 0;
maintain(rrt);
return ;
}
inline int merge(int a, int b) {
if(!a) return b;
if(!b) return a;
while(ch[a][1]) a = ch[a][1];
splay(a);
ch[a][1] = b; fa[b] = a;
return maintain(a), a;
}
inline void Insert(int pos, int mrt) {
int lrt, rrt;
split(lrt, rrt, pos);
lrt = merge(lrt, mrt); merge(lrt, rrt);
return ;
}
inline void Insert2(int pos, int mrt, int m2) {
int lrt, rrt;
split(lrt, rrt, pos);
lrt = merge(lrt, mrt); lrt = merge(lrt, m2); merge(lrt, rrt);
return ;
}
inline int Create(int o, int v) {
dl[o] = ++tot; dr[o] = ++tot;
ns[dl[o]] = Node(v); ns[dr[o]] = Node(0);
ch[dl[o]][1] = dr[o]; fa[dr[o]] = dl[o];
maintain(dr[o]); maintain(dl[o]);
return dl[o];
}
int rt, last, ToT, to[maxn][maxa], par[maxn], Max[maxn];
void extend(int x) {
int p = last, np = ++ToT; Max[np] = Max[p] + 1; last = np;
int Np = Create(np, 1);
while(p && !to[p][x]) to[p][x] = np, p = par[p];
if(!p){ par[np] = rt; Insert(dl[rt], Np); return ; }
int q = to[p][x];
if(Max[q] == Max[p] + 1){ par[np] = q; Insert(dl[q], Np); return ; }
int nq = ++ToT, Nq = Create(nq, 0); Max[nq] = Max[p] + 1;
memcpy(to[nq], to[q], sizeof(to[q]));
par[nq] = par[q]; Insert(dl[par[q]], Nq);
par[q] = par[np] = nq;
int lrt, Q, rrt;
splitl(lrt, Q, dl[q]); split(Q, rrt, dr[q]); merge(lrt, rrt);
Insert2(dl[nq], Q, Np);
while(p && to[p][x] == q) to[p][x] = nq, p = par[p];
return ;
}
char cmd[10], S[maxn];
void decode(char* S, int mark) {
int n = strlen(S);
for(int i = 0; i < n; i++) {
mark = (mark * 131 + i) % n;
swap(S[mark], S[i]);
}
return ;
}
int main() {
// freopen("data.in", "r", stdin);
// freopen("data.out", "w", stdout);
rt = last = ToT = 1; Rt = Create(rt, 0);
int q = read(), mark = 0;
char tc = Getchar(); while(!isalpha(tc)) tc = Getchar();
int n = 0;
while(isalpha(tc)) S[n++] = tc, tc = Getchar();
for(int i = 0; i < n; i++) extend(S[i] - 'A');
while(q--) {
while(!isalpha(tc)) tc = Getchar();
n = 0; while(isalpha(tc)) cmd[n++] = tc, tc = Getchar();
cmd[n] = 0;
while(!isalpha(tc)) tc = Getchar();
n = 0; while(isalpha(tc)) S[n++] = tc, tc = Getchar();
S[n] = 0;
decode(S, mark);
if(!strcmp(cmd, "ADD"))
for(int i = 0; i < n; i++) extend(S[i] - 'A');
else {
int p = rt; n = strlen(S);
for(int i = 0; i < n; i++) p = to[p][S[i]-'A'];
if(!p){ puts("0"); continue; }
int lrt, mrt, rrt;
splitl(lrt, mrt, dl[p]); split(mrt, rrt, dr[p]);
printf("%d\n", ns[mrt].sum); mark ^= ns[mrt].sum;
lrt = merge(lrt, mrt); merge(lrt, rrt);
}
}
return 0;
}
应该能看出来我卡常的痕迹。。。

浙公网安备 33010602011771号