后缀数组 杂题总结

前言

在阅读本文之前,请先确保你以了解一下东西的求法和定义:

  • \(sa_i\),就是后缀字典序的排名。

  • \(height_i\),即 \(\text{lcp}(sa_{i-1},sa_i)\)

  • \(\text{height}\) 的一些基本作用,比如求 \(\text{lcp}\)

题外话,本人对于以上数组全部依靠哈希进行求解,所以复杂度会比正常人多一个二分的 \(\log(n)\),不用在意它。

\(\texttt{1. [NOI2016] 优秀的拆分}\)

其实是一道非常 \(\texttt{Tricky}\) 的题目。

首先,不难想到,把这个 \(\texttt{AABB}\) 给它拆分成两段,分别进行计算。

具体来说,就是定义 \(f_i\),表示以 \(i\) 结束的 \(\texttt{AA}\) 串有多少个。

同理,定义 \(g_i\),表示以 \(i\) 开始的 \(\texttt{BB}\) 串有多少个。

显然,答案就是 \(\sum _{i=1}^{n-1}f_i\times g_{i+1}\)

其实,问题也就转化为了求 \(f,g\) 中的一个。(因为求完其中的一个之后,另一个倒过来求就是了)

然后我们考虑如何进行统计。

显然有一个暴力枚举左右端点,然后 \(\mathcal{O}(1)\) 检查(无论你是用哈希还是后缀数组)的一个思路。

数据很水,小小优化一下也就是 \(\texttt{95 pts}\) 的好成绩。

我们考虑如何加速这个过程。

这里需要使用一个非常奇怪的 \(\texttt{Trick}\),其实也是后缀数组题目中一个比较常见的技巧。

我们考虑枚举这个 \(\texttt{AA}\) 串的长度 \(\text{len}\),从 \(1\)\(\dfrac{n}{2}\)

然后每枚举一次,我们都在下标 \(len \times k,k\in \mathbb{Z}^{+}\) 的位置设置断点。

可以发现,每个长度为 \(\text{len}\)\(\texttt{AA}\) 串必定会经过恰好两个断点,并且一定是连续的断点。

所以我们考虑去枚举这两个断点,然后看这两个断点对于答案的贡献。

显然,如果只是枚举的话,时间复杂度是 \(\mathcal{O}(\dfrac{n}{1}+\dfrac{n}{2}+\dfrac{n}{3}+\dfrac{n}{4}...)\),也就是 \(\mathcal{O}(n\ln n)\)

然后我们考虑算答案的贡献。

首先,对于一个合法的 \(\texttt{AA}\) 串,他们经过的这两个断点应当分别属于这两个 \(\texttt{A}\) 中的一个。故我们考虑对这两个端点求 \(\texttt{lcp,lcs}\),假设长度分别为 \(\text{L,R}\)

然后假设这两个断点分别为 \(x,y\)

显然,如果 \(x+R\ge y-L+1\),那么对于点 \(y+R-1\),他就会对 \(f_{y+R-1}\) 产生 \(1\) 的贡献。

当然,更一般的,对于 \(i\in[y-L+\text{len},y+R-1]\)\(f_{i}\) 都会得到 \(1\) 的贡献。

对于 \(g\) 也同理,不再赘述,如果无法理解,画画图就可以解决。

注意,由于我们算的只能是长度为 \(\texttt{len}\) 的,为了防止算重,记得对 \(y-L+\text{len}\)\(y\)\(\max\),对 \(y+R-1\)\(y+len-1\)\(\min\)。(因为 "对于一个合法的 \(\texttt{AA}\) 串,他们经过的这两个断点应当分别属于这两个 \(\texttt{A}\) 中的一个" 这句话的钦定。)

然后,由于是区间加 \(1\),打一个差分就可以了。

至于 \(\texttt{lcs,lcp}\) 的求法,看个人喜好。可以选择 \(\texttt{ST+SA}\) 的做法,复杂度 \(\mathcal{O}(n\log n+n\ln n)\)。也可以直接二分加哈希,复杂度 \(\mathcal{O}(n\log n \ln n)\)

但是实际上,哈希比后缀数组快得多,不知道为什么。

