后缀数组(Suffix Array)学习笔记

感觉 OI-wiki 在相关方面讲得很牛啊!拜谢 OI-wiki 的后缀数组部分


一些约定

字符串 \(S\) 的长度为 \(|S|\),是下标从 \(1\) 开始(到 \(|S|\))的字符串。特别地,若无特殊说明,约定字符串 \(s\) 的长度为 \(n\)

后缀 \(i\) 代表以 \(i\) 开头的后缀。

若无特殊说明,则两个字符串 \(s_1\)\(s_2\) 的大小关系就是它们字典序的大小关系。

后缀数组是什么?

后缀数组(Suffix Array),可以直接表示字符串 \(s\) 的后缀的字典序大小关系。

定义 \(\text{sa}_i\) 表示按照字典序排序后,第 \(i\) 大的后缀的开头的下标;定义 \(\text{rk}_i\) 表示后缀 \(i\) 按照字典序排序后的排名。

显然,在求出后缀数组(\(\text{sa}\))后,我们有 \(\text{sa}_{\text{rk}_i} = \text{rk}_{\text{sa}_i} = i\)

如何求出后缀数组?

考虑最暴力的求法:将 \(s\) 的所有后缀存储起来,并直接 sort 起来。容易得到,该做法的时间复杂度是 \(O(n^2 \log n)\) 的。

考虑优化它:只要我们是基于比较的排序(即使用重定义 cmpsort),复杂度中的 \(O(n \log n)\) 就无法去掉。考虑如何优化比较两个字符串 \(s_1\)\(s_2\) 的大小关系。一个想法是,判断 \(s_1\)\(s_2\) 的大小关系只需要找到它们第一个不相等的字符,并直接比较即可。我们可以用二分 + 哈希的方式做这件事。该做法的时间复杂度为 \(O(n \log^2 n)\)。可能再使用一些手法可以做到 \(O(n \log n)\) 地求出后缀数组,但我不会,所以这里先不说。以后记得补一下

这种 \(O(n \log^2 n)\) 做法的代码:

点击查看代码
#include <bits/stdc++.h>
#define ll long long
#define mid (l + r + 1 >> 1)

using namespace std;

constexpr int N = 1e6 + 5;
constexpr int base1 = 31, mod1 = 998244353;
constexpr int base2 = 171, mod2 = 998244853;

int n;
int sa[N];
ll p1[N], p2[N], hs1[N], hs2[N];
char s[N];

bool check(int l1, int r1, int l2, int r2) {
	int hsx1 = (hs1[r1] - hs1[l1 - 1] * p1[r1 - l1 + 1] % mod1 + mod1) % mod1;
	int hsx2 = (hs2[r1] - hs2[l1 - 1] * p2[r1 - l1 + 1] % mod2 + mod2) % mod2;
	int hsy1 = (hs1[r2] - hs1[l2 - 1] * p1[r2 - l2 + 1] % mod1 + mod1) % mod1;
	int hsy2 = (hs2[r2] - hs2[l2 - 1] * p2[r2 - l2 + 1] % mod2 + mod2) % mod2;
	if (hsx1 == hsy1 && hsx2 == hsy2) {
		return true;
	} else {
		return false;
	}
}

signed main() {
	ios::sync_with_stdio(false);
	cin.tie(nullptr); cout.tie(nullptr);
	
	cin >> s + 1;
	n = strlen(s + 1);
	
	p1[0] = p2[0] = 1ll;
	for (int i = 1; i <= n; ++i) {
		p1[i] = p1[i - 1] * base1 % mod1;
		p2[i] = p2[i - 1] * base2 % mod2;
		hs1[i] = (hs1[i - 1] * base1 + s[i]) % mod1;
		hs2[i] = (hs2[i - 1] * base2 + s[i]) % mod2;
	}
	
	for (int i = 1; i <= n; ++i) {
		sa[i] = i;
	}
	
	sort(sa + 1, sa + n + 1, [](int x, int y) {
		int len1 = n - x + 1, len2 = n - y + 1;
		int l = 0, r = min(len1, len2);
		while (l < r) {
			if (check(x, x + mid - 1, y, y + mid - 1)) l = mid;
			else r = mid - 1;
		}
		if (l == min(len1, len2)) {
			return x > y;
		} else {
			return s[x + l] < s[y + l];
		}
	});
	
	for (int i = 1; i <= n; ++i) {
		cout << sa[i] << " \n"[i == n];
	}
	
	return 0;
}

提交记录,可以看到哈希的常数和较劣的复杂度是无法通过本题的。

upd:好像基于比较的排序没有前途了 /ll

