bzoj4566 [Haoi2016]找相同字符

[Haoi2016]找相同字符

Time Limit: 20 Sec Memory Limit: 256 MB

Description

给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两
个子串中有一个位置不同。

Input

两行,两个字符串s1,s2,长度分别为n1,n2。1 <=n1, n2<= 200000,字符串中只有小写字母

Output

输出一个整数表示答案

Sample Input

aabb

bbaa

Sample Output

10






我后缀自动机入门啦啦啦~
这道题有两种做法:
做法一:你可以把两个串分别建一个后缀自动机,然后再把两个串用 # 连接起来再建一个后缀自动机。每个自动机都可以得到有多少个不同的子串。所以答案等于前面两个小的自动机的答案减去后面这个大的自动机的答案。
我写的是做法二:
就是说你先用第一个串建一个后缀自动机,然后拿第二个串去跑 。(就是匹配不到了该跳 fail 就跳,和 ac 自动机差不多的那种。)
然后你记录一下每个节点被进入了多少次。再记录一下你当前匹配的长度。注意最后统计的时候是直接统计的匹配长度是 【L,R】 的,所以进每个节点时先减去不合法的。这样最后统计答案才正确。并且你跳链的时候匹配的长度要和你当前匹配长度取min。因为每个节点的管理区间是有限的,所以只能匹配你长度在区间里的串。




#include<bits/stdc++.h>
using namespace std;
const int maxn = 2e5 + 6;
struct lpl{
	int son[30], parent, num, len, k;
}node[maxn << 1];
char s[maxn];
int cnt = 1, last = 1, pr[maxn], rk[maxn << 1];
long long ans;

void insert(int t, int L){
	node[++cnt].len = L; node[cnt].num = 1; int i = last; last = cnt;
	while(!node[i].son[t] && i){node[i].son[t] = cnt; i = node[i].parent;}
	if(!i){node[last].parent = 1; return;} int now = node[i].son[t];
	if(node[now].len == node[i].len + 1){node[cnt].parent = now; return;}
	int nw = ++cnt; for(int j = i; j && node[j].son[t] == now; j = node[j].parent) node[j].son[t] = nw;
	node[nw] = node[now]; node[nw].len = node[i].len + 1; node[nw].num = 0;
	node[now].parent = nw; node[last].parent = nw;
}

void match(int u, int pos, int L){
	node[u].k++; ans -= 1ll * node[u].num * (node[u].len - L);
	if(s[pos] == 0) return;
	for(;!node[u].son[s[pos] - 'a' + 1] && u; u = node[u].parent);
	if(!u) match(u + 1, pos + 1, 0);
	else match(node[u].son[s[pos] - 'a' + 1], pos + 1, min(L, node[u].len) + 1); 
}

int main()
{
	scanf("%s", s + 1); int len = strlen(s + 1);
	for(int i = 1; i <= len; ++i) insert(s[i] - 'a' + 1, i);
	for(int i = 1; i <= cnt; ++i) pr[node[i].len]++;
	for(int i = 1; i <= len; ++i) pr[i] += pr[i - 1];
	for(int i = 1; i <= cnt; ++i){rk[pr[node[i].len]] = i; pr[node[i].len]--;}
	for(int i = cnt; i >= 1; --i) node[node[rk[i]].parent].num += node[rk[i]].num;
	scanf("%s", s + 1); match(1, 1, 0);
	for(int i = cnt; i >= 1; --i){
		int now = rk[i];
		node[node[now].parent].k += node[now].k;
		ans += 1ll * node[now].k * (node[now].len - node[node[now].parent].len) * node[now].num;
	}
	printf("%lld", ans);
	return 0;
}

posted @ 2018-09-21 07:54  沛霖  阅读(214)  评论(0编辑  收藏  举报