#include <bits/stdc++.h>
using namespace std;
#define maxn 30005
#define ull unsigned long long
int t;
char a[maxn];
int n;
int f[maxn], g[maxn];
ull Has[maxn], p[maxn];
ull get(int l, int r)
{
	return Has[l] - Has[r + 1] * p[r - l + 1];
}
int main()
{
	scanf("%d", &t);
	while(t--)
	{
		scanf("%s", a + 1);
		n = strlen(a + 1);
		memset(f, 0, sizeof(f));
		memset(g, 0, sizeof(g));
		p[0] = 1, Has[n + 1] = 0;
		for (int i = n; i >= 1; --i) Has[i] = Has[i + 1] * 173 + a[i];
		for (int i = 1; i <= n; ++i) p[i] = p[i - 1] * 173;
		for (int len = 1; len * 2 <= n; ++len)
		{
			for (int i = len * 2; i <= n; i += len)
			{
				if(a[i] != a[i - len]) continue;
				int last = i - len;
				int L, R;
				int l = 1, r = len, mid, ans = 0;
				while(l <= r)
				{
					mid = (l + r) >> 1;
					if(get(last - mid + 1, last) == get(i - mid + 1, i)) l = mid + 1, ans = mid;
					else r = mid - 1; 
				}
				L = i - ans + 1, L = max(L + len - 1, i);
				l = 1, r = len, ans = 0;
				while(l <= r)
				{
					mid = (l + r) >> 1;
					if(get(last, last + mid - 1) == get(i, i + mid - 1)) l = mid + 1, ans = mid;
					else r = mid - 1;
				}
				R = i + ans - 1, R = min(R, i + len - 1);
				if(L <= R)
				{
					f[L - len * 2 + 1]++, f[R - len * 2 + 2]--;
					g[L]++, g[R + 1]--;
				}
			}
		} 
		for (int i = 1; i <= n; ++i) f[i] += f[i - 1], g[i] += g[i - 1];
		long long ans = 0;
		for (int i = 1; i <= n; ++i) ans += g[i] * 1ll * f[i + 1];
		cout << ans << endl;
	}
}

\(\texttt{2.[SP687] REPEATS - Repeats}\)

其实是对上面那个分段点技巧的一个应用,不过这道题感觉思维难度明显更高。

对于一个重复次数大于 \(1\) 的子串,我们假设把他单独拎出来成为一个字符串 \(S\)。假设其重复子串的长度为 \(\text{len}\)。显然有:\(\forall i\in [1,len - 1],S_{i}=S_{i+\text{len} \times k,k\in \mathbb{Z}}\)

故,对于原串中的位置 \(i\),如果他能成为上述重复子串中的一个起始位置,那么对于 \(i,i+\text{len}\) 的后缀,必然满足 \(\text{lcp}(i,i+\text{len})\ge \text{len}\),其重复子串的重复次数,也就是 \(\lfloor \dfrac{\text{lcp}(i,i+\text{len})}{\text{len}} \rfloor+1\)

但是显然,我们不可能既去枚举 \(i\),又去枚举 \(\text{len}\),我们只有做出取舍。

根据上一道题的思路,我们考虑去枚举这个 \(\text{len}\),然后枚举 \(\text{len}\times k,k\in \mathbb{Z}^+\)。显然,对于这些点,我们只对相邻的两个断点求一个 \(\text{lcp}\) 是不够的,因为他有可能还能往前面扩展。

其实也很简单,对于前面的贡献,其实也就是 \(\text{lcs}(i,i+\text{len})\),把刚刚我们推出来的给反过来,就可以得到前面的贡献。

故对于相邻的两个断点,\(i,i+\text{len}\),他们对答案产生的贡献,也就是 \(\lfloor \dfrac{\text{lcp}(i,i+\text{len})+\text{lcs}(i,i+\text{len})-1}{\text{len}} \rfloor+1\)。(减一是因为最长公共前后缀都包含了这个点本身,加起来会多算一次)

最后求一个 \(\max\) 即可。

时间复杂度的分析同上。

#include <bits/stdc++.h>
using namespace std;
#define maxn 300005
const int mod = 998344353;
int t;
char a[maxn];
int n;
int Has[maxn], p[maxn];
int get(int l, int r)
{
	return ((Has[r] - 1ll * Has[l - 1] * p[r - l + 1] % mod) % mod + mod) % mod;
}
int main()
{
	scanf("%d", &n);
	for (int i = 1; i <= n; ++i) cin >> a[i];
	p[0] = 1, Has[0] = 0;
	for (int i = 1; i <= n; ++i) Has[i] = (Has[i - 1] * 173ll + a[i]) % mod;
	for (int i = 1; i <= n; ++i) p[i] = (p[i - 1] * 173ll) % mod;
	int sum = 0;
	for (int len = 1; len <= n; ++len)
	{
		for (int i = len + 1; i <= n; i += len)
		{
			int last = i - len;
			if(a[i] != a[last]) continue;
			int L, R;
			int l = 1, r = last, mid, ans = 0;
			while(l <= r)
			{
				mid = (l + r) >> 1;
				if(get(last - mid + 1, last) == get(i - mid + 1, i)) l = mid + 1, ans = mid;
				else r = mid - 1; 
			}
			L = ans;
			l = 1, r = n - i + 1, ans = 0;
			while(l <= r)
			{
				mid = (l + r) >> 1;
				if(get(last, last + mid - 1) == get(i, i + mid - 1)) l = mid + 1, ans = mid;
				else r = mid - 1;
			}
			R = ans;
			sum = max((L + R - 1) / len + 1, sum);
		}
	} 
	cout << sum << endl;
}

