字符串专题-广义后缀自动机
构建广义后缀自动姬
(1)伪做法:重置last法 - 也是一种在线做法,但是最坏情况下空间会比第三种构造法大一倍
每次插入一个字符串,插完一个字符串之后,重置last=1,然后继续插入下一个字符串,直到所有字符串插入完毕。
这个算法的时间复杂度俺不太会证明,但是平均复杂度是线性的。由于没有多余操作,最优性能比标准法要快,但是容易被卡掉?这里就不给代码了。
(2)正解/标准:广搜Trie树建SAM
涉及到使用Trie来建SAM,相当于把Trie上的所有路径插入到SAM中。正确的原因是:按照BFS插入的节点的len是单调不减的,这样就可以保证插入正确性(具体证明俺不太会)。
查看代码
struct GSAM { // root = 1, 记得开两倍空间
static const int SIGMA = 26;
int nxt[maxn][SIGMA], len[maxn], fa[maxn], samcnt;
int Psz[maxn]; // Right集合size
inline int newNode() {
++samcnt, me(nxt[samcnt], 0);
len[samcnt] = fa[samcnt] = 0;
return samcnt;
}
inline void initSAM() { samcnt = 0, newNode(); }
inline void insTrie(char *str, int n) {
for (int x = 1, i = 1; i <= n; i++) {
int c = str[i] - 'a'; // 注意字符集,有时候是:'0'~'9'
if (!nxt[x][c]) nxt[x][c] = newNode();
x = nxt[x][c];
}
}
inline void insSAM(int last, int c) {
int np = nxt[last][c], p = fa[last]; // 因为last的nxt已经赋值,直接往上跳
Psz[np] = 1, len[np] = len[last] + 1;
while (p && !nxt[p][c]) nxt[p][c] = np, p = fa[p];
if (!p) {
fa[np] = 1;
} else {
int q = nxt[p][c];
if (len[q] == len[p] + 1) {
fa[np] = q;
} else {
int nq = newNode();
Psz[nq] = 0, len[nq] = len[p] + 1;
fa[nq] = fa[q];
for (int i = 0; i < SIGMA; i++) // len不为0的节点才是在SAM上的点
nxt[nq][i] = (len[nxt[q][i]] ? nxt[q][i] : 0);
for (; p && nxt[p][c] == q; p = fa[p]) nxt[p][c] = nq;
fa[q] = fa[np] = nq;
}
}
}
inline void Build() {
queue<pair<int, int>> q;
for (int i = 0; i < SIGMA; i++)
if (nxt[1][i]) q.push(make_pair(1, i));
while (q.size()) { // BFS按层插入节点
auto [last, c] = q.front();
q.pop(), insSAM(last, c);
for (int x = nxt[last][c], i = 0; i < SIGMA; i++)
if (nxt[x][i]) q.push(make_pair(x, i));
}
}
} gsam;
(3)在线构造法:重置last法+特判
我们发现第一种重置last的方法,会多出一些空余节点,这些节点保存了无用信息。所以只需要通过特判减少/避免这类空余节点,就可以使得GSAM的节点数量是正确的,同时构造时间的复杂度同样是线性的。
查看代码
#include <bits/stdc++.h>
using namespace std;
#define me(a, b) memset(a, (b), sizeof(a))
const int maxn = 2e6 + 7;
struct GSAM { // root = 1, 注意开2倍空间
static const int SIGMA = 26;
int samcnt, nxt[maxn][SIGMA], fa[maxn];
int len[maxn], Psz[maxn];
int newNode() {
++samcnt, me(nxt[samcnt], 0), fa[samcnt] = 0;
return samcnt;
}
void initSAM() { samcnt = 0, newNode(); }
// 返回插入c创建的节点(不是分裂点),用于其他节点的last
int insert(int last, int c) {
// 需要保证nxt[p][c]!=0,然后检查nxt[p][c]是否符合要求
const auto chkPnt = [&](int p, int c) -> int {
int q = nxt[p][c];
if (len[q] == len[p] + 1) return q; // 直接返回节点q
int nq = newNode(); // q的分裂点
Psz[nq] = 0, len[nq] = len[p] + 1;
fa[nq] = fa[q], fa[q] = nq;
memcpy(nxt[nq], nxt[q], sizeof(nxt[q]));
for (; p && nxt[p][c] == q; p = fa[p]) nxt[p][c] = nq;
return nq;
};
if (nxt[last][c]) return chkPnt(last, c); // 特判已经存在的节点
int np = newNode(), p = last;
Psz[np] = 1, len[np] = len[p] + 1;
for (; p && !nxt[p][c]; p = fa[p]) nxt[p][c] = np; //记得p=fa[p]
fa[np] = p ? chkPnt(p, c) : 1;
return np;
}
} gsam;
int n;
char str[maxn];
int main() {
scanf("%d", &n);
gsam.initSAM();
for (int i = 1; i <= n; i++) {
scanf("%s", str);
int last = 1; // reset, 插入第二个字符串
for (int j = 0; str[j]; j++) {
last = gsam.insert(last, str[j] - 'a');
}
}
long long ans = 0;
for (int j = 2; j <= gsam.samcnt; j++)
ans += gsam.len[j] - gsam.len[gsam.fa[j]];
printf("%lld", ans);
}
GSAM的性质/例题
GSAM具备了SAM的一些良好性质,同时因为GSAM是多个串的自动机,可以处理多串结合的问题。
习题
(1)hihocoder #1457 : 后缀自动机四·重复旋律7 【计算SAM上每个节点的权值? 广义后缀自动机】 - 已经无法提交? 中等难度
题意:给你n个字符串(由0~9的数字组成),求这些字符串本质不同的子串的权值之和。每个子串的权值就是该串在十进制表示下的权值。
思路:我们知道SAM中每次添加一个字符,相当于把str[1:i]的i个后缀插入到旧的SAM中,而新增的本质不同的串的数量,相当于: len[np] - len[fa[np]]。那么对于这些新增的不同的子串,我们可以直接计算它们的贡献。
具体实现方法:trie上的每个节点都有一个bel属性,表示该节点属于哪个字符串(如果有重复也不要紧)。最后BFS建GSAM时,每插入一个节点,ans+=calc(1, len[np], bel[np]) - calc(len[np] - len[fa[np]], len[np], bel[np])就行了。其中的calc(l,r,i)函数是计算第i个字符串的str[i][l:r]的权值。【calc函数的实现可以利用前缀思想】
该代码未AC,因为OJ无法提交
int pw[maxn];
vector<vector<ll>> G, F;
void initPw() {
pw[0] = 1, G.emp(vector<ll>()), F.emp(vector<ll>());
for (int i = 1; i < maxn; i++) pw[i] = (10ll * pw[i - 1]) % mod;
}
void addStr(char *str, int n) {
G.emp(vector<ll>(n + 11, 0));
F.emp(vector<ll>(n + 11, 0));
int i = (int)G.size() - 1;
for (int j = 1; j <= n; j++) {
G[i][j] = (G[i][j - 1] * 10ll % mod + (str[j] - '0')) % mod;
F[i][j] = (F[i][j - 1] * 10ll % mod + 1ll * j * (str[j] - '0')) % mod;
}
}
ll calc(int l, int r, int b) {
// cerr << "l = " << l << " , r = " << r << " ! b = " << b << endl;
ll sub = F[b][l - 1] * pw[r - l + 1] % mod;
ll num = (G[b][r] - G[b][l - 1] * pw[r - l + 1] % mod + mod) % mod;
return ((F[b][r] - num * (l - 1) - sub) % mod + mod) % mod;
}
static const int SIGMA = 10;
int nxt[maxn][SIGMA], len[maxn], fa[maxn], samcnt;
int Psz[maxn], bel[maxn];
ll ans;
inline int newNode() {
++samcnt, me(nxt[samcnt], 0);
len[samcnt] = fa[samcnt] = 0;
return samcnt;
}
inline void initSAM() { samcnt = 0, newNode(); }
inline void insTrie(char *str, int n, int sid) {
for (int x = 1, i = 1; i <= n; i++) {
int c = str[i] - '0'; // 注意修改这里,字符集已经改变了
if (!nxt[x][c]) nxt[x][c] = newNode();
x = nxt[x][c], bel[x] = sid;
}
}
inline void insSAM(int last, int c) {
int np = nxt[last][c], p = fa[last]; // 因为last的nxt已经赋值,直接往上跳
Psz[np] = 1, len[np] = len[last] + 1;
while (p && !nxt[p][c]) nxt[p][c] = np, p = fa[p];
if (!p) {
fa[np] = 1;
} else {
int q = nxt[p][c];
if (len[q] == len[p] + 1) {
fa[np] = q;
} else {
int nq = newNode();
Psz[nq] = 0, len[nq] = len[p] + 1;
fa[nq] = fa[q];
for (int i = 0; i < SIGMA; i++) // len不为0的节点才是在SAM上的点
nxt[nq][i] = (len[nxt[q][i]] ? nxt[q][i] : 0);
for (; p && nxt[p][c] == q; p = fa[p]) nxt[p][c] = nq;
fa[q] = fa[np] = nq;
}
}
// 每次累积贡献,最后就是答案
ans = (ans + calc(1, len[np], bel[np]) - calc(len[np] - len[fa[np]] + 1, len[np], bel[np])) % mod;
if (ans < 0) ans += mod;
}
inline void Build() {
queue<pair<int, int>> q;
for (int i = 0; i < SIGMA; i++)
if (nxt[1][i]) q.push(make_pair(1, i));
while (q.size()) { // BFS按层插入节点
int last = q.front().fi;
int c = q.front().se;
q.pop(), insSAM(last, c);
for (int x = nxt[last][c], i = 0; i < SIGMA; i++)
if (nxt[x][i]) q.push(make_pair(x, i));
}
}
int n;
char str[maxn];
int main() {
scanf("%d", &n);
initSAM(), initPw();
for (int i = 1; i <= n; i++) {
scanf("%s", str + 1);
int len = strlen(str + 1);
insTrie(str, len, i);
addStr(str, len);
}
Build();
printf("%lld\n", ans);
}
(2)P3181 [HAOI2016]找相同字符【维护每个节点在每个字符串中Right集合的大小】 - 中等
题意:给定两个字符串S、T,两个子串不同,当且仅当位置不同。现在让你求出有多少对(a,b),其中a是S中的一个子串,b是T中的一个子串。
思路:由于是位置不同决定子串,所以要考虑上Right集合的大小。使用GSAM来解决这道题是最方便的,维护出每个字符串在每个节点的Right大小,然后答案就是:\( \sum_{x∈gsam} sz[x][0]*sz[x][1]*(len[x] - len[fa[x]])\)。
其中right集合的维护方法大概是这样的:①对于Trie树建SAM来说,插入Tried的时候对路径上每个节点标记上sz[x][str_id]=1就行了,这个节点在sam上就是对应一个前缀的节点(即对应一个right_pos)。②对于在线建树的方法:因为我们每一次insert都会返回刚刚插入的节点,所以直接sz[insert(last,c)]=1就行了。 - 又因为每个string一个前缀对应一个节点,所以直接赋值为1就行了。
这道题使用SA、单串sam来求解可能过于复杂,但是也不是不能做(如果题目改成给4、5、6..个串呢?只能用gsam了)。
(3)P3346 [ZJOI2015]诸神眷顾的幻想乡【暴力】
题意:给定一棵无根树,每个节点有一个字母,求路上所有路径形成的字符串中,本质不同的有多少个。【数据:叶子不超过20个】
tm读错题,我以为每个节点的度数不超过20。因为树上所有路径都可以看作是以一个叶子为根时的一个子串,所以暴力插入20个以叶子为根的树就行了。
(4)CF666E Forensic Examination【树上线段树合并 + 在SAM上找到str[l:r]所属的节点】

