[BZOJ2555]SubString

[BZOJ2555]SubString

试题描述

懒得写背景了,给你一个字符串init,要求你支持两个操作
(1):在当前字符串的后面插入一个字符串
(2):询问字符串s在当前字符串中出现了几次?(作为连续子串)
你必须在线支持这些操作。

输入

第一行一个数Q表示操作个数
第二行一个字符串表示初始字符串init
接下来Q行,每行2个字符串Type,Str 
Type是ADD的话表示在后面插入字符串。
Type是QUERY的话表示询问某字符串在当前字符串中出现了几次。
为了体现在线操作,你需要维护一个变量mask,初始值为0
    
读入串Str之后,使用这个过程将之解码成真正询问的串TrueStr。
询问的时候,对TrueStr询问后输出一行答案Result
然后mask = mask xor Result  
插入的时候,将TrueStr插到当前字符串后面即可。

HINT:ADD和QUERY操作的字符串都需要解压

输出

对于每个询问输出字符串出现次数。

输入示例

2
A
QUERY B
ADD BBABBBBAAB

输出示例

0

数据规模及约定

40 % 的数据字符串最终长度 <= 20000,询问次数<= 1000,询问总长度<= 10000
100 % 的数据字符串最终长度 <= 600000,询问次数<= 10000,询问总长度<= 3000000

题解

用 splay 动态维护 dfs 序(括号序列),这样每次 extend 的时候就相当于插入一个节点或者是把一颗子树连到另一个节点上;对应 splay 操作就是每次将一个区间拎出来插到另一个缝隙中,或是对一个区间进行查询。

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <algorithm>
using namespace std;

const int BufferSize = 1 << 16;
char buffer[BufferSize], *Head, *Tail;
inline char Getchar() {
	if(Head == Tail) {
		int l = fread(buffer, 1, BufferSize, stdin);
		Tail = (Head = buffer) + l;
	}
	return *Head++;
}
int read() {
	int x = 0, f = 1; char c = Getchar();
	while(!isdigit(c)){ if(c == '-') f = -1; c = Getchar(); }
	while(isdigit(c)){ x = x * 10 + c - '0'; c = Getchar(); }
	return x * f;
}

#define maxn 1200010
#define maxa 26

int Rt, tot, ch[maxn<<1][2], fa[maxn<<1], dl[maxn<<1], dr[maxn<<1];
struct Node {
	int val, sum;
	Node() {}
	Node(int _): val(_) {}
} ns[maxn<<1];
inline void maintain(int o) {
	ns[o].sum = ns[o].val;
	for(int i = 0; i < 2; i++) if(ch[o][i])
		ns[o].sum += ns[ch[o][i]].sum;
	return ;
}
inline void rotate(int u) {
	int y = fa[u], z = fa[y], l = 0, r = 1;
	if(z) ch[z][ch[z][1]==y] = u;
	if(ch[y][1] == u) swap(l, r);
	fa[u] = z; fa[y] = u; fa[ch[u][r]] = y;
	ch[y][l] = ch[u][r]; ch[u][r] = y;
	maintain(y); maintain(u);
	return ;
}
inline void splay(int u) {
	while(fa[u]) {
		int y = fa[u], z = fa[y];
		if(z) {
			if(ch[y][0] == u ^ ch[z][0] == y) rotate(u);
			else rotate(y);
		}
		rotate(u);
	}
	return ;
}
inline void split(int& lrt, int& rrt, int id) {
	splay(id);
	lrt = id; rrt = ch[id][1];
	ch[id][1] = fa[rrt] = 0;
	maintain(lrt);
	return ;
}
inline void splitl(int& lrt, int& rrt, int id) {
	splay(id);
	lrt = ch[id][0]; rrt = id;
	ch[id][0] = fa[lrt] = 0;
	maintain(rrt);
	return ;
}
inline int merge(int a, int b) {
	if(!a) return b;
	if(!b) return a;
	while(ch[a][1]) a = ch[a][1];
	splay(a);
	ch[a][1] = b; fa[b] = a;
	return maintain(a), a;
}
inline void Insert(int pos, int mrt) {
	int lrt, rrt;
	split(lrt, rrt, pos);
	lrt = merge(lrt, mrt); merge(lrt, rrt);
	return ;
}
inline void Insert2(int pos, int mrt, int m2) {
	int lrt, rrt;
	split(lrt, rrt, pos);
	lrt = merge(lrt, mrt); lrt = merge(lrt, m2); merge(lrt, rrt);
	return ;
}
inline int Create(int o, int v) {
	dl[o] = ++tot; dr[o] = ++tot;
	ns[dl[o]] = Node(v); ns[dr[o]] = Node(0);
	ch[dl[o]][1] = dr[o]; fa[dr[o]] = dl[o];
	maintain(dr[o]); maintain(dl[o]);
	return dl[o];
}