\(\texttt{3.[NOI2015] 品酒大会}\)

一道将 \(\text{height}\) 数组与数据结构结合起来的一道题目。算是一道后缀数组题目应该有的基本形式。

我们先抛开第二小问不谈,其实就是让你求 \(\text{ans}_i=\sum_{i=1}^{n}\sum_{j=i+1}^{n}(\text{lcp(i,j)}\ge i)\)

有一个显然的思路,就是先通过后缀数组维护得到 \(\text{height,sa}\) 然后枚举 \(i,j\),通过 \(\texttt{ST}\) 表查询 \(\min \{ \text{height}_{{i+1,i+2..j}} \}\) 得到 \(\text{lcp}\),然后一个差分解决。

你可以得到 \(\texttt{40 pts}\) 的好成绩。

我们需要考虑加速这个过程。

我们发现一个有趣的事情,当我们在求 \(\text{ans}_i\) 的时候,我们可以通过将所有 \(\text{height}_j\ge i\) 并且彼此相邻的 \(j\) 给合并在一起。然后我们暴力枚举所有合并在一起的点集 \(S\),答案就是 \(\text{ans}_i=\sum \dfrac{\text{siz}_S\times (\text{siz}_S-1)}{2}\)。这个应该是非常容易得到的。

那就非常简单了。

我们可以考虑倒着枚举 \(i\),然后每次将 \(\text{height}_j=i\) 的所有 \(j\),与其相邻的点合并在一起,每合并一次,算一次答案,这样只需要合并 \(n\) 次,我们就可以得到最终的 \(\text{ans}_i\)

用并查集维护 \(\text{siz}\) 和合并即可。

至于最值,由于存在负数,考虑分别对一个集合求出 \(\text{min,max}\),然后对 \(\text{min}\times \text{min},\text{max}\times \text{max}\) 求最大值即可。

除开后缀数组,时间复杂度 \(\mathcal{O}(n \times \alpha(n))\)

