[BZOJ2555]SubString

SAM+LCT模板题。
题目相当于求询问串在SAM上走到的状态的right集合大小,等于parent树上这个节点的子树中有多少前缀节点(right集合大小为1的节点)。
因为强制在线,所以parent树用LCT维护。注意是维护子树和。代码还是非常好写的,也非常好想。

#include <bits/stdc++.h>

using namespace std;

const int N = 601000;

char str[N];
void getstr(int p){
	scanf("%s",str);
	int len = strlen(str);
	for (int i = 0;i < len;i++){
		p = (p*131+i)%len;
		swap(str[i],str[p]);
	}
}

struct Node{
	int son[2],fa,sum,sum1,sum2;
	Node(){}
	Node(int _fa,int _sum):fa(_fa),sum(_sum),sum1(_sum){
		son[0] = son[1] = sum2 = 0;}
};
struct Lct{
	Node nod[N<<1];
	void insert(int p,int d){
		nod[p] = Node(0,d);
	}
	void link(int x,int y){
		access(x);
		access(y);
		splay(x);
		splay(y);
		nod[x].fa = y;
		nod[y].sum2 += nod[x].sum;
		nod[y].sum += nod[x].sum;
	}
	void cut(int x){
		access(x);
		splay(x);
		int u = nod[x].son[0];
		if (u){
			nod[x].son[0] = 0;
			update(x);
			nod[u].fa = 0;
		}
	}
	void access(int x){
		int y = 0;
		while (x){
			splay(x);
			//!!!!
			int u = nod[x].son[1];
			if (u) nod[x].sum2 += nod[u].sum;
			if (nod[y].fa == x) nod[x].sum2 -= nod[y].sum;
			nod[x].son[1] = y;
			update(x);
			y = x;x = nod[x].fa;
		}
	}
	void splay(int x){
		int w;
		while ((w = check(x)) != -1){
			int y = nod[x].fa;
			if (w == check(y)) rotate(y,w^1);
			rotate(x,w^1);
		}
	}
	void rotate(int x,int d){
		int y = nod[x].fa,z = nod[y].fa,w = check(y);
		nod[x].fa = z;
		if (w != -1) nod[z].son[w] = x;
		nod[y].son[d^1] = nod[x].son[d];
		if (nod[x].son[d]) nod[nod[x].son[d]].fa = y;
		nod[y].fa = x;
		nod[x].son[d] = y;
		update(y);
		update(x);
	}
	void update(int p){
		int u = nod[p].son[0],v = nod[p].son[1];
		nod[p].sum = nod[p].sum1+nod[p].sum2;
		if (u) nod[p].sum += nod[u].sum;
		if (v) nod[p].sum += nod[v].sum;
	}
	int check(int x){
		int y = nod[x].fa;
		if (!y) return -1;
		if (nod[y].son[0] == x) return 0;
		if (nod[y].son[1] == x) return 1;
		return -1;
	}
	int getans(int p){
		if (p == 0) return 0;
		access(p);
		splay(p);
		return nod[p].sum-nod[nod[p].son[0]].sum;
	}
}lct;

struct State{
	int go[26],par,val;
	State(){}
	State(int _val):par(0),val(_val){
		memset(go,0,sizeof(go));
	}
}state[N<<1];
int root = 1,len = 1,last = 1;

int q,mask,lastans;
char opt[10];
void extend(int);
int trans();
int main(){
	scanf("%d",&q);
	scanf("%s",str);
	int slen = strlen(str);
	for (int i = 0;i < slen;i++) 
		extend(str[i]-'A');
	while (q--){
		scanf("%s",opt);
		if (opt[0] == 'A'){
			getstr(mask);
			int lenn = strlen(str);
			for (int i = 0;i < lenn;i++) 
				extend(str[i]-'A');
		}
		else{
			getstr(mask);
			lastans = 0;
			int u = trans();
			lastans = lct.getans(u);
			printf("%d\n",lastans);
			mask ^= lastans;
		}
	}
	return 0;
}
void extend(int w){
	int p = last,np = ++len;
	state[np] = State(state[p].val+1);		
	lct.insert(np,1);
	while (p && state[p].go[w] == 0) 
		state[p].go[w] = np,p = state[p].par;
	if (p == 0){
		state[np].par = root;
		lct.link(np,root);
	}
	else{
		int q = state[p].go[w];
		if (state[q].val == state[p].val+1){
			state[np].par = q;
			lct.link(np,q);
		}
		else{
			int nq = ++len;state[nq] = State(state[p].val+1);
			lct.insert(nq,0);
			memcpy(state[nq].go,state[q].go,sizeof(state[q].go));
			state[nq].par = state[q].par;
			state[q].par = nq;
			state[np].par = nq;
			lct.cut(q);
			lct.link(nq,state[nq].par);
			lct.link(q,nq);
			lct.link(np,nq);
			while (p && state[p].go[w] == q)
				state[p].go[w] = nq,p = state[p].par;
		}
	}
	last = np;
}
int trans(){
	int len = strlen(str);
	int p = root;
	for (int i = 0;i < len;i++)
		p = state[p].go[str[i]-'A'];
	return p;
}
posted @ 2017-04-27 16:06  VictBr  阅读(114)  评论(0编辑  收藏  举报