弱省互测#0 t2

题意

给定两个字符串 A 和 B,求下面四个问题的答案:
1.在 A 的子串中,不是 B 的子串的字符串的数量。
2.在 A 的子串中,不是 B 的子序列的字符串的数量。
3.在 A 的子序列中,不是 B 的子串的字符串的数量。
4.在 A 的子序列中,不是 B 的子序列的字符串的数量。
其中子串是指本质不同的子串,不同的位置相同的串也只算一个串
|S|<=2000

分析

构造俩自动机然后同时跑

题解

构造一个子序列自动机,再构造一个后缀自动机,然后从根依次转移,记录状态上的信息,记忆化一下。

#include <bits/stdc++.h>
using namespace std;
const int mo=1e9+7, Lim=100005;
int nid;
struct node {
	node *c[26], *f;
	int len, s, id;
	bool flag;
	void init() {
		memset(c, 0, sizeof c);
		f=0;
		len=s=0;
		flag=0;
	}
}Po[Lim], *iT=Po, *root[2], *last[2], *rt[2];
node *newnode() {
	iT->init();
	return iT++;
}
void init() {
	root[0]=newnode();
	root[1]=newnode();
	last[0]=root[0];
	last[1]=root[1];
	rt[0]=newnode();
	rt[1]=newnode();
}
void add1(int who, int ch) {
	node *now=last[who], *x=newnode();
	last[who]=x;
	x->len=now->len+1;
	x->id=nid++;
	for(; now && !now->c[ch]; now=now->f) {
		now->c[ch]=x;
	}
	if(!now) {
		x->f=root[who];
		return;
	}
	node *y=now->c[ch];
	if(y->len==now->len+1) {
		x->f=y;
		return;
	}
	node *z=newnode();
	*z=*y;
	z->id=nid++;
	z->len=now->len+1;
	x->f=y->f=z;
	for(; now && now->c[ch]==y; now=now->f) {
		now->c[ch]=z;
	}
}
void cal(node *x) {
	if(x->flag) {
		return;
	}
	x->flag=1;
	x->s=1;
	for(int ch=0; ch<26; ++ch) {
		if(x->c[ch]) {
			cal(x->c[ch]);
			x->s+=x->c[ch]->s;
			if(x->s>=mo) {
				x->s-=mo;
			}
		}
	}
}
void build1(int who, char *s) {
	nid=1;
	for(; *s; ++s) {
		add1(who, *s-'a');
	}
	cal(root[who]);	
}
void build2(int who, char *s) {
	nid=1;
	static node *tc[26];
	memset(tc, 0, sizeof tc);
	for(; *s; --s) {
		node *now=newnode();
		now->id=nid++;
		for(int ch=0; ch<26; ++ch) {
			now->c[ch]=tc[ch];
		}
		tc[*s-'a']=now;
	}
	for(int ch=0; ch<26; ++ch) {
		rt[who]->c[ch]=tc[ch];
	}
	cal(rt[who]);
}
int vis[5005][5005];
int getans(node *x, node *y) {
	if(!x) {
		return 0;
	}
	if(!y) {
		return x->s;
	}
	if(vis[x->id][y->id]!=-1) {
		return vis[x->id][y->id];
	}
	int ret=0;
	for(int ch=0; ch<26; ++ch) {
		ret+=getans(x->c[ch], y->c[ch]);
		if(ret>=mo) {
			ret-=mo;
		}
	}
	return vis[x->id][y->id]=ret;
}
char s[2][2005], *it[2];
int main() {
	init();
	scanf("%s%s", s[0], s[1]);
	build1(0, s[0]);
	build1(1, s[1]);
	for(int i=0; i<2; ++i) {
		for(it[i]=s[i]; *(it[i]+1); ++it[i]);
	}
	build2(0, it[0]);
	build2(1, it[1]);
	memset(vis, -1, sizeof vis); printf("%d\n", getans(root[0], root[1]));
	memset(vis, -1, sizeof vis); printf("%d\n", getans(root[0], rt[1]));
	memset(vis, -1, sizeof vis); printf("%d\n", getans(rt[0], root[1]));
	memset(vis, -1, sizeof vis); printf("%d\n", getans(rt[0], rt[1]));
	return 0;
}
posted @ 2015-11-22 18:42  iwtwiioi  阅读(309)  评论(0编辑  收藏  举报