#include <bits/stdc++.h>
using namespace std;
#define maxn 1000005
const int mod = 998244353;
char a[maxn];
int n;
int val[maxn];
int sa[maxn];
int Has[maxn], p[maxn];
int get(int l, int r)
{
	return ((Has[r] - 1ll * Has[l - 1] * p[r - l + 1] % mod) % mod + mod) % mod;
}
int lcp(int x, int y)
{
	int l = 1, r = min(n - x, n - y) + 1, mid, ans = 0;
	while(l <= r)
	{
		int mid = (l + r) >> 1;
		if(get(x, x + mid - 1) == get(y, y + mid - 1)) l = mid + 1, ans = mid;
		else r = mid - 1;
	}
	return ans;
}
bool cmp(int x, int y)
{
	if(a[x] != a[y]) return a[x] < a[y];
	if(a[x + 1] != a[y + 1]) return a[x + 1] < a[y + 1];
	if(a[x + 2] != a[y + 2]) return a[x + 2] < a[y + 2];
	if(a[x + 3] != a[y + 3]) return a[x + 3] < a[y + 3];
	if(a[x + 4] != a[y + 4]) return a[x + 4] < a[y + 4];
	int l = 1, r = min(n - x + 1, n - y + 1), mid, ans;
	while(l <= r)
	{
		mid = (l + r) >> 1;
		if(get(x, x + mid - 1) == get(y, y + mid - 1)) ans = mid, l = mid + 1;
		else r = mid - 1;
	}
	return a[ans + x] < a[ans + y];
}
int height[maxn];
int fa[maxn], siz[maxn];
long long maxx[maxn], minn[maxn], ans[maxn], sum[maxn];
long long nowsum, nowans;
int findroot(int x)
{
	if(fa[x] == x) return x;
	return fa[x] = findroot(fa[x]);
}
void unionn(int x, int y)
{
	int p = findroot(x), q = findroot(y);
	if(p == q) return;
	nowsum += 1ll * siz[p] * siz[q], nowans = max(nowans, max(maxx[p] * maxx[q], minn[q] * minn[p]));
	fa[q] = p, siz[p] += siz[q];
	maxx[p] = max(maxx[p], maxx[q]), minn[p] = min(minn[p], minn[q]);
}
vector<int> id[maxn];
int main()
{
	scanf("%d", &n);
	scanf("%s", a + 1);
	for (int i = 1; i <= n; ++i) scanf("%d", &val[i]);
	p[0] = 1, Has[0] = 0;
	for (int i = 1; i <= n; ++i) sa[i] = i, Has[i] = (Has[i - 1] * 173ll + a[i]) % mod;
	for (int i = 1; i <= n; ++i) p[i] = (p[i - 1] * 173ll) % mod;
	stable_sort(sa + 1, sa + n + 1, cmp);
	for (int i = 2; i <= n; ++i) height[i] = lcp(sa[i], sa[i - 1]), id[height[i]].push_back(i);
	id[0].push_back(1);
	for (int i = 1; i <= n; ++i) fa[i] = i, siz[i] = 1, minn[i] = maxx[i] = val[sa[i]];
	nowans = 0xf3f3f3f3f3f3f3f3;
	for (int i = n - 1; i >= 0; --i)
	{
		for (int j = 0; j < id[i].size(); ++j) unionn(id[i][j], id[i][j] - 1);
		ans[i] = nowsum, sum[i] = nowans;
	}
	for (int i = 0; i < n; ++i)
	{
		if(sum[i] == 0xf3f3f3f3f3f3f3f3) sum[i] = 0;
		printf("%lld %lld\n", ans[i], sum[i]);
	}
	return 0;
}

\(\texttt{4.[AHOI2013] 差异}\)

可以考虑直接套用上一题的 \(\text{ans}_i\)

通过差分,我们可以得到 \(\text{num}_i\) 表示后缀 \(\text{lcp}(j,k)=i\) 有多少个。

首先,对于题目中的 \(\sum \text{len}(T_i)+\text{len}(T_j)\) 这一部分的求解是非常简单的,这里不再赘述。

然后我们只需要考虑算出 \(\sum \text{lcp}(T_i,T_j)\times 2\) 即可。

咱就是说,这不就是 \(\sum 2\times i \times ans_i\) 吗。

分别计算,然后减去即可,时间复杂度同上。

当然,这样做其实有一点大财小用,完全可以直接用单调队列来解决。

