[雅礼集训2017 Day1]字符串(SAM+根号分治)

题目:LOJ#6031

题目描述:

\(f(s,w,l,r)\)表示字符串\(w\)的第\(l\)位到第\(r\)位,即\(w[l...r]\),在字符串\(s\)中出现的次数

给定一个长度为\(n\)的字符串\(s\),给定\(m\)个区间\([l_{i},r_{i}](i∈[0,m])\)\(q\)次询问,每次询问给出一个长度为\(k\)\(k\)为定值)的字符串\(w\),以及\(a,b\),求\(\sum_{i=a}^{b}f(s,w,l_{i},r_{i})\)

\(0 < n,m,k,q \leq 10^{5}\)\(\sum|w| \leq 10^{5}\)\(0 \leq l,r <m\)\(0 \leq a,b < k\)

蒟蒻题解:

观察数据发现\(\sum|w| \leq 10^{5}\),即\(k \cdot q \leq 10^{5}\),考虑根号分治

字符串\(s\)是固定的,且后面的询问都是问某个子串在\(s\)中出现的次数,不妨对\(s\)建后缀自动机

以下算复杂度时假设\(k \cdot q\)\(m\)\(n\)是同一级别的

  • \(q \geq k\)时,询问个数多,但是每个询问的字符串都较短,对于每个询问的字符串\(w\),它的区间个数是\(k^{2}\)级别的,对于每个询问,暴力枚举\(w\)的每个子串的左端点,在\(s\)建出的后缀自动机上跑,找出\(w\)的每个区间在\(s\)的出现次数,复杂度是\(q \cdot k^{2} = n \cdot k \leq n \cdot \sqrt{n}\),对于\(m\)个区间,可以找这\(m\)个区间中满足恰好\(l_{i}=L,r_{i}=R\)\(L,R\)\(w\)的一个子串,可以暴力算出其在\(s\)中的出现次数)且满足\(i∈[a,b]\)的区间个数,这个可以提前用动态数组存下每个区间的所有位置,然后二分去找,对于每个区间\(l_{i},r_{i}\),在一次询问中,它只会被二分一次,所以二分的总复杂度是\(q \cdot logn\),总的时间复杂度就是\(\Theta(n \sqrt{n} + q \cdot logn)\)

  • \(q < k\)时,询问的字符串较长,但是询问个数较少,可以把所有区间\([l_{i},r_{i}]\)记在位置\(r_{i}\)上,复杂度是\(q \cdot m\)的,在后缀自动机上跑时,当前右端点是\(r_{i}\),对应后缀自动机上点\(x\),记录当前匹配的长度\(y\),找左端点\(l_{i}\)可以倍增去找,找\(x\)的深度至少为\(y - r_{i} + l_{i}\)的深度最小的节点,总的时间复杂度是\(\Theta(q \cdot m + q \cdot n \cdot logk)\),但是其实跑后缀自动机很难跑到\(n\),远远跑不满的,如果被卡常的话可以将一部分换成第一种的解法

参考程序:

#include<bits/stdc++.h>
using namespace std;
#define Re register int
typedef long long ll;

const int N = 200005;
const int S = 320;
struct info
{
	int l, r;
}a[N];
int n, m, q, k, lst = 1, num = 1, len[N], lk[N], cnt[N], d[N], f[N], ch[N][28], fa[N][18];
ll ans;
char s[N];
vector<int> vt[S][S], qq[N];

inline int read()
{
	char c = getchar();
	int ans = 0;
	while (c < 48 || c > 57) c = getchar();
	while (c >= 48 && c <= 57) ans = (ans << 3) + (ans << 1) + (c ^ 48), c = getchar();
	return ans;
}

inline void write(ll x)
{
	int num = 0;
	char sc[25];
	if (!x) sc[num = 1] = 48;
	while (x) sc[++num] = x % 10 + 48, x /= 10;
	while (num) putchar(sc[num--]);
	putchar('\n');
}

