SA-IS 乱记
学习自这篇博客,讲得很好。
没有什么例子,只是看看自己有没有理解 SA-IS,顺便方便后面复习。
提醒:基本没有出题人会没事干卡带 \(\log\) 的 SA 的,所以这个东西基本没用,实在需要卡常用用就得了。
一些记号:\(|S| = n\),\(S_{[l, r]}\) 代表 \([l, r]\) 的子串。
首先在字符串后面加一个表示极小值的字符,记做 \(\#\) 方便 SA-IS 的求解。
\(\operatorname{type}\) 的定义:
接下来对于每个后缀 \(S_{[i, n]}\) 都定义一个类型 \(\operatorname{type}_i\),其中 \(\operatorname{type}_i = \texttt{L}\) 则表示 \(S_{[i, n]} > S_{[i + 1, n]}\),为 \(\texttt{S}\) 则表示 \(S_{[i, n]} < S_{[i + 1, n]}\),特殊的定义 \(\operatorname{type}_n = \texttt{S}\)。
对于 \(\operatorname{type}\) 的解可以考虑递推,这是因为如果 \(S_i\not= S_{i + 1}\) 则可以直接比较出来,否则就相当于继续往后比较 \(S_{i + 1}\) 和 \(S_{i + 2}\),就变成了 \((i + 1, i + 2)\) 的问题了。
于是可以倒着递推,如果 \(S_i = S_{i + 1}\) 则 \(\operatorname{type}_i = \operatorname{type}_{i + 1}\),否则若 \(S_i > S_{i + 1}\) 则 \(\operatorname{type}_i = \texttt{L}\),若 \(S_i < S_{i + 1}\) 则 \(\operatorname{type}_i = \texttt{S}\)。
若 \(S_{i} = S_j, \operatorname{type}_i = \texttt{L}, \operatorname{type}_i = \texttt{S}\),则 \(S_{[i, n]} < S_{[j, n]}\)。
因为 \(\operatorname{type}_i = \texttt{T}\),所以 \(S_i\ge S_{i + 1}\);因为 \(\operatorname{type}_j = \texttt{S}\),所以 \(S_j\le S_{j + 1}\)。
只要存在 \(S_i\not = S_{i + 1}\) 或 \(S_j \not = S_{j + 1}\) 就已经可以通过 \(S_{i + 1}, S_{j + 1}\) 比较出 \(S_{[i, n]} < S_{[j, n]}\);否则又可以继续考虑 \((i + 1, j + 1)\),因为此时又有 \(S_{i + 1} = S_{j + 1}, \operatorname{type}_{i + 1} = \texttt{T}, \operatorname{type}_j = \texttt{S}\),那么在比较到最后一定会出现不同的(会碰到 \(\#\))。
LMS 子串:
定义若 \(\operatorname{type}_{i} = \texttt{S}, \operatorname{type}_{i - 1} = \texttt{L}\),则 \(S_i\) 为 LMS 字符,相邻的两个 LMS 字符之间(包含两端)的子串就被称作 LMS 子串,特殊的,\(\#\) 算一个 LMS 子串。
那么 LMS 子串在 \(\operatorname{type}\) 的刻画下一定形如 \(\texttt{S}\cdots\texttt{ST}\cdots\texttt{TS}\)(除掉特殊情况),其中最前和最后的 \(\texttt{S}\) 就为 LMS 字符。
LMS 子串个数不超过 \(\lceil\frac{n}{2}\rceil\),因为最短的 LMS 串也形如 \(\texttt{STS}\),所以可以从后往前考虑 LMS 子串,至少需要两个字符才能产生 LMS 子串,并且最后结尾是没有 LMS 子串的。
LMS 子串长度总和 \(\le 1.5n\),考虑只有端点会被算到 \(2\) 次,其余至多 \(1\) 次,所以不超过 \(1.5n\)。
LMS 子串的大小比较:在基于比较字符的情况下也比较类型。即对于 \((i, j)\) 先比较 \(s_i, s_j\),如果相等继续看 \(\operatorname{type}_i, \operatorname{type}_j\) 是否相等(不等可以像上面比出大小关系)。
不存在两个 LMS 子串使得为真前缀关系(包含类型关系),这是因为如果为真前缀那更长的一个串一定会在中间出现 LMS 字符,矛盾。
如果知道了 LMS 子串的大小关系,那么对于 LMS 字符 \(i, j\),\(s_{[i, n]}, s_{[j, n]}\) 的大小关系就是其后缀的 LMS 子串的大小关系比较。
这是因为 LMS 子串不存在真前缀关系,于是只有相等或能直接比较出大小。
于是就可以考虑先得到 LMS 子串的大小关系,并且继续递归得到 LMS 子串的后缀排序,就知道 LMS 字符的后缀排序顺序,然后再得到整个串的后缀排序。
那么这个的复杂度是 \(T(n) = T(\frac{n}{2}) + F(n)\) 的,那么只要能让 \(F(n)\) 是 \(\mathcal{O}(n)\) 的,就能让 \(T(n)\) 是 \(\mathcal{O}(n)\) 的。
重头戏:诱导排序,也是名字中 IS 的由来(induced sort)。
首先考虑如果已经知道了 LMS 子串的后缀排序怎么求解整个串的后缀排序。
首先考虑前文提到的,比较后缀可以先考虑比较开头的字符,然后比较开头的类型。
那么如果说都一样,那么就需要继续往后比较。
于是一个想法是,可以根据已知的大小关系,在开头加字符,这样就能保证后面的大小关系了,进一步就可以保证这个后缀的大小关系了。
又考虑到对于 \(\texttt{L}\) 的一段,从后面往前加开头字符,大小一定递增。
对于 \(\texttt{S}\) 的一段,从后面往前加开头字符,大小一定递减。
所以可以考虑对不同的字符开头与类型 \(\texttt{S}, \texttt{L}\) 都开桶,那么从小往大遍历就是先从小到大遍历开头字符,然后依次遍历 \(\texttt{L}\) 类型,再依次遍历 \(\texttt{S}\) 类型。
对于从大往小遍历就可以倒过来。
于是可以考虑先把 LMS 字符的后缀顺序加入桶。
接下来考虑遍历一边桶,把 \(\texttt{L}\) 类型都放进去:如果当前遍历到 \(i\) 且 \(\operatorname{type}_{i - 1} = \texttt{L}\) 则把 \(i - 1\) 放进由 \(S_{i - 1}\) 开头的 \(\texttt{L}\) 类型的桶里。
那么这个时候就保证了对于相同开头相同类型(\(\texttt{L}\))的后缀的大小关系,因为这些的比较是基于去掉开头的大小顺序得到的,而这些后缀放进桶的顺序就是由去掉开头的后缀得到的,所以排出来的顺序是对的。
接下来因为除了 \(\#\) 的其他 \(\texttt{S}\) 类型字符后面一定有 \(\texttt{L}\) 类型字符,于是接下来考虑再对 \(\texttt{S}\) 型字符排序(需要注意的是,此时也需要加入 LMS 字符一起重排)。
因为 \(\texttt{S}\) 的段从后往前大小递减,所以接下来考虑倒着从大到小这个桶,过程与排序 \(\texttt{L}\) 类似,唯一需要注意的是放入桶时也应该是倒着从后面开始放(保证先放的更大),正确性证明也一样。
接下来考虑如何对 LMS 子串进行一个基本的排序。
对此,依然可以使用诱导排序,因为此时这个 LMS 子串也可以看作是由最后一个 LMS 字符不断加入开头字符得来的。
于是类似的,变为整串的后缀排序值只需要知道 LMS 字符的后缀的顺序,那么按照 LMS 字符的顺序放入桶内,进行一次诱导排序,最后每个 LMS 字符的顺序就代表以其开头的 LMS 子串的大小顺序。
接下来是一些附带的杂项了:
- 因为递归下去做 LMS 子串应该要保证正确的 LMS 子串的大小关系,所以需要把相同的 LMS 子串压成一个值。
对此,考虑到因为已经求解了 LMS 的顺序了,相等只需要判断相邻子串就可以了,又因为 LMS 子串的长度和不超过 \(1.5n\),所以直接暴力遍历判断即可。 - 如果说不存在相同的 LMS 子串,那么后缀的大小关系是可以直接判断的,就是 LMS 子串的大小关系。
如果还要递归下去当然也是可以的,这只是个剪枝。
代码实现很大程度参考了这份,不过为什么他的 sbuc 和 lbuc 数组不用清空我很好奇,我的实现看起来是必须清空的。
给出的实现对应的是 QOJ 956 后缀排序模板:
#include<bits/stdc++.h>
constexpr int maxn = 1e6 + 100;
namespace sais {
int cnt[maxn], lbuc[maxn], sbuc[maxn];
inline void induced_sort(int S, int *s, int *type, int *lcnt, int *scnt) {
memcpy(cnt + 1, lcnt, S * 4);
for (int i = 1; i <= S; i++) {
for (int j = lcnt[i - 1] + 1; j <= lcnt[i]; j++) {
if (lbuc[j] > 1 && type[lbuc[j] - 1]) {
lbuc[++cnt[s[lbuc[j] - 1]]] = lbuc[j] - 1;
}
}
for (int j = scnt[i - 1] + 1; j <= scnt[i]; j++) {
if (sbuc[j] > 1 && type[sbuc[j] - 1]) {
lbuc[++cnt[s[sbuc[j] - 1]]] = sbuc[j] - 1;
}
}
}
memcpy(cnt + 1, scnt + 1, S * 4);
for (int i = S; i >= 1; i--) {
for (int j = scnt[i]; j > scnt[i - 1]; j--) {
if (sbuc[j] > 1 && ! type[sbuc[j] - 1]) {
sbuc[cnt[s[sbuc[j] - 1]]--] = sbuc[j] - 1;
}
}
for (int j = lcnt[i]; j > lcnt[i - 1]; j--) {
if (lbuc[j] > 1 && ! type[lbuc[j] - 1]) {
sbuc[cnt[s[lbuc[j] - 1]]--] = lbuc[j] - 1;
}
}
}
}
void _sais(int n, int S, int *s, int *sa, int *type, int *lcnt, int *scnt, int *ptr) {
type[n] = 0;
for (int i = n - 1; i >= 1; i--) {
type[i] = s[i] == s[i + 1] ? type[i + 1] : s[i] > s[i + 1];
}
for (int i = 1; i <= S; i++) lcnt[i] = scnt[i] = 0;
for (int i = 1; i <= n; i++) {
(type[i] ? lcnt[s[i]] : scnt[s[i]])++;
}
for (int i = 2; i <= S; i++) {
lcnt[i] += lcnt[i - 1];
scnt[i] += scnt[i - 1];
}
memset(lbuc + 1, 0, lcnt[S] * 4);
memset(sbuc + 1, 0, scnt[S] * 4);
memcpy(cnt + 1, scnt, S * 4);
static int nxt[maxn], sa1[maxn], ptr_[maxn];
int m = 0;
for (int i = n, nw = n; i >= 2; i--) {
if (type[i - 1] && ! type[i]) {
sbuc[++cnt[s[i]]] = i;
nxt[i] = nw, ptr[++m] = nw = i;
}
}
std::reverse(ptr + 1, ptr + m + 1);
induced_sort(S, s, type, lcnt, scnt);
int c = 0;
for (int i = 1, las = 0, m_ = 0; i <= scnt[S]; i++) {
int p = sbuc[i];
if (p == 1 || ! type[p - 1]) continue;
int d = nxt[p] - p;
c += c == 0 || d != nxt[las] - las ||
memcmp(s + p, s + las, d * 4) ||
memcmp(type + p, type + las, d * 4);
sa1[p] = c, ptr_[++m_] = las = p;
}
if (c == m) {
memset(lbuc + 1, 0, lcnt[S] * 4);
memset(sbuc + 1, 0, scnt[S] * 4);
memcpy(cnt + 1, scnt + 1, S * 4);
for (int i = m; i >= 1; i--) {
int p = ptr_[i];
sbuc[cnt[s[p]]--] = p;
}
} else {
for (int i = 1; i <= m; i++) {
s[n + i + 1] = sa1[ptr[i]];
}
_sais(m, c, s + n + 1, sa + n + 1, type + n + 1, lcnt + S + 1, scnt + S + 1, ptr + m + 1);
memset(lbuc + 1, 0, lcnt[S] * 4);
memset(sbuc + 1, 0, scnt[S] * 4);
memcpy(cnt + 1, scnt + 1, S * 4);
for (int i = m; i >= 1; i--) {
int p = ptr[sa[n + i + 1]];
sbuc[cnt[s[p]]--] = p;
}
}
induced_sort(S, s, type, lcnt, scnt);
for (int i = 1, k = 0; i <= S; i++) {
for (int j = lcnt[i - 1] + 1; j <= lcnt[i]; j++) {
sa[++k] = lbuc[j];
}
for (int j = scnt[i - 1] + 1; j <= scnt[i]; j++) {
sa[++k] = sbuc[j];
}
}
}
int s[maxn * 2], type[maxn * 2], sa[maxn * 2], lcnt[maxn * 2], scnt[maxn * 2], ptr[maxn];
inline void sais(int n, char *s_, int *sa_) {
for (int i = 1; i <= n; i++) s[i] = s_[i];
_sais(n, 128, s, sa, type, lcnt, scnt, ptr);
for (int i = 1; i <= n; i++) sa_[i] = sa[i + 1];
}
}
int n;
char s[maxn];
int sa[maxn];
int main() {
scanf("%s", s + 1);
n = strlen(s + 1);
s[n + 1] = '$';
sais::sais(n + 1, s, sa);
for (int i = 1; i <= n; i++) printf("%d ", sa[i]);
return 0;
}
浙公网安备 33010602011771号