#include <bits/stdc++.h>
using namespace std;
#define maxn 1000005
const int mod = 998244853;
char a[maxn];
int n;
int val[maxn];
int sa[maxn];
int Has[maxn], p[maxn];
int get(int l, int r)
{
	return ((Has[r] - 1ll * Has[l - 1] * p[r - l + 1] % mod) % mod + mod) % mod;
}
int lcp(int x, int y)
{
	int l = 1, r = min(n - x, n - y) + 1, mid, ans = 0;
	while(l <= r)
	{
		int mid = (l + r) >> 1;
		if(get(x, x + mid - 1) == get(y, y + mid - 1)) l = mid + 1, ans = mid;
		else r = mid - 1;
	}
	return ans;
}
bool cmp(int x, int y)
{
	if(a[x] != a[y]) return a[x] < a[y];
	if(a[x + 1] != a[y + 1]) return a[x + 1] < a[y + 1];
	if(a[x + 2] != a[y + 2]) return a[x + 2] < a[y + 2];
	if(a[x + 3] != a[y + 3]) return a[x + 3] < a[y + 3];
	if(a[x + 4] != a[y + 4]) return a[x + 4] < a[y + 4];
	if(a[x + 5] != a[y + 5]) return a[x + 5] < a[y + 5];
	if(a[x + 6] != a[y + 6]) return a[x + 6] < a[y + 6];
	if(a[x + 7] != a[y + 7]) return a[x + 7] < a[y + 7];
	if(a[x + 8] != a[y + 8]) return a[x + 8] < a[y + 8];
	if(a[x + 9] != a[y + 9]) return a[x + 9] < a[y + 9];
	if(a[x + 10] != a[y + 10]) return a[x + 10] < a[y + 10];
	if(a[x + 11] != a[y + 11]) return a[x + 11] < a[y + 11];
	if(a[x + 12] != a[y + 12]) return a[x + 12] < a[y + 12];
	if(a[x + 13] != a[y + 13]) return a[x + 13] < a[y + 13];
	if(a[x + 14] != a[y + 14]) return a[x + 14] < a[y + 14];
	if(a[x + 15] != a[y + 15]) return a[x + 15] < a[y + 15];
	int l = 1, r = min(n - x + 1, n - y + 1), mid, ans;
	while(l <= r)
	{
		mid = (l + r) >> 1;
		if(get(x, x + mid - 1) == get(y, y + mid - 1)) ans = mid, l = mid + 1;
		else r = mid - 1;
	}
	return a[ans + x] < a[ans + y];
}
int height[maxn];
int fa[maxn], siz[maxn];
long long ans[maxn];
long long nowsum, sum;
int findroot(int x)
{
	if(fa[x] == x) return x;
	return fa[x] = findroot(fa[x]);
}
void unionn(int x, int y)
{
	int p = findroot(x), q = findroot(y);
	nowsum += 1ll * siz[p] * siz[q];
	fa[q] = p, siz[p] += siz[q];
}
vector<int> id[maxn];
int main()
{
	scanf("%s", a + 1);
	n = strlen(a + 1);
	for (int i = 1; i <= n; ++i) sum += 1ll * (n - i + 1) * (n - 1);
	p[0] = 1, Has[0] = 0;
	for (int i = 1; i <= n; ++i) sa[i] = i, Has[i] = (Has[i - 1] * 1331ll + a[i]) % mod;
	for (int i = 1; i <= n; ++i) p[i] = (p[i - 1] * 1331ll) % mod;
	stable_sort(sa + 1, sa + n + 1, cmp);
	for (int i = 2; i <= n; ++i) height[i] = lcp(sa[i], sa[i - 1]), id[height[i]].push_back(i);
	id[0].push_back(1);
	for (int i = 1; i <= n; ++i) fa[i] = i, siz[i] = 1;
	for (int i = n - 1; i >= 0; --i)
	{
		for (int j = 0; j < id[i].size(); ++j) unionn(id[i][j], id[i][j] - 1);
		ans[i] = nowsum;
	}
	for (int i = 0; i <= n - 1; ++i) ans[i] -= ans[i + 1], sum -= 2ll * i * ans[i];
	cout << sum << endl;
	return 0;
}

\(\texttt{5.[HAOI2016] 找相同字符}\)

其实,本质上也是求 \(\sum i\times \text{ans}_i\) 的一个形式。

显然是将 \(b\) 串接在 \(a\) 串后面,然后跑后缀数组。当然,记得在中间补一个特殊字符断开。

然后我们考虑如何计算答案。

其实也是一样的。

我们考虑根据上面两个题目的思路,在并查集合并的时候,将 \(\text{ans}\) 加上 \(\text{siz}_x\times \text{siz}_y\)

需要注意的一点是,我们算要出的这个长度是用 \(a\) 串去与 \(b\) 串匹配,而如果你直接相乘,那么就会导致 \(a\)\(a\) 匹配,\(b\)\(b\) 匹配的不合法方案存在。

故,我们考虑对每一个集合分别维护 \(\text{siz1},\text{siz2}\),分别表示这个集合中,\(a\) 中元素的数量,与 \(b\) 中元素的数量。

那么在合并的时候,我们只需要让答案加上 \(\text{siz1}_x\times \text{siz2}_y+\text{siz2}_x\times \text{siz1}_y\) 即可。

时间复杂度同上。

