BZOJ 4502 串

Description

传送门


Solution

首先注意到很多字符串有很多不同的方式进行划分,这样子为了避免不重不漏,就需要对其按特定形式划分。考虑到\(AC\)自动机上\(fail\)指针的含义是状态\(x\)的最长可识别后缀,那么我们就把划分方式定义为使划分的后半部分子串最长即可。

这样的话考虑在\(AC\)自动机上对这样一个串进行匹配会发生什么,注意到我们最后停止在了某个状态,该状态的\(fail\)指针指向的就是最长的可识别后缀,那么只需关心前缀能否识别即可。当我们的串第一次跳\(fail\)指针跳到的状态就是该串的最长可识别前缀,所以只要我们划分的前缀的长度比我们第一次跳到的状态的长度小的话就是可识别的。也就是说,一个好串在跳了第一次\(fail\)指针后在\(AC\)自动机上走的步数应该小于最后停留的状态的长度也就是该状态在\(AC\)自动机上的深度。

现在考虑枚举好串的长度来进行\(dp\)\(dp_{i, j}\)表示在第一次跳了\(fail\)指针后走了\(i\)步,现在在\(j\)状态上的好串个数,这样子就能在\(AC\)自动机上转移了,有\(dp\)状态转移方程

\[dp_{i + 1, trie_{k, c}} += dp_{i, k} (dep_{trie_{k, c}} > i) \]

其中\(c\)是当前位置枚举的字符。

初始化就是在\(AC\)自动机上找到所有跳一次\(fail\)指针能跳到的位置将\(dp_{1, x} + 1\),可以在构造\(AC\)自动机分类讨论的时候随便实现,同理\(dep\)数组也可以在构造\(AC\)自动机的时候随便实现。

到这里这个题差不多就做完了,但是其实还有一种情况没有考虑到,如果一个串没有跳过\(fail\)指针,那么只要他能分成两个非空的串,它就是一个好串,也就是只要\(fail_i != 0\),就会多出来一个好串,暴力遍历整个\(AC\)自动机上所有节点就行。

这题数据好泄,没开\(long long\)只有\(10pts\),害的我还去人工查错了\(5min\)


Code

#include <iostream>
#include <cstdio>
#include <cstring>
#include <queue>

using namespace std;

#define ll long long

const int N = 300050;
const int M = 50; 

int trie[N][26], cnt, n, root, fail[N], dep[N], maxx;

ll ans, dp[70][N];

char st[N];

int Max(int a, int b)
{
	return a > b ? a : b;
}

void Insert(char st[N])
{
	int l = strlen(st + 1);
	maxx = Max(maxx, 2 * l);
	int now = root;
	for (int i = 1; i <= l; i++)
	{
		int c = st[i] - 'a';
		if (!trie[now][c]) trie[now][c] = ++cnt;
		dep[trie[now][c]] = dep[now] + 1;
		now = trie[now][c];
	}
	return;
}

void Build_AC_auto()
{
	queue<int> q;
	for (int i = 0; i <= 25; i++) if (trie[root][i]) q.push(trie[root][i]);
	while (!q.empty())
	{
		int u = q.front(); q.pop();
		for (int i = 0; i <= 25; i++) 
			if (trie[u][i]) fail[trie[u][i]] = trie[fail[u]][i], q.push(trie[u][i]);
			else trie[u][i] = trie[fail[u]][i], dp[1][trie[u][i]]++;
	}	
	return;
}

int main()
{
	scanf("%d", &n);
	for (int i = 1; i <= n; i++) scanf("%s", st + 1), Insert(st);
	Build_AC_auto();
	for (int i = 1; i <= cnt; i++) 
		if (fail[i]) ans++;
	for (int i = 1; i <= maxx; i++)
		for (int j = 1; j <= cnt; j++)
			if (dp[i][j])
			{
				ans += dp[i][j];
				for (int c = 0; c <= 25; c++)
					if (dep[trie[j][c]] > i) dp[i + 1][trie[j][c]] += dp[i][j];
			}
	printf("%lld", ans);
	return 0;
}
posted @ 2020-06-12 10:01  Tian-Xing  阅读(93)  评论(0编辑  收藏  举报