upd2:但是 wyd 讲了一个只利用哈希做到 \(O(n \log n)\) 求后缀数组的做法!该做法好像来源于 zak,拜谢。

我们首先可以将 \(s\) 的长度补全到 \(2\) 的次幂(设为 \(k\))来简化接下来的考虑(即,在 \(s\) 后面加入空字符使 \(n = 2^k\))。考虑首先处理出所有 \(n\) 个后缀的长度为 \(2^{k - 1}\) 的前缀的哈希值。这时,\(n\) 个后缀之间的关系就可以分为「哈希值相同」的部分和「哈希值不同」的部分。

对于哈希值相同的后缀,我们显然只需要比较它们长度不大于 \(2^{k - 1}\) 的后半部分的字典序即可。对于不同的每个哈希值,我们都只保留一个串进行接下来的比较。

递归地做这件事情。

复杂度证明,不会。

我怎么这么菜啊呜呜,如何变得可以自己想出这个的复杂度证明阿?


上面的 \(O(n \log^2 n)\) 做法好想也好写,不过复杂度还是稍微劣了一些,接下来将说一说最普遍的 \(O(n \log n)\) 的后缀数组求法。

我们可以认为 \(s\) 后面接着足够多的空字符(假定空字符的字典序最小),那么 \(n\) 个后缀的长度也可以被补全为一样的。

考虑对 \(n\) 个字符排序的过程。我们首先对每个后缀的第一个字符进行排序,设此时的 \(\text{rk}_i\) 表示 \(s_i\) 这个字符的排名,直接按照 \(\text{rk}_i\) 作为权值排序即可完成这一步。

接下来,我们当然可以继续做比较第 \(2\) 个字符的过程,但我们实际上可以优化它。考虑使用倍增的思想,求出新的 \(\text{rk}_i\) 表示 \(s_{[i, i + 1]}\) 的排名,那么我们再次按照 \(\text{rk}_{i}\) 排序即可,而求新的 \(\text{rk}_i\) 的过程又是 \(O(n)\) 的。一般地,设 \(\text{rk}_i\) 表示 \(s_{[i, i + 2^k - 1]}\) 的排名,并按照它排序。该做法的时间复杂度为 \(O(n \log^2 n)\)

代码(注意因为我们对 \(s\) 进行了长度上的“补全”,所以 \(\text{rk}\) 要开到 \(|s|\) 的两倍):

点击查看代码
#include <bits/stdc++.h>

using namespace std;

constexpr int N = 1e6 + 5;

int n, w;
int sa[N], rk[N << 1], rk2[N << 1];
char s[N];

signed main() {
	ios::sync_with_stdio(false);
	cin.tie(nullptr); cout.tie(nullptr);
	
	cin >> s + 1;
	n = strlen(s + 1);
	
	for (int i = 1; i <= n; ++i) {
		sa[i] = i;
		rk[i] = s[i];
	}
	
	for (w = 1; w < n; w <<= 1) {
		sort(sa + 1, sa + n + 1, [](int x, int y) {
			if (rk[x] == rk[y]) {
				return rk[x + w] < rk[y + w];
			} else {
				return rk[x] < rk[y];
			}
		});
		
		for (int i = 1; i <= n; ++i) {
			rk2[i] = rk[i];
		}
		
		for (int p = 0, i = 1; i <= n; ++i) {
			if (rk2[sa[i]] == rk2[sa[i - 1]] && rk2[sa[i] + w] == rk2[sa[i - 1] + w]) {
				rk[sa[i]] = p;
			} else {
				rk[sa[i]] = ++p;
			}
		}
	}
	
	for (int i = 1; i <= n; ++i) {
		cout << sa[i] << " \n"[i == n];
	}
	
	return 0;
}

而优化这个做法是不困难的:复杂度瓶颈在于求新的 \(\text{sa}\) 的排序而非求 \(\text{rk}\) 的过程。我们的排序只参考了 \(\text{rk}\) 数组,那么直接使用基数排序替换掉 sort 即可。

注意基数排序和计数排序的区别:计数排序是直接把所有数放到值域的桶里面,基数排序是对于多个关键字,按较低到较高优先级的关键字依次排序。

代码:

点击查看代码
#include <bits/stdc++.h>

using namespace std;

constexpr int N = 1e6 + 5;

int n;
int sa[N], sa2[N], buc[N], rk[N << 1], rk2[N << 1];
char s[N];