#include <bits/stdc++.h>
using namespace std;
#define maxn 1000005
const int mod = 998244853;
char a[maxn], b[maxn];
int n, m;
int val[maxn];
int sa[maxn];
unsigned long long Has[maxn], p[maxn];
unsigned long long get(int l, int r)
{
	return Has[r] - Has[l - 1] * p[r - l + 1];
}
int lcp(int x, int y)
{
	int l = 1, r = min(n - x, n - y) + 1, mid, ans = 0;
	while(l <= r)
	{
		int mid = (l + r) >> 1;
		if(get(x, x + mid - 1) == get(y, y + mid - 1)) l = mid + 1, ans = mid;
		else r = mid - 1;
	}
	return ans;
}
bool cmp(int x, int y)
{
	if(a[x] != a[y]) return a[x] < a[y];
	if(a[x + 1] != a[y + 1]) return a[x + 1] < a[y + 1];
	if(a[x + 2] != a[y + 2]) return a[x + 2] < a[y + 2];
	if(a[x + 3] != a[y + 3]) return a[x + 3] < a[y + 3];
	if(a[x + 4] != a[y + 4]) return a[x + 4] < a[y + 4];
	if(a[x + 5] != a[y + 5]) return a[x + 5] < a[y + 5];
	if(a[x + 6] != a[y + 6]) return a[x + 6] < a[y + 6];
	if(a[x + 7] != a[y + 7]) return a[x + 7] < a[y + 7];
	if(a[x + 8] != a[y + 8]) return a[x + 8] < a[y + 8];
	if(a[x + 9] != a[y + 9]) return a[x + 9] < a[y + 9];
	if(a[x + 10] != a[y + 10]) return a[x + 10] < a[y + 10];
	if(a[x + 11] != a[y + 11]) return a[x + 11] < a[y + 11];
	if(a[x + 12] != a[y + 12]) return a[x + 12] < a[y + 12];
	if(a[x + 13] != a[y + 13]) return a[x + 13] < a[y + 13];
	if(a[x + 14] != a[y + 14]) return a[x + 14] < a[y + 14];
	if(a[x + 15] != a[y + 15]) return a[x + 15] < a[y + 15];
	int l = 15, r = min(n - x + 1, n - y + 1), mid, ans;
	while(l <= r)
	{
		mid = (l + r) >> 1;
		if(get(x, x + mid - 1) == get(y, y + mid - 1)) ans = mid, l = mid + 1;
		else r = mid - 1;
	}
	return a[ans + x] < a[ans + y];
}
int height[maxn];
int fa[maxn], siz1[maxn], siz2[maxn];
long long ans[maxn];
long long nowsum, sum;
int findroot(int x)
{
	if(fa[x] == x) return x;
	return fa[x] = findroot(fa[x]);
}
void unionn(int x, int y)
{
	int p = findroot(x), q = findroot(y);
	if(siz1[p] + siz2[p] < siz1[q] + siz2[q]) swap(p, q);
	nowsum += 1ll * siz1[p] * siz2[q];
	nowsum += 1ll * siz2[p] * siz1[q];
	fa[q] = p, siz1[p] += siz1[q], siz2[p] += siz2[q];
}
vector<int> id[maxn];
int main()
{
	scanf("%s", a + 1);
	a[(int)strlen(a + 1) + 1] = '.';
	scanf("%s", b + 1);
	n = strlen(a + 1), m = strlen(b + 1);
	for (int i = n + 1; i <= n + m; ++i) a[i] = b[i - n];
	n += m;
	m = n - m;
	p[0] = 1, Has[0] = 0;
	for (int i = 1; i <= n; ++i) sa[i] = i, Has[i] = Has[i - 1] * 137 + a[i], p[i] = p[i - 1] * 137;
		cerr << "qwq" << endl;
	stable_sort(sa + 1, sa + n + 1, cmp);
	for (int i = 2; i <= n; ++i) height[i] = lcp(sa[i], sa[i - 1]), id[height[i]].push_back(i);
	id[0].push_back(1);
	for (int i = 1; i <= n; ++i) fa[i] = i, siz1[i] = (sa[i] <= m), siz2[i] = (sa[i] > m);
	for (int i = n - 1; i >= 0; --i)
	{
		for (int j = 0; j < id[i].size(); ++j) unionn(id[i][j], id[i][j] - 1);
		ans[i] = nowsum;
	}
	for (int i = 0; i <= n - 1; ++i) ans[i] -= ans[i + 1], sum += i * 1ll * ans[i];
	cout << sum << endl;
	return 0;
}

\(\texttt{6. [JSOI2007] 字符加密}\)

\(\texttt{Joker}\) 题,就是个板子。

首先破环为链,在原字符串后面再复制一遍,然后跑后缀数组。

显然,此时已经排好序了,直接从 \(1\)\(n\times 2\) 枚举 \(sa\) 数组,如果 \(sa_i\le n\),输出 \(a_{sa_i+n-1}\) 即可。

没什么好说的,不过就是一个正常的破环为链的技巧。