int rt, last, ToT, to[maxn][maxa], par[maxn], Max[maxn];
void extend(int x) {
	int p = last, np = ++ToT; Max[np] = Max[p] + 1; last = np;
	int Np = Create(np, 1);
	while(p && !to[p][x]) to[p][x] = np, p = par[p];
	if(!p){ par[np] = rt; Insert(dl[rt], Np); return ; }
	int q = to[p][x];
	if(Max[q] == Max[p] + 1){ par[np] = q; Insert(dl[q], Np); return ; }
	int nq = ++ToT, Nq = Create(nq, 0); Max[nq] = Max[p] + 1;
	memcpy(to[nq], to[q], sizeof(to[q]));
	par[nq] = par[q]; Insert(dl[par[q]], Nq);
	par[q] = par[np] = nq;
	int lrt, Q, rrt;
	splitl(lrt, Q, dl[q]); split(Q, rrt, dr[q]); merge(lrt, rrt);
	Insert2(dl[nq], Q, Np);
	while(p && to[p][x] == q) to[p][x] = nq, p = par[p];
	return ;
}

char cmd[10], S[maxn];

void decode(char* S, int mark) {
	int n = strlen(S);
	for(int i = 0; i < n; i++) {
		mark = (mark * 131 + i) % n;
		swap(S[mark], S[i]);
	}
	return ;
}

int main() {
//	freopen("data.in", "r", stdin);
//	freopen("data.out", "w", stdout);
	rt = last = ToT = 1; Rt = Create(rt, 0);
	int q = read(), mark = 0;
	char tc = Getchar(); while(!isalpha(tc)) tc = Getchar();
	int n = 0;
	while(isalpha(tc)) S[n++] = tc, tc = Getchar();
	for(int i = 0; i < n; i++) extend(S[i] - 'A');
	while(q--) {
		while(!isalpha(tc)) tc = Getchar();
		n = 0; while(isalpha(tc)) cmd[n++] = tc, tc = Getchar();
		cmd[n] = 0;
		while(!isalpha(tc)) tc = Getchar();
		n = 0; while(isalpha(tc)) S[n++] = tc, tc = Getchar();
		S[n] = 0;
		decode(S, mark);
		if(!strcmp(cmd, "ADD"))
			for(int i = 0; i < n; i++) extend(S[i] - 'A');
		else {
			int p = rt; n = strlen(S);
			for(int i = 0; i < n; i++) p = to[p][S[i]-'A'];
			if(!p){ puts("0"); continue; }
			int lrt, mrt, rrt;
			splitl(lrt, mrt, dl[p]); split(mrt, rrt, dr[p]);
			printf("%d\n", ns[mrt].sum); mark ^= ns[mrt].sum;
			lrt = merge(lrt, mrt); merge(lrt, rrt);
		}
	}
	
	return 0;
}

应该能看出来我卡常的痕迹。。。

posted @ 2017-03-14 14:43  xjr01  阅读(148)  评论(0编辑  收藏  举报