inline void ins(int x)
{
	int y = ++num;
	len[y] = len[lst] + 1, f[y] = 1;
	while (lst && !ch[lst][x]) ch[lst][x] = y, lst = lk[lst];
	if (!lst) lk[y] = 1;
	else
	{
		int u = ch[lst][x];
		if (len[u] == len[lst] + 1) lk[y] = u;
		else
		{
			len[++num] = len[lst] + 1, lk[num] = lk[u], lk[u] = lk[y] = num;
			memcpy(ch[num], ch[u], sizeof ch[num]);
			while (ch[lst][x] == u) ch[lst][x] = num, lst = lk[lst];
		}
	}
	lst = y;
}

inline int find(int x, int y)
{
	for (Re i = 16; i >= 0; --i)
		if (len[fa[x][i]] >= y) x = fa[x][i];
	return x;
}

int main()
{
	n = read(), m = read(), q = read(), k = read();
	scanf("%s", s + 1);
	for (Re i = 1; i <= n; ++i) ins(s[i] - 97);
	for (Re i = 1; i <= num; ++i) ++cnt[len[i]];
	for (Re i = 1; i <= n; ++i) cnt[i] += cnt[i - 1];
	for (Re i = num; i; --i) d[cnt[len[i]]--] = i;
	for (Re i = num; i; --i) f[lk[d[i]]] += f[d[i]];
	if (q >= k)
	{
		for (Re i = 0; i < m; ++i)
		{
			int u = read() + 1, v = read() + 1;
			vt[u][v].push_back(i);
		}
		while (q--)
		{
			scanf("%s", s + 1), ans = 0;
			int u = read(), v = read();
			for (Re i = 1; i <= k; ++i)
			{
				int j = i - 1, t = 1;
				while (j < k)
				{
					++j;
					if (!ch[t][s[j] - 97]) break;
					t = ch[t][s[j] - 97];
					int w = vt[i][j].size();
					if (!w) continue;
					int l = 0, r = w - 1, ls = w, rs = -1;
					while (l <= r)
					{
						int mid = l + r >> 1;
						if (vt[i][j][mid] >= u) ls = mid, r = mid - 1;
						else l = mid + 1;
					}
					if (ls == w) continue;
					l = 0, r = w - 1;
					while (l <= r)
					{
						int mid = l + r >> 1;
						if (vt[i][j][mid] <= v) rs = mid, l = mid + 1;
						else r = mid - 1;
					}
					if (ls <= rs) ans += 1ll * f[t] * (rs - ls + 1);
				}
			}
			write(ans);
		}
	}
	else
	{
		for (Re i = 2; i <= num; ++i) fa[i][0] = lk[i];
		for (Re i = 1; (1 << i) < n; ++i)
			for (Re j = 1; j <= num; ++j) fa[j][i] = fa[fa[j][i - 1]][i - 1];
		for (Re i = 0; i < m; ++i) a[i].l = read() + 1, a[i].r = read() + 1;
		while (q--)
		{
			scanf("%s", s + 1), ans = 0;
			int u = read(), v = read();
			for (Re i = u; i <= v; ++i) qq[a[i].r].push_back(a[i].l);
			for (Re i = 1, j = 1, l = 0; i <= k; ++i)
			{
				int w = s[i] - 97;
				if (ch[j][w]) j = ch[j][w], ++l;
				else
				{
					while (j > 1 && !ch[j][w]) j = lk[j];
					if (ch[j][w]) l = len[j] + 1, j = ch[j][w];
					else l = 0;
				}
				w = qq[i].size();
				if (!w) continue;
				if (!l)
				{
					qq[i].clear();
					continue;
				}
				for (Re t = 0; t < w; ++t)
				{
					int uu = qq[i][t];
					if (i - uu + 1 > l) continue;
					ans += f[find(j, i - uu + 1)];
				}
				qq[i].clear();
			}
			write(ans);
		}
	}
	return 0;
}
posted @ 2021-04-01 23:52  clfzs  阅读(86)  评论(0)    收藏  举报