#include <bits/stdc++.h>
using namespace std;
#define maxn 1000005
const int mod = 998244853;
char a[maxn], b[maxn];
int n, m;
int val[maxn];
int sa[maxn];
unsigned long long Has[maxn], p[maxn];
unsigned long long get(int l, int r)
{
	return Has[r] - Has[l - 1] * p[r - l + 1];
}
int lcp(int x, int y)
{
	int l = 1, r = min(n - x, n - y) + 1, mid, ans = 0;
	while(l <= r)
	{
		int mid = (l + r) >> 1;
		if(get(x, x + mid - 1) == get(y, y + mid - 1)) l = mid + 1, ans = mid;
		else r = mid - 1;
	}
	return ans;
}
bool cmp(int x, int y)
{
	if(a[x] != a[y]) return a[x] < a[y];
	if(a[x + 1] != a[y + 1]) return a[x + 1] < a[y + 1];
	if(a[x + 2] != a[y + 2]) return a[x + 2] < a[y + 2];
	if(a[x + 3] != a[y + 3]) return a[x + 3] < a[y + 3];
	if(a[x + 4] != a[y + 4]) return a[x + 4] < a[y + 4];
	if(a[x + 5] != a[y + 5]) return a[x + 5] < a[y + 5];
	if(a[x + 6] != a[y + 6]) return a[x + 6] < a[y + 6];
	if(a[x + 7] != a[y + 7]) return a[x + 7] < a[y + 7];
	if(a[x + 8] != a[y + 8]) return a[x + 8] < a[y + 8];
	if(a[x + 9] != a[y + 9]) return a[x + 9] < a[y + 9];
	if(a[x + 10] != a[y + 10]) return a[x + 10] < a[y + 10];
	if(a[x + 11] != a[y + 11]) return a[x + 11] < a[y + 11];
	if(a[x + 12] != a[y + 12]) return a[x + 12] < a[y + 12];
	if(a[x + 13] != a[y + 13]) return a[x + 13] < a[y + 13];
	if(a[x + 14] != a[y + 14]) return a[x + 14] < a[y + 14];
	if(a[x + 15] != a[y + 15]) return a[x + 15] < a[y + 15];
	int l = 15, r = min(n - x + 1, n - y + 1), mid, ans;
	while(l <= r)
	{
		mid = (l + r) >> 1;
		if(get(x, x + mid - 1) == get(y, y + mid - 1)) ans = mid, l = mid + 1;
		else r = mid - 1;
	}
	return a[ans + x] < a[ans + y];
}
int main()
{
	scanf("%s", a + 1);
	n = strlen(a + 1);
	for (int i = n + 1; i <= n + n; ++i) a[i] = a[i - n];
	n += n;
	p[0] = 1, Has[0] = 0;
	for (int i = 1; i <= n; ++i) sa[i] = i, Has[i] = Has[i - 1] * 137 + a[i], p[i] = p[i - 1] * 137;
	stable_sort(sa + 1, sa + n + 1, cmp);
	for (int i = 1; i <= n; ++i)
	{
		if(sa[i] > n / 2) continue;
		printf("%c", a[sa[i] + n / 2 - 1]);
	}
	puts("");
	return 0;
}

\(\texttt{7. [JSOI2015] 串分割}\)

首先,字符串里面没有 \(0\),其次给定了你要分成 \(k\) 个,以及求得是每段对应的十进制数的最大值的最小值。显然可以判定,这个作为最大值的子串长度一定为 \(\lceil\dfrac{n}{k}\rceil\)

然后先破环为链,把字符串直接复制一遍。

然后我们考虑一个二分,因为所有长度为定值的字符串必然可以在后缀数组得到的排名中一一对应,故我们直接二分其在后缀数组的排名。

接着为了方便,我们令 \(\text{len} = \lceil\dfrac{n}{k}\rceil,rnk_{sa_i}=i\)

首先先思考一个比较劣的对于 \(x\)\(\text{check}\)。(注意 \(x\) 只是一个排名)

不难想到枚举一个起始点,然后向后匹配 \(k\) 次,判断能否匹配出一个 \(\ge n\) 的长度。

对于每次匹配,假设当前枚举到 \(i\),如果 \(rnk_i\le x\),那么直接向后匹配 \(\text{len}\) 位,\(i\to i+\text{len}\),否则,至少也可以匹配 \(\text{len}-1\) 位,\(i\to i+\text{len}-1\)

我们看能否有一个起始点匹配长度到 \(n\) 即可。

但是你每次尽可能多的匹配 \(\text{len}\) 个真的是最优的吗?

但其实本质上,你从 \(i\) 可以匹配 \(\text{len}\) 位的情况下匹配 \(\text{len}-1\) 位,你最后也一定需要匹配回来,而你为了以后的那个匹配回来放弃当前的最优,不如在这里用 \(\text{len}\),然后让那里更劣,本质上效果是一样的。

故这个尽可能多的匹配的策略是正确的。

至于那个起始点,经过思考,其实你只需要枚举 \(i\in[1,\text{len}]\) 的就够了。

