bzoj 2555: SubString 后缀自动机+LCT

题目:

给你一个字符串init,要求你支持两个操作

  • 在当前字符串的后面插入一个字符串
  • 询问字符串s在当前字符串中出现了几次?(作为连续子串)
    你必须在线支持这些操作。

题解:

询问某个字符串出现了几次,我们可以直接在后缀自动机上跑
输出最终走到的状态的right集合大小
如果失配则没有出现过

不难发现这样是正确的
所以现在问题在于求right集合的大小.由于这道题强制在线
所以我们需要动态维护right集合的大小
就是说每次加入一个cur后,把cur的parent树的所有祖先的siz全部+1
如果我们直接跳parent树会TLE.不要问我是怎么知道的
想为什么会TLE
如果是这样的字符串

input
aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa...(省略若干a)

这样每一次修改所有的祖先的时候复杂度都是\(O(n)\)的,也就是整体复杂度退化到了\(O(n^2)\)

所以我们需要用LCT来维护parent树
这样就可以做到\(O(nlogn)\)

一直在刷后缀自动机
为了杠这道题还特意花了一个下午去学LCT

Code

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <iostream>
using namespace std;
typedef long long ll;
inline void read(int &x){
	x=0;char ch;bool flag = false;
	while(ch=getchar(),ch<'!');if(ch == '-') ch=getchar(),flag = true;
	while(x=10*x+ch-'0',ch=getchar(),ch>'!');if(flag) x=-x;
}
const int maxn = 600010;
const int maxm = 3000010;
namespace lct{
	struct Node{
		Node *ch[2],*fa;
		int w,lazy;
	}*null;
	Node mem[maxn<<1],*it;
	inline void init(){
		it = mem;null = new Node;
		null->ch[0] = null->ch[1] = null->fa = null;
		null->w = null->lazy = 0;
	}
	inline Node* newNode(int val){
		Node *p = it++;p->w = val;p->lazy = 0;
		p->ch[0] = p->ch[1] = p->fa = null;
		return p;
	}
	inline void pushdown(Node *p){
		if(p->lazy == 0 || p == null) return;
		if(p->ch[0] != null){
			p->ch[0]->w += p->lazy;
			p->ch[0]->lazy += p->lazy;
		}
		if(p->ch[1] != null){
			p->ch[1]->w += p->lazy;
			p->ch[1]->lazy += p->lazy;
		}
		p->lazy = 0;
	}
	inline void rotate(Node *p,Node *x){
		int k = p == x->ch[1];
		Node *y = p->ch[k^1],*z = x->fa;
		if(z->ch[0] == x) z->ch[0] = p;
		if(z->ch[1] == x) z->ch[1] = p;
		if(y != null) y->fa = x;
		p->fa = z;p->ch[k^1] = x;
		x->fa = p;x->ch[k] = y;
	}
	inline bool isroot(Node *p){
		return p == null || (p->fa->ch[0] != p && p->fa->ch[1] != p);
	}
	inline void splay(Node *p){
		pushdown(p);
		while(!isroot(p)){
			Node *x = p->fa,*y = x->fa;
			pushdown(y);pushdown(x);pushdown(p);
			if(isroot(x)) rotate(p,x);
			else if((x->ch[0] == p) ^ (y->ch[0] == x)) rotate(p,x),rotate(p,y);
			else rotate(x,y),rotate(p,x);
		}
	}
	inline void Access(Node *x){
		Node *y = null;
		while(x != null){
			splay(x);x->ch[1] = y;
			y = x;x = x->fa;
		}
	}
	inline void link(Node *u,Node *v){
		Access(u);splay(u);
		u->fa = v;
	}
	inline void cut(Node *u){
		Access(u);splay(u);
		u->ch[0] = u->ch[0]->fa = null;
	}
	inline int query(Node *p){
		Access(p);splay(p);
		return p->w;
	}
	inline void inc(Node* x){
		Access(x);splay(x);
		x->lazy ++ ;x->w ++ ;
	}
}
struct Node{
	int nx[26];
	int len,fa;
}T[maxn<<1];
int last,nodecnt;
inline void init(){
	last = nodecnt = 0;
	T[0].fa = -1;
}
inline void insert(char cha){
	int c = cha - 'A',cur = ++ nodecnt,p;lct::newNode(0);
	T[cur].len = T[last].len + 1;
	for(p = last;p != -1 && !T[p].nx[c];p = T[p].fa) T[p].nx[c] = cur;
	if(p == -1){
		T[cur].fa = 0;
		lct::link(lct::mem+cur,lct::mem);
	}else{
		int q = T[p].nx[c];
		if(T[q].len == T[p].len + 1){
			T[cur].fa = q;
			lct::link(lct::mem+cur,lct::mem+q);
		}else{
			int co = ++ nodecnt;lct::newNode(query((lct::mem+q)));
			T[co] = T[q];T[co].len = T[p].len+1;
			lct::link(lct::mem+co,lct::mem+T[q].fa);
			for(;p != -1 && T[p].nx[c] == q;p = T[p].fa) T[p].nx[c] = co;
			lct::cut(lct::mem+q);
			lct::link(lct::mem+cur,lct::mem+co);
			lct::link(lct::mem+q,lct::mem+co);
			T[cur].fa = T[q].fa = co;
		}
	}last = cur;
	lct::inc(lct::mem+cur);
}
char cmd[10];
char s[maxm];
void decode(int mask){
	int len = strlen(s);
	for(int i=0;i<len;++i){
		mask = (mask*131 + i) % len;
		swap(s[i],s[mask]);
	}
}
int mask = 0;
inline int find(){
	int len = strlen(s),nw = 0;
	for(int i=0;i<len;++i){
		if(T[nw].nx[s[i]-'A']) nw = T[nw].nx[s[i]-'A'];
		else return 0;
	}
	int x = lct::query(lct::mem+nw);
	mask ^= x;return x;
}
int main(){
	init();lct::init();lct::newNode(0);
	int m;read(m);scanf("%s",s);
	int len = strlen(s);for(int i=0;i<len;++i) insert(s[i]);
	while(m--){
		scanf("%s",cmd);scanf("%s",s);
		decode(mask);
		if(*cmd == 'A'){
			int len = strlen(s);
			for(int i=0;i<len;++i) insert(s[i]);
		}else printf("%d\n",find());
	}
	getchar();getchar();
	return 0;
}
posted @ 2017-03-09 07:13  Sky_miner  阅读(182)  评论(0编辑  收藏  举报