思路是十分板的,先把S串插入到SAM中,记录下S串每个前缀对应的节点pos;之后再插入m个T串,插入的时候,还要更新last节点的线段树。(为什么先把S插入到SAM中,后面再插T串不会影响之前记录的pos数组呢?纵观insert函数,只有节点分裂会影响一个节点,但因为S串每个前缀在对应的状态点中都是长度最大的,即使分裂,也是保存在q节点之后,线段树没有发生变化,所以我们可以先插入S,再插入Ti。)。
然后把所有的询问离线下来,使用SAM中学到的倍增思想,快速获取到str[pl,pr]对应的节点。
最后再dfs一遍,就是基本的数据结构思路了。
启发: GSAM的Right集合 和 SAM中的Right理解方式可以不一样,因为SAM只有一个串,Right集合中的元素意义就是相当于以R[i]结尾;而我更倾向于GSAM的Right集合是一个二元组集合,它的元素是这样的一个二元组(str_id, end_pos)。而且,对于GSAM中的每一个串的每个前缀,假设它对应的状态点是w(str_id, end_pos),那么一定有:len[w(str_id, end_pos)] = str[str_id][1:end_pos].size() = end_pos 。所以即使在节点分裂的时候,也不会影响到已经赋值的pos数组(因为每个前缀都是它所在状态点中最长的字符串,即使分裂,也保留在q节点不动)。
查看代码
struct QUERY {
int i, l, r;
};
const int maxn = 1.2e6;
char str[maxn];
int n, m, q, pos[maxn];
PII ans[maxn];
namespace seg {
int rt[maxn], ls[maxn * 4], rs[maxn * 4], segnum;
PII tr[maxn * 4];
PII operator+(PII a, PII b) {
if (a.fi == b.fi) return {a.fi, a.se + b.se};
if (a.se == b.se && a.fi < b.fi) return a;
return (a.se > b.se) ? a : b;
}
int merge(int x, int y, int l, int r) {
if (!x || !y) return x | y;
if (l == r) return tr[x].se += tr[y].se, x;
ls[x] = merge(ls[x], ls[y], l, mseg);
rs[x] = merge(rs[x], rs[y], mseg + 1, r);
tr[x] = tr[ls[x]] + tr[rs[x]];
return x;
}
void insert(int& ro, int l, int r, int x) {
if (!ro) ro = ++segnum;
if (l == r) return tr[ro].fi = l, tr[ro].se++, void();
x <= mseg ? insert(ls[ro], l, mseg, x) : insert(rs[ro], mseg + 1, r, x);
tr[ro] = tr[ls[ro]] + tr[rs[ro]];
}
PII query(int x, int l, int r, int s, int e) {
if (s <= l && r <= e) return x ? tr[x] : mp(s, 0);
if (e <= mseg) return query(ls[x], l, mseg, s, e);
if (s > mseg) return query(rs[x], mseg + 1, r, s, e);
return query(ls[x], l, mseg, s, e) + query(rs[x], mseg + 1, r, s, e);
}
}; // namespace seg
static const int SIGMA = 26;
int samcnt, nxt[maxn][SIGMA], fa[20][maxn];
int len[maxn];
vector<vector<QUERY>> qr;
vector<vector<int>> son;
int newNode() { return ++samcnt; }
void initSAM() { samcnt = 0, newNode(); }
// 返回插入c创建的节点(不是分裂点),用于其他节点的last
int insert(int last, int c) {
// 需要保证nxt[p][c]!=0,然后检查nxt[p][c]是否符合要求
const auto chkPnt = [&](int p, int c) -> int {
int q = nxt[p][c];
if (len[q] == len[p] + 1) return q; // 直接返回节点q
int nq = newNode(); // q的分裂点
len[nq] = len[p] + 1;
fa[0][nq] = fa[0][q], fa[0][q] = nq;
memcpy(nxt[nq], nxt[q], sizeof(nxt[q]));
for (; p && nxt[p][c] == q; p = fa[0][p]) nxt[p][c] = nq;
return nq;
};
if (nxt[last][c]) return chkPnt(last, c); // 特判已经存在的节点
int np = newNode(), p = last;
len[np] = len[p] + 1;
for (; p && !nxt[p][c]; p = fa[0][p]) nxt[p][c] = np; //记得p=fa[p]
fa[0][np] = p ? chkPnt(p, c) : 1;
return np;
}
void BuildTree() {
for (int d = 1; d < 20; d++)
for (int i = 1; i <= samcnt; i++) fa[d][i] = fa[d - 1][fa[d - 1][i]];
qr.assign(samcnt + 1, {});
son.assign(samcnt + 1, {});
for (int i = 2; i <= samcnt; i++) son[fa[0][i]].emp(i);
}
void dfs(int x) {
for (int& v : son[x]) {
dfs(v);
seg::rt[x] = seg::merge(seg::rt[x], seg::rt[v], 1, m);
}
for (auto [i, l, r] : qr[x]) {
ans[i] = seg::query(seg::rt[x], 1, m, l, r);
}
}
void insertStr(char* str, int ID) {
int last = 1;
for (int i = 1; str[i]; i++) {
last = insert(last, str[i] - 'a');
seg::insert(seg::rt[last], 1, m, ID);
}
}
void solve() {
cin >> str + 1;
initSAM();
n = 1;
for (int i = 1, last = 1; str[i]; i++, n++)
pos[i] = last = insert(last, str[i] - 'a');
cin >> m;
for (int i = 1; i <= m; i++) cin >> str + 1, insertStr(str, i);
BuildTree();
cin >> q;
for (int i = 1, pl, pr, l, r; i <= q; i++) {
cin >> l >> r >> pl >> pr;
int L = pr - pl + 1;
int x = pos[pr];
for (int d = 19; ~d; d--)
if (len[fa[d][x]] >= L) x = fa[d][x];
qr[x].emp(QUERY{i, l, r});
}
dfs(1);
for (int i = 1; i <= q; i++) cout << ans[i].fi << " " << ans[i].se << '\n';
}
[参考资料]
(2)广义后缀自动机学习笔记 - 饕餮传奇 - 博客园 (cnblogs.com)
(4)OIWIKI-GSAM

浙公网安备 33010602011771号