#include <bits/stdc++.h>
using namespace std;
#define maxn 1000005
const int mod = 998244853;
char a[maxn], b[maxn];
int n, m, len;
int val[maxn];
int sa[maxn], rnk[maxn];
unsigned long long Has[maxn], p[maxn];
unsigned long long get(int l, int r)
{
	return Has[r] - Has[l - 1] * p[r - l + 1];
}
int lcp(int x, int y)
{
	int l = 1, r = min(n - x, n - y) + 1, mid, ans = 0;
	while(l <= r)
	{
		int mid = (l + r) >> 1;
		if(get(x, x + mid - 1) == get(y, y + mid - 1)) l = mid + 1, ans = mid;
		else r = mid - 1;
	}
	return ans;
}
bool cmp(int x, int y)
{
	if(a[x] != a[y]) return a[x] < a[y];
	if(a[x + 1] != a[y + 1]) return a[x + 1] < a[y + 1];
	if(a[x + 2] != a[y + 2]) return a[x + 2] < a[y + 2];
	if(a[x + 3] != a[y + 3]) return a[x + 3] < a[y + 3];
	if(a[x + 4] != a[y + 4]) return a[x + 4] < a[y + 4];
	if(a[x + 5] != a[y + 5]) return a[x + 5] < a[y + 5];
	if(a[x + 6] != a[y + 6]) return a[x + 6] < a[y + 6];
	if(a[x + 7] != a[y + 7]) return a[x + 7] < a[y + 7];
	if(a[x + 8] != a[y + 8]) return a[x + 8] < a[y + 8];
	if(a[x + 9] != a[y + 9]) return a[x + 9] < a[y + 9];
	if(a[x + 10] != a[y + 10]) return a[x + 10] < a[y + 10];
	if(a[x + 11] != a[y + 11]) return a[x + 11] < a[y + 11];
	if(a[x + 12] != a[y + 12]) return a[x + 12] < a[y + 12];
	if(a[x + 13] != a[y + 13]) return a[x + 13] < a[y + 13];
	if(a[x + 14] != a[y + 14]) return a[x + 14] < a[y + 14];
	if(a[x + 15] != a[y + 15]) return a[x + 15] < a[y + 15];
	int l = 15, r = min(n - x + 1, n - y + 1), mid, ans;
	while(l <= r)
	{
		mid = (l + r) >> 1;
		if(get(x, x + mid - 1) == get(y, y + mid - 1)) ans = mid, l = mid + 1;
		else r = mid - 1;
	}
	return a[ans + x] < a[ans + y];
}
bool check(int mid)
{
	for (int i = 1; i <= len; ++i)
	{
		int id = i;
		for (int j = 1; j <= m; ++j)
		{
			if(rnk[id] > mid) id += len - 1;
			else id += len;
		}
		if(id - i >= n / 2) return true;
	}
	return false;
}
int main()
{
	scanf("%d %d", &n, &m);
	len = ceil(n / (m * 1.0));
	scanf("%s", a + 1);
	for (int i = n + 1; i <= n + n; ++i) a[i] = a[i - n];
	n += n;
	p[0] = 1, Has[0] = 0;
	for (int i = 1; i <= n; ++i) sa[i] = i, Has[i] = Has[i - 1] * 137 + a[i], p[i] = p[i - 1] * 137;
	stable_sort(sa + 1, sa + n + 1, cmp);
	for (int i = 1; i <= n; ++i) rnk[sa[i]] = i;
	int l = 0, r = n, mid, ans = 0;
	while(l <= r)
	{
		mid = (l + r) >> 1;
		if(check(mid)) ans = mid, r = mid - 1;
		else l = mid + 1;
	}
	int id = 0;
	for (int i = 1; i <= n / 2; ++i) if(rnk[i] == ans)
	{
		id = i;
		break;
	}
	for (int i = id; i <= id + len - 1; ++i) putchar(a[i]);
	return 0;
}

后记

总的来说,其实后缀数组的题目本质上分为一下三类:

  • 通过确定需要的字符串的长度(无论是去直接枚举,还是二分),然后通过相邻的两个断点,结算得到其对于答案的贡献,注意不要算重。

  • 通过对 \(\text{height}\) 数组,使用数据结构进行维护,达到你要求的答案的目的。

  • 通过将一段后缀转化为排名,更好的去储存或者遍历字符串,方便操作(比如第七题的二分)。

posted @ 2024-04-09 20:10  Saltyfish6  阅读(56)  评论(0编辑  收藏  举报
Document