signed main() {
	ios::sync_with_stdio(false);
	cin.tie(nullptr); cout.tie(nullptr);
	
	cin >> s + 1;
	n = strlen(s + 1);
	
	int sz = max(n, 127);
	for (int i = 1; i <= n; ++i) ++buc[rk[i] = s[i]];
	for (int i = 1; i <= sz; ++i) buc[i] += buc[i - 1];
	for (int i = 1; i <= n; ++i) sa[buc[rk[i]]--] = i;
	
	for (int w = 1; w < n; w <<= 1) {
		for (int i = 0; i <= sz; ++i) buc[i] = 0;
		for (int i = 1; i <= n; ++i) sa2[i] = sa[i];
		for (int i = 1; i <= n; ++i) ++buc[rk[sa2[i] + w]];
		for (int i = 1; i <= sz; ++i) buc[i] += buc[i - 1];
		for (int i = n; i > 0; --i) sa[buc[rk[sa2[i] + w]]--] = sa2[i];
		
		for (int i = 0; i <= sz; ++i) buc[i] = 0;
		for (int i = 1; i <= n; ++i) sa2[i] = sa[i];
		for (int i = 1; i <= n; ++i) ++buc[rk[sa2[i]]];
		for (int i = 1; i <= sz; ++i) buc[i] += buc[i - 1];
		for (int i = n; i > 0; --i) sa[buc[rk[sa2[i]]]--] = sa2[i];
		
		for (int i = 1; i <= n; ++i) rk2[i] = rk[i];
		for (int cur = 0, i = 1; i <= n; ++i) {
			if (rk2[sa[i]] == rk2[sa[i - 1]] && rk2[sa[i] + w] == rk2[sa[i - 1] + w]) {
				rk[sa[i]] = cur;
			} else {
				rk[sa[i]] = ++cur;
			}
		}
	}
	
	for (int i = 1; i <= n; ++i) {
		cout << sa[i] << " \n"[i == n];
	}
	
	return 0;
}

当然我们还能再进行一些常数上的优化,得到更为实用的版本。

  1. 不对第二关键字进行计数排序。

考虑对第二关键字进行排序的实质:对于 \(i = 1\)\(n\),按照 \(\text{rk}_{i + 2^k - 1}\) 进行排序,最终的结果必定是把最后的 \(n\) 个(后缀 \(i - n + 1\) 到后缀 \(n\))放到最前面,并把其余的部分按照原顺序(倍增到 \(k - 1\) 层的顺序)往后平移。

  1. 在每次倍增结束后动态更新值域。

  2. \(\text{rk}\) 数组内任意两项的值都不相同,则排序过程必定已经结束,可以停止接下来的操作。

代码:

点击查看代码
#include <bits/stdc++.h>

using namespace std;

constexpr int N = 1e6 + 5;

int n;
int buc[N], sa[N], sa2[N], rk[N << 1], rk2[N << 1];
char s[N];

signed main() {
	ios::sync_with_stdio(false);
	cin.tie(nullptr); cout.tie(nullptr);
	
	cin >> s + 1;
	n = strlen(s + 1);
	
	int sz = 127;
	for (int i = 1; i <= n; ++i) ++buc[rk[i] = s[i]];
	for (int i = 1; i <= sz; ++i) buc[i] += buc[i - 1];
	for (int i = 1; i <= n; ++i) sa[buc[rk[i]]--] = i;
	
	for (int w = 1; w < n; w <<= 1) {
		int cur = 0;
		for (int i = n - w + 1; i <= n; ++i) sa2[++cur] = i;
		for (int i = 1; i <= n; ++i) {
			if (sa[i] > w) {
				sa2[++cur] = sa[i] - w;
			}
		}
		
		for (int i = 0; i <= sz; ++i) buc[i] = 0;
		for (int i = 1; i <= n; ++i) ++buc[rk[i]];
		for (int i = 1; i <= sz; ++i) buc[i] += buc[i - 1];
		for (int i = n; i > 0; --i) sa[buc[rk[sa2[i]]]--] = sa2[i];
		
		cur = 0;
		for (int i = 1; i <= n; ++i) rk2[i] = rk[i];
		for (int i = 1; i <= n; ++i) {
			if (rk2[sa[i]] == rk2[sa[i - 1]] && rk2[sa[i] + w] == rk2[sa[i - 1] + w]) {
				rk[sa[i]] = cur;
			} else {
				rk[sa[i]] = ++cur;
			}
		}
		sz = cur;
		
		if (cur == n) {
			break;
		}
	}
	
	for (int i = 1; i <= n; ++i) {
		cout << sa[i] << " \n"[i == n];
	}
	
	return 0;
}
posted @ 2025-07-04 20:37  zyb_txdy  阅读(14)  评论(0)    收藏  举报