后缀数组(Suffix Array)学习笔记
感觉 OI-wiki 在相关方面讲得很牛啊!拜谢 OI-wiki 的后缀数组部分。
一些约定
字符串 \(S\) 的长度为 \(|S|\),是下标从 \(1\) 开始(到 \(|S|\))的字符串。特别地,若无特殊说明,约定字符串 \(s\) 的长度为 \(n\)。
后缀 \(i\) 代表以 \(i\) 开头的后缀。
若无特殊说明,则两个字符串 \(s_1\) 和 \(s_2\) 的大小关系就是它们字典序的大小关系。
后缀数组是什么?
后缀数组(Suffix Array),可以直接表示字符串 \(s\) 的后缀的字典序大小关系。
定义 \(\text{sa}_i\) 表示按照字典序排序后,第 \(i\) 大的后缀的开头的下标;定义 \(\text{rk}_i\) 表示后缀 \(i\) 按照字典序排序后的排名。
显然,在求出后缀数组(\(\text{sa}\))后,我们有 \(\text{sa}_{\text{rk}_i} = \text{rk}_{\text{sa}_i} = i\)。
如何求出后缀数组?
考虑最暴力的求法:将 \(s\) 的所有后缀存储起来,并直接 sort 起来。容易得到,该做法的时间复杂度是 \(O(n^2 \log n)\) 的。
考虑优化它:只要我们是基于比较的排序(即使用重定义 cmp 的 sort),复杂度中的 \(O(n \log n)\) 就无法去掉。考虑如何优化比较两个字符串 \(s_1\) 和 \(s_2\) 的大小关系。一个想法是,判断 \(s_1\) 和 \(s_2\) 的大小关系只需要找到它们第一个不相等的字符,并直接比较即可。我们可以用二分 + 哈希的方式做这件事。该做法的时间复杂度为 \(O(n \log^2 n)\)。可能再使用一些手法可以做到 \(O(n \log n)\) 地求出后缀数组,但我不会,所以这里先不说。以后记得补一下
这种 \(O(n \log^2 n)\) 做法的代码:
点击查看代码
#include <bits/stdc++.h>
#define ll long long
#define mid (l + r + 1 >> 1)
using namespace std;
constexpr int N = 1e6 + 5;
constexpr int base1 = 31, mod1 = 998244353;
constexpr int base2 = 171, mod2 = 998244853;
int n;
int sa[N];
ll p1[N], p2[N], hs1[N], hs2[N];
char s[N];
bool check(int l1, int r1, int l2, int r2) {
int hsx1 = (hs1[r1] - hs1[l1 - 1] * p1[r1 - l1 + 1] % mod1 + mod1) % mod1;
int hsx2 = (hs2[r1] - hs2[l1 - 1] * p2[r1 - l1 + 1] % mod2 + mod2) % mod2;
int hsy1 = (hs1[r2] - hs1[l2 - 1] * p1[r2 - l2 + 1] % mod1 + mod1) % mod1;
int hsy2 = (hs2[r2] - hs2[l2 - 1] * p2[r2 - l2 + 1] % mod2 + mod2) % mod2;
if (hsx1 == hsy1 && hsx2 == hsy2) {
return true;
} else {
return false;
}
}
signed main() {
ios::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);
cin >> s + 1;
n = strlen(s + 1);
p1[0] = p2[0] = 1ll;
for (int i = 1; i <= n; ++i) {
p1[i] = p1[i - 1] * base1 % mod1;
p2[i] = p2[i - 1] * base2 % mod2;
hs1[i] = (hs1[i - 1] * base1 + s[i]) % mod1;
hs2[i] = (hs2[i - 1] * base2 + s[i]) % mod2;
}
for (int i = 1; i <= n; ++i) {
sa[i] = i;
}
sort(sa + 1, sa + n + 1, [](int x, int y) {
int len1 = n - x + 1, len2 = n - y + 1;
int l = 0, r = min(len1, len2);
while (l < r) {
if (check(x, x + mid - 1, y, y + mid - 1)) l = mid;
else r = mid - 1;
}
if (l == min(len1, len2)) {
return x > y;
} else {
return s[x + l] < s[y + l];
}
});
for (int i = 1; i <= n; ++i) {
cout << sa[i] << " \n"[i == n];
}
return 0;
}
提交记录,可以看到哈希的常数和较劣的复杂度是无法通过本题的。
upd:好像基于比较的排序没有前途了 /ll
upd2:但是 wyd 讲了一个只利用哈希做到 \(O(n \log n)\) 求后缀数组的做法!该做法好像来源于 zak,拜谢。
我们首先可以将 \(s\) 的长度补全到 \(2\) 的次幂(设为 \(k\))来简化接下来的考虑(即,在 \(s\) 后面加入空字符使 \(n = 2^k\))。考虑首先处理出所有 \(n\) 个后缀的长度为 \(2^{k - 1}\) 的前缀的哈希值。这时,\(n\) 个后缀之间的关系就可以分为「哈希值相同」的部分和「哈希值不同」的部分。
对于哈希值相同的后缀,我们显然只需要比较它们长度不大于 \(2^{k - 1}\) 的后半部分的字典序即可。对于不同的每个哈希值,我们都只保留一个串进行接下来的比较。
递归地做这件事情。
复杂度证明,不会。
我怎么这么菜啊呜呜,如何变得可以自己想出这个的复杂度证明阿?
上面的 \(O(n \log^2 n)\) 做法好想也好写,不过复杂度还是稍微劣了一些,接下来将说一说最普遍的 \(O(n \log n)\) 的后缀数组求法。
我们可以认为 \(s\) 后面接着足够多的空字符(假定空字符的字典序最小),那么 \(n\) 个后缀的长度也可以被补全为一样的。
考虑对 \(n\) 个字符排序的过程。我们首先对每个后缀的第一个字符进行排序,设此时的 \(\text{rk}_i\) 表示 \(s_i\) 这个字符的排名,直接按照 \(\text{rk}_i\) 作为权值排序即可完成这一步。
接下来,我们当然可以继续做比较第 \(2\) 个字符的过程,但我们实际上可以优化它。考虑使用倍增的思想,求出新的 \(\text{rk}_i\) 表示 \(s_{[i, i + 1]}\) 的排名,那么我们再次按照 \(\text{rk}_{i}\) 排序即可,而求新的 \(\text{rk}_i\) 的过程又是 \(O(n)\) 的。一般地,设 \(\text{rk}_i\) 表示 \(s_{[i, i + 2^k - 1]}\) 的排名,并按照它排序。该做法的时间复杂度为 \(O(n \log^2 n)\)。
代码(注意因为我们对 \(s\) 进行了长度上的“补全”,所以 \(\text{rk}\) 要开到 \(|s|\) 的两倍):
点击查看代码
#include <bits/stdc++.h>
using namespace std;
constexpr int N = 1e6 + 5;
int n, w;
int sa[N], rk[N << 1], rk2[N << 1];
char s[N];
signed main() {
ios::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);
cin >> s + 1;
n = strlen(s + 1);
for (int i = 1; i <= n; ++i) {
sa[i] = i;
rk[i] = s[i];
}
for (w = 1; w < n; w <<= 1) {
sort(sa + 1, sa + n + 1, [](int x, int y) {
if (rk[x] == rk[y]) {
return rk[x + w] < rk[y + w];
} else {
return rk[x] < rk[y];
}
});
for (int i = 1; i <= n; ++i) {
rk2[i] = rk[i];
}
for (int p = 0, i = 1; i <= n; ++i) {
if (rk2[sa[i]] == rk2[sa[i - 1]] && rk2[sa[i] + w] == rk2[sa[i - 1] + w]) {
rk[sa[i]] = p;
} else {
rk[sa[i]] = ++p;
}
}
}
for (int i = 1; i <= n; ++i) {
cout << sa[i] << " \n"[i == n];
}
return 0;
}
而优化这个做法是不困难的:复杂度瓶颈在于求新的 \(\text{sa}\) 的排序而非求 \(\text{rk}\) 的过程。我们的排序只参考了 \(\text{rk}\) 数组,那么直接使用基数排序替换掉 sort 即可。
注意基数排序和计数排序的区别:计数排序是直接把所有数放到值域的桶里面,基数排序是对于多个关键字,按较低到较高优先级的关键字依次排序。
代码:
点击查看代码
#include <bits/stdc++.h>
using namespace std;
constexpr int N = 1e6 + 5;
int n;
int sa[N], sa2[N], buc[N], rk[N << 1], rk2[N << 1];
char s[N];
signed main() {
ios::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);
cin >> s + 1;
n = strlen(s + 1);
int sz = max(n, 127);
for (int i = 1; i <= n; ++i) ++buc[rk[i] = s[i]];
for (int i = 1; i <= sz; ++i) buc[i] += buc[i - 1];
for (int i = 1; i <= n; ++i) sa[buc[rk[i]]--] = i;
for (int w = 1; w < n; w <<= 1) {
for (int i = 0; i <= sz; ++i) buc[i] = 0;
for (int i = 1; i <= n; ++i) sa2[i] = sa[i];
for (int i = 1; i <= n; ++i) ++buc[rk[sa2[i] + w]];
for (int i = 1; i <= sz; ++i) buc[i] += buc[i - 1];
for (int i = n; i > 0; --i) sa[buc[rk[sa2[i] + w]]--] = sa2[i];
for (int i = 0; i <= sz; ++i) buc[i] = 0;
for (int i = 1; i <= n; ++i) sa2[i] = sa[i];
for (int i = 1; i <= n; ++i) ++buc[rk[sa2[i]]];
for (int i = 1; i <= sz; ++i) buc[i] += buc[i - 1];
for (int i = n; i > 0; --i) sa[buc[rk[sa2[i]]]--] = sa2[i];
for (int i = 1; i <= n; ++i) rk2[i] = rk[i];
for (int cur = 0, i = 1; i <= n; ++i) {
if (rk2[sa[i]] == rk2[sa[i - 1]] && rk2[sa[i] + w] == rk2[sa[i - 1] + w]) {
rk[sa[i]] = cur;
} else {
rk[sa[i]] = ++cur;
}
}
}
for (int i = 1; i <= n; ++i) {
cout << sa[i] << " \n"[i == n];
}
return 0;
}
当然我们还能再进行一些常数上的优化,得到更为实用的版本。
- 不对第二关键字进行计数排序。
考虑对第二关键字进行排序的实质:对于 \(i = 1\) 到 \(n\),按照 \(\text{rk}_{i + 2^k - 1}\) 进行排序,最终的结果必定是把最后的 \(n\) 个(后缀 \(i - n + 1\) 到后缀 \(n\))放到最前面,并把其余的部分按照原顺序(倍增到 \(k - 1\) 层的顺序)往后平移。
-
在每次倍增结束后动态更新值域。
-
若 \(\text{rk}\) 数组内任意两项的值都不相同,则排序过程必定已经结束,可以停止接下来的操作。
代码:
点击查看代码
#include <bits/stdc++.h>
using namespace std;
constexpr int N = 1e6 + 5;
int n;
int buc[N], sa[N], sa2[N], rk[N << 1], rk2[N << 1];
char s[N];
signed main() {
ios::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);
cin >> s + 1;
n = strlen(s + 1);
int sz = 127;
for (int i = 1; i <= n; ++i) ++buc[rk[i] = s[i]];
for (int i = 1; i <= sz; ++i) buc[i] += buc[i - 1];
for (int i = 1; i <= n; ++i) sa[buc[rk[i]]--] = i;
for (int w = 1; w < n; w <<= 1) {
int cur = 0;
for (int i = n - w + 1; i <= n; ++i) sa2[++cur] = i;
for (int i = 1; i <= n; ++i) {
if (sa[i] > w) {
sa2[++cur] = sa[i] - w;
}
}
for (int i = 0; i <= sz; ++i) buc[i] = 0;
for (int i = 1; i <= n; ++i) ++buc[rk[i]];
for (int i = 1; i <= sz; ++i) buc[i] += buc[i - 1];
for (int i = n; i > 0; --i) sa[buc[rk[sa2[i]]]--] = sa2[i];
cur = 0;
for (int i = 1; i <= n; ++i) rk2[i] = rk[i];
for (int i = 1; i <= n; ++i) {
if (rk2[sa[i]] == rk2[sa[i - 1]] && rk2[sa[i] + w] == rk2[sa[i - 1] + w]) {
rk[sa[i]] = cur;
} else {
rk[sa[i]] = ++cur;
}
}
sz = cur;
if (cur == n) {
break;
}
}
for (int i = 1; i <= n; ++i) {
cout << sa[i] << " \n"[i == n];
}
return 0;
}

浙公网安备 33010602011771号