后缀自动机学习
后缀自动机
学习参考:
一. ACM算法日常微信公众号-后缀自动机入门系列:
史上全网最清晰后缀自动机学习(二)后缀自动机的线性时间构造算法
史上全网最清晰后缀自动机学习(五)后缀自动机和最长公共子串问题
三. 陈立杰后缀自动机PPT
模板
struct node{
int trans[P], link, len;
void clear(){
memset(trans, 0, sizeof trans);
link = len = 0;
}
};
struct SAM{
node S[N];
int p, np, size;
SAM():p(1),np(1),size(1){}
void clear(){
for(int i=0;i<=size;i++)S[i].clear();
np = size = p = 1;
}
void insert(char ch){
int x = ch - 'a';
np = ++size;
S[np].len = S[p].len + 1;
while(p != 0 && !S[p].trans[x]) S[p].trans[x] = np, p = S[p].link;
if(p == 0)S[np].link = 1;
else{
int q, nq;
q = S[p].trans[x];
if(S[q].len == S[p].len + 1) S[np].link = q;
else{
nq = ++size;
S[nq] = S[q];
S[nq].len = S[p].len + 1;
S[np].link = S[q].link = nq;
while(p != 0 && S[p].trans[x] == q) S[p].trans[x] = nq, p = S[p].link;
}
}
p = np;
}
}sam;
基础题
1. 求本质不同的子串个数
#1445 : 后缀自动机二·重复旋律5
SAM中每个状态的 \(len\) 表示该 \(endpos\) 等价类中最长字串的长度,所以这个状态的不同子串个数是 \(len-S[S[i].link].len\)
//在SAM结构体中添加该函数调用即可
ll get(){
ll res = 0;
for(int i=1;i<=size;i++){
res += S[i].len - S[S[i].link].len;
}
return res;
}
int main() {
scanf("%s",s);
int n = strlen(s);
for(int i=0;i<n;i++)sam.insert(s[i]);
printf("%lld\n", sam.get());
return 0;
}
#2033. 「SDOI2016」生成魔咒
本题特殊点在于字符集为[1,1e9], 所以并不能用数组下标直接查询,因为总结点数不会\(2*n\),边不超过\(3*n\), 所以只需要给每个点加一个map就可以了。
本题需要求每插入一个字符后,SAM中本质不同的字符串的个数,可以发现每次分裂产生的新结点只是变成了一个新的\(endpos\)类,但本质不同字符串个数并没有增加,所以该个数增加只会发生在新插入的结点上
#include <cstdio>
#include <iostream>
#include <cmath>
#include <string>
#include <cstring>
#include <algorithm>
#include <vector>
#include <queue>
#include <unordered_map>
using namespace std;
typedef long long ll;
const int inf = 0x3f3f3f3f;
const int N = 200000 + 5;
struct node{
int link, len;
unordered_map<int,int> trans;
void clear(){
trans.clear();
link = len = 0;
}
};
struct SAM{
node S[N];
int p, np, size;
ll res;
SAM():p(1),np(1),size(1){}
void clear(){
for(int i=0;i<=size;i++)S[i].clear();
np = size = p = 1;
}
void insert(int ch){
int x = ch;
np = ++size;
S[np].len = S[p].len + 1;
while(p != 0 && !S[p].trans[x]) S[p].trans[x] = np, p = S[p].link;
if(p == 0)S[np].link = 1;
else{
int q, nq;
q = S[p].trans[x];
if(S[q].len == S[p].len + 1) S[np].link = q;
else{
nq = ++size;
S[nq] = S[q];
S[nq].len = S[p].len + 1;
S[np].link = S[q].link = nq;
while(p != 0 && S[p].trans[x] == q) S[p].trans[x] = nq, p = S[p].link;
}
}
p = np;
res += S[p].len - S[S[p].link].len; // res只会在这里更新
}
}sam;
int main() {
int n;scanf("%d",&n);
sam.res = 0;
while(n--){
int x;scanf("%d",&x);
sam.insert(x);
printf("%lld\n",sam.res);
}
return 0;
}
2. 求 \(endpos\) 集合大小
#1449 : 后缀自动机三·重复旋律6
沿后缀\(link\) 建立反向树,\(dfs\) 进行统计
using namespace std;
typedef long long ll;
const int inf = 0x3f3f3f3f;
const int N = 3000000 + 5;
const int P = 26;
int head[N], ver[N], nxt[N], tot;
ll res[N];
void add(int x, int y){
ver[++tot] = y, nxt[tot] = head[x], head[x] = tot;
}
struct node{
int trans[P], link, len;
ll sum;
void clear(){
memset(trans, 0, sizeof trans);
link = len = 0;
}
};
struct SAM{
node S[N];
int p, np, size;
SAM():p(1),np(1),size(1){}
void clear(){
for(int i=0;i<=size;i++)S[i].clear();
np = size = p = 1;
}
void insert(char ch){
int x = ch - 'a';
np = ++size;
S[np].sum = 1; // np 为一个新的endpos类,大小为1
S[np].len = S[p].len + 1;
while(p != 0 && !S[p].trans[x]) S[p].trans[x] = np, p = S[p].link;
if(p == 0)S[np].link = 1;
else{
int q, nq;
q = S[p].trans[x];
if(S[q].len == S[p].len + 1) S[np].link = q;
else{
nq = ++size;
S[nq] = S[q];
S[nq].sum = 0; // 分离出的结点,赋值为0
S[nq].len = S[p].len + 1;
S[np].link = S[q].link = nq;
while(p != 0 && S[p].trans[x] == q) S[p].trans[x] = nq, p = S[p].link;
}
}
p = np;
}
void dfs(int x){
for(int i=head[x];i;i=nxt[i]){
int y = ver[i];
dfs(y);
S[x].sum += S[y].sum;
}
}
void get(){
for(int i=2;i<=size;i++){ //沿后缀link反向建树
int to = S[i].link;
if(to != 0) add(to, i);
}
dfs(1); // dfs
for(int i=2;i<=size;i++){ // 更新答案
res[S[i].len] = max(res[S[i].len], S[i].sum);
}
}
}sam;
char s[N];
int main() {
scanf("%s",s);
int n = strlen(s);
for(int i=0;i<n;i++)sam.insert(s[i]);
sam.get();
for(int i=n-1;i>=1;i--) res[i] = max(res[i], res[i+1]);
for(int i=1;i<=n;i++){
printf("%lld\n",res[i]);
}
return 0;
}
P3804 【模板】后缀自动机 (SAM)
求出 \(endpos\) 之后扫一遍所有结点即可
ll get(){
for(int i=2;i<=size;i++){
int to = S[i].link;
if(to != 0) add(to, i);
}
dfs(1);
ll res = 0;
for(int i=2;i<=size;i++){
if(S[i].sum == 1) continue;
res = max(res, S[i].len * S[i].sum);
}
return res;
}
3. SAM上DP
#1457 : 后缀自动机四·重复旋律7
SAM是一个天然的DAG,某个结点作为一个最终结果接纳状态,可以表示若干个子串,\(endpos\) 的大小则表示了他们的出现次数。
本题中先考虑一个串的情况,DP时,某个结点沿着某条出边走向下一个结点,这个转移意味着给一些子串后面加一个字符,如果这个结点的数字之和为 \(sum\),\(endpos\) 大小表示这些数字在原串中共出现了多少次,那么沿着 \('1'\) 这条边走到下一个状态,贡献将产生 :\(sum*10 + 1 * endpos\)。
这里的\(endpos\) 可以顺便用拓扑序 \(dp\) 求出,当多个串连起来时,不考虑非法字符的转移即可。
typedef long long ll;
const int inf = 0x3f3f3f3f;
const int mod = 1e9 + 7;
const int N = 3000000 + 5;
const int P = 11;
int tot;
ll res[N];
struct node{
int trans[P], link, len;
ll sum, validnum;
void clear(){
memset(trans, 0, sizeof trans);
link = len = 0;
}
};
struct SAM{
node S[N];
int p, np, size, deg[N];
SAM():p(1),np(1),size(1){}
void clear(){
for(int i=0;i<=size;i++)S[i].clear();
np = size = p = 1;
}
void insert(char ch){
int x = ch - '0';
np = ++size;
S[np].len = S[p].len + 1;
while(p != 0 && !S[p].trans[x]) S[p].trans[x] = np, p = S[p].link;
if(p == 0)S[np].link = 1;
else{
int q, nq;
q = S[p].trans[x];
if(S[q].len == S[p].len + 1) S[np].link = q;
else{
nq = ++size;
S[nq] = S[q];
S[nq].len = S[p].len + 1;
S[np].link = S[q].link = nq;
while(p != 0 && S[p].trans[x] == q) S[p].trans[x] = nq, p = S[p].link;
}
}
p = np;
}
ll get(){
queue<int> q;
q.push(1);
S[1].validnum = 1;
for(int i=1;i<=size;i++){
for(int j=0;j<P;j++){
if(S[i].trans[j]) deg[S[i].trans[j]]++; //初始化结点入度
}
}
//按照拓扑序进行DP,DAG上的DP
ll res = 0;
while(q.size()){
int x = q.front();q.pop();
for(int i=0;i<P;i++){
int y = S[x].trans[i];
if(!y)continue;
if(--deg[y] == 0)q.push(y);
// i = 10时表示非法字符,不进行转移
if(i != P-1){
S[y].validnum += S[x].validnum;
S[y].sum += S[x].sum * 10 % mod + i * S[x].validnum % mod;
S[y].sum %= mod;
}
}
res = (res + S[x].sum)%mod;
}
return res;
}
}sam;
int n;
char s[N];
int main() {
scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%s",s);
int len = strlen(s);
for(int j=0;j<len;j++)sam.insert(s[j]);
if(i < n) sam.insert(':'); //‘:'的 ascii码为58,刚好等于 '9' + 1
}
printf("%lld\n", sam.get());
return 0;
}
4. 公共子串
Longest Common Substring
给两个串,求最长公共子串的长度。
将第一个串构建SAM,然后对第二个串的每个位置 i, 维护以 i 为结尾的最长字串的长度 l, 那么整个过程 l 的最大值就是答案。
如果维护?设当前状态结点为 \(u\),长度为$ l$,如果匹配不到当前字符,那 \(u = u.link\), 以直到 \(u = 1\),此过程中 $l $一直在缩小,令 \(l = min(l, u.len)\), 如果最终也没能匹配,那么\(u=1,l=0\),否则\(l++,u = u.trans[i]\);
#include <cstdio>
#include <iostream>
#include <cmath>
#include <string>
#include <cstring>
#include <algorithm>
#include <vector>
#include <queue>
using namespace std;
typedef long long ll;
const int inf = 0x3f3f3f3f;
const int N = 5000000 + 5;
const int P = 26;
char a[N], b[N];
struct node{
int link, len, trans[P];
void clear(){
memset(trans,0, sizeof trans);
link = len = 0;
}
};
struct SAM{
node S[N];
int p, np, size;
SAM():p(1),np(1),size(1){}
void clear(){
for(int i=0;i<=size;i++)S[i].clear();
np = size = p = 1;
}
void insert(char ch){
int x = ch - 'a';
np = ++size;
S[np].len = S[p].len + 1;
while(p != 0 && !S[p].trans[x]) S[p].trans[x] = np, p = S[p].link;
if(p == 0)S[np].link = 1;
else{
int q, nq;
q = S[p].trans[x];
if(S[q].len == S[p].len + 1) S[np].link = q;
else{
nq = ++size;
S[nq] = S[q];
S[nq].len = S[p].len + 1;
S[np].link = S[q].link = nq;
while(p != 0 && S[p].trans[x] == q) S[p].trans[x] = nq, p = S[p].link;
}
}
p = np;
}
void solve(char *b){
int u = 1, l = 0, n = strlen(b);
int res = 0;
for(int i=0;i<n;i++){
int c = b[i] - 'a';
while(u != 1 && !S[u].trans[c]) u = S[u].link, l = min(l, S[u].len);
if(S[u].trans[c]){
u = S[u].trans[c];
l++;
}
res = max(res, l);
}
cout << res << endl;
}
}sam;
int main() {
scanf("%s%s",a,b);
int n = strlen(a);
for(int i=0;i<n;i++)sam.insert(a[i]);
sam.solve(b);
return 0;
}
#1465 : 后缀自动机五·重复旋律8
给出一个字符串,有若干次询问,每次询问给出一个字符串,求该字符串的所有循环表示在原串中出现过多少次。
循环表示一般的套路就是复制一遍接到尾巴上,\(s[1...n] \Rightarrow S[1...n~1....n]\) , 枚举每个位置作为循环串的结尾,假设为 S[1...i],在SAM中找到一个结点 \(u\) , \(u\) 在SAM中表示的状态是 \(S[1...i]\) 的后缀,并且在 \(u\) 的 \(len >= n\) 前提下,使得\(endpos\) 尽可能大,在这样情况下求出来的 \(u\) 就是一个以字符\(S[i]\) 结尾的循环子串,在原串中可以出现的所有位置。
typedef long long ll;
const int inf = 0x3f3f3f3f;
const int mod = 1e9 + 7;
const int N = 3000000 + 5;
const int P = 26;
ll res[N];
int head[N], ver[N], nxt[N], tot;
void add(int x, int y){
ver[++tot] = y, nxt[tot] = head[x], head[x] = tot;
}
struct node{
int trans[P], link, len;
int endpos;
void clear(){
memset(trans, 0, sizeof trans);
link = len = 0;
}
};
struct SAM{
node S[N];
int p, np, size;
int vis[N];
vector<int> vislist;
SAM():p(1),np(1),size(1){}
void clear(){
for(int i=0;i<=size;i++)S[i].clear();
np = size = p = 1;
}
void insert(char ch){
int x = ch - 'a';
np = ++size;
S[np].endpos = 1;
S[np].len = S[p].len + 1;
while(p != 0 && !S[p].trans[x]) S[p].trans[x] = np, p = S[p].link;
if(p == 0)S[np].link = 1;
else{
int q, nq;
q = S[p].trans[x];
if(S[q].len == S[p].len + 1) S[np].link = q;
else{
nq = ++size;
S[nq] = S[q];
S[nq].endpos = 0;
S[nq].len = S[p].len + 1;
S[np].link = S[q].link = nq;
while(p != 0 && S[p].trans[x] == q) S[p].trans[x] = nq, p = S[p].link;
}
}
p = np;
}
void dfs(int x){
for(int i=head[x];i;i=nxt[i]){
dfs(ver[i]);
S[x].endpos += S[ver[i]].endpos;
}
}
void getEndpos(){
for(int i=2;i<=size;i++){
add(S[i].link, i);
}
dfs(1);
}
void solve(char *s){
int len = strlen(s), newLen = len * 2 - 1; // 长度为2*len-1的原因是防止重复计算
for(int i=len;i<newLen;i++)s[i] = s[i-len];
int u = 1, l = 0, res = 0;
for(int i=0;i<newLen;i++){
int c = s[i] - 'a';
while(u != 1 && !S[u].trans[c]){ // 找不到可以匹配 c 的 u
u = S[u].link;
l = S[u].len;
}
if(S[u].trans[c]){ // 可以匹配 c 的 u 存在
u = S[u].trans[c];
++ l;
}else { //不存在能匹配 c 的 u,可以可以break了
break;
}
if(l > len){
while(S[S[u].link].len >= len){
u = S[u].link;
l = S[u].len;
}
}
if(l >= len && !vis[u]){
vis[u] = 1;
vislist.push_back(u);
res += S[u].endpos;
}
}
for(int i=0;i<vislist.size();i++){
vis[vislist[i]] = 0;
}
printf("%d\n",res);
}
}sam;
int n;
char s[N];
int main() {
scanf("%s",s);
n = strlen(s);
for(int i=0;i<n;i++)sam.insert(s[i]);
scanf("%d",&n);
sam.getEndpos();
for(int i=1;i<=n;i++){
scanf("%s",s);
sam.solve(s);
}
return 0;
}
Longest Common Substring II
多个串之间的LCS,与第一个问题大同小异,先构建第一个串的SAM,然后对于后来的每个串去按照第一个题进行一样的处理,并标记SAM中每个点能够匹配的最大长度,这也意味着这些点的link也具有这样的长度,所以要在扫描结束后自底向上的更新max值,然后对于所有的串,取min值即可。
#include <cstdio>
#include <iostream>
#include <cmath>
#include <string>
#include <cstring>
#include <algorithm>
#include <vector>
#include <queue>
using namespace std;
typedef long long ll;
const int inf = 0x3f3f3f3f;
const int N = 200000 + 5;
const int P = 26;
char a[N], b[N];
struct node{
int link, len, trans[P], mx, mi;
void clear(){
memset(trans,0, sizeof trans);
link = len = 0;
}
};
struct SAM{
node S[N];
int p, np, size;
int b[N], c[N];
SAM():p(1),np(1),size(1){}
void clear(){
for(int i=0;i<=size;i++)S[i].clear();
np = size = p = 1;
}
void insert(char ch){
int x = ch - 'a';
np = ++size;
S[np].len = S[p].len + 1;
while(p != 0 && !S[p].trans[x]) S[p].trans[x] = np, p = S[p].link;
if(p == 0)S[np].link = 1;
else{
int q, nq;
q = S[p].trans[x];
if(S[q].len == S[p].len + 1) S[np].link = q;
else{
nq = ++size;
S[nq] = S[q];
S[nq].len = S[p].len + 1;
S[np].link = S[q].link = nq;
while(p != 0 && S[p].trans[x] == q) S[p].trans[x] = nq, p = S[p].link;
}
}
p = np;
}
void sort(){//按照len值排序,保证从后往前遍历能够正确更新mx值
for(int i=1;i<=size;i++)++c[S[i].len];
for(int i=1;i<=size;i++)c[i] += c[i-1];
for(int i=1;i<=size;i++)b[c[S[i].len]--] = i;
for(int i=1;i<=size;i++)S[i].mi = inf;
}
void solve(char *s){
int u = 1, l = 0, n = strlen(s);
int res = 0;
for(int i=0;i<n;i++){
int c = s[i] - 'a';
while(u != 1 && !S[u].trans[c]) u = S[u].link, l = S[u].len;
if(S[u].trans[c]){
u = S[u].trans[c];
l++;
S[u].mx = max(S[u].mx, l);
}else u = 1, l = 0;
}
for(int i=size;i>=1;i--){
int u = b[i], fa = S[u].link;
S[fa].mx = max(S[fa].mx, min(S[u].mx, S[fa].len));
S[u].mi = min(S[u].mi, S[u].mx);
S[u].mx = 0;
}
}
}sam;
int main() {
scanf("%s",a);
int n = strlen(a);
for(int i=0;i<n;i++)sam.insert(a[i]);
sam.sort();
while(~scanf("%s",a))sam.solve(a);
int res = 0;
for(int i=1;i<=sam.size;i++) res = max(res, sam.S[i].mi);
printf("%d\n",res);
return 0;
}
5. 字典序第 K 大子串
Lexicographical Substring Search
在SAM中的DAG上可以直接求出每个结点以及其后续状态的所有子串,然后依次从小到大试填。
#include <cstdio>
#include <iostream>
#include <cmath>
#include <string>
#include <cstring>
#include <algorithm>
#include <vector>
#include <queue>
using namespace std;
typedef long long ll;
const int inf = 0x3f3f3f3f;
const int N = 200000 + 5;
const int P = 26;
char s[N], a[N];
struct node{
int link, len, trans[P];
void clear(){
memset(trans,0, sizeof trans);
link = len = 0;
}
};
struct SAM{
node S[N];
int p, np, size;
ll d[N];
SAM():p(1),np(1),size(1){}
void clear(){
for(int i=0;i<=size;i++)S[i].clear();
np = size = p = 1;
}
void insert(char ch){
int x = ch - 'a';
np = ++size;
S[np].len = S[p].len + 1;
while(p != 0 && !S[p].trans[x]) S[p].trans[x] = np, p = S[p].link;
if(p == 0)S[np].link = 1;
else{
int q, nq;
q = S[p].trans[x];
if(S[q].len == S[p].len + 1) S[np].link = q;
else{
nq = ++size;
S[nq] = S[q];
S[nq].len = S[p].len + 1;
S[np].link = S[q].link = nq;
while(p != 0 && S[p].trans[x] == q) S[p].trans[x] = nq, p = S[p].link;
}
}
p = np;
}
void dfs(int x){
if(d[x])return; // 记忆化
d[x] = 1; // 它本身肯定算作一个
for(int i=0;i<P;i++){
if(S[x].trans[i]){
dfs(S[x].trans[i]);
d[x] += d[S[x].trans[i]];
}
}
}
void init(){
dfs(1);
d[1] --;//空串本身不算,所以减一
}
void get(int k){
int pos = 0, u = 1;
while(k){
if(u != 1) k --; //非空串占用一个
if(k == 0)break;
for(int i=0;i<P;i++){
int y = S[u].trans[i];
if(!y)continue;
if(d[y] >= k){
u = S[u].trans[i];
a[pos++] = 'a' + i;
break;
}
k -= d[y];
}
}
a[pos] = 0;
puts(a);
}
}sam;
int main() {
scanf("%s",s);
int n = strlen(s);
for(int i=0;i<n;i++)sam.insert(s[i]);
sam.init();
int q;scanf("%d",&q);
while(q--){
int k;scanf("%d",&k);
sam.get(k);
}
return 0;
}
#2102. 「TJOI2015」弦论
这个题稍稍复杂一些,包含了对于不同位置的相同子串的计算,那么SAM上每个结点都有 \(endpos\) 次出现,所以我们在计算出\(endpos\)之后再DP就好了。
// TJOI 2015
#include <cstdio>
#include <iostream>
#include <cmath>
#include <string>
#include <cstring>
#include <algorithm>
#include <vector>
#include <queue>
using namespace std;
typedef long long ll;
const int inf = 0x3f3f3f3f;
const int N = 1000000 + 5;
const int P = 26;
int head[N], ver[N], nxt[N], tot;
void add(int x, int y){ver[++tot] = y, nxt[tot] = head[x], head[x] = tot; }
char s[N], a[N];
struct node{
int link, len, trans[P], endpos;
void clear(){
memset(trans,0, sizeof trans);
link = len = 0;
}
};
struct SAM{
node S[N];
int p, np, size;
ll d[N][2]; // 0 表示不同位置子串算一个,1 表示不同位置子串算多个
SAM():p(1),np(1),size(1){}
void clear(){
for(int i=0;i<=size;i++)S[i].clear();
np = size = p = 1;
}
void insert(char ch){
int x = ch - 'a';
np = ++size;
S[np].endpos = 1;
S[np].len = S[p].len + 1;
while(p != 0 && !S[p].trans[x]) S[p].trans[x] = np, p = S[p].link;
if(p == 0)S[np].link = 1;
else{
int q, nq;
q = S[p].trans[x];
if(S[q].len == S[p].len + 1) S[np].link = q;
else{
nq = ++size;
S[nq] = S[q];
S[nq].endpos = 0;
S[nq].len = S[p].len + 1;
S[np].link = S[q].link = nq;
while(p != 0 && S[p].trans[x] == q) S[p].trans[x] = nq, p = S[p].link;
}
}
p = np;
}
void getendpos(int x){
for(int i=head[x];i;i=nxt[i]){
getendpos(ver[i]);
S[x].endpos += S[ver[i]].endpos;
}
}
void dfs(int x){
if(d[x][0])return; // 记忆化
d[x][0] = 1; // 它本身肯定算作一个
d[x][1] = S[x].endpos;
for(int i=0;i<P;i++){
if(S[x].trans[i]){
dfs(S[x].trans[i]);
d[x][0] += d[S[x].trans[i]][0];
d[x][1] += d[S[x].trans[i]][1];
}
}
}
void init(){
for(int i=2;i<=size;i++){
if(S[i].link) add(S[i].link, i);
}
getendpos(1);
dfs(1);
d[1][0] --;//空串本身不算,所以减一
}
void get(int k, int t){
int pos = 0, u = 1;
if(d[u][t] < k){
puts("-1");
return;
}
while(k){
if(u != 1){
if(t == 0) k--; // t = 0 时 只占用一个串
else k -= S[u].endpos; //t = 1 时,占用endpos个串
}
if(k <= 0)break;
for(int i=0;i<P;i++){
int y = S[u].trans[i];
if(!y)continue;
if(d[y][t] >= k){
u = S[u].trans[i];
a[pos++] = 'a' + i;
break;
}
k -= d[y][t];
}
}
a[pos] = 0;
puts(a);
}
}sam;
int main() {
scanf("%s",s);
int n = strlen(s);
for(int i=0;i<n;i++)sam.insert(s[i]);
sam.init();
int t,k;
scanf("%d%d",&t,&k);
sam.get(k,t);
return 0;
}
综合题
1. SAM与DAG上的博弈
这个题等价于同时进行两个DAG有向图游戏,出度为 0 的点是必败态,每次只能往前走一步。所以可以直接\(DFS\)求每个点的\(SG\)函数,而有向图游戏的和就是简单的将他们现在的状态的\(SG\)值异或起来,如果等于0则为必败态。
要求字典序第 \(k\) 小的答案,所以我们还要统计两个DAG中某个状态以及该状态的后继状态中的所有必胜态个数,由于每个状态最多只有26条出边,所以\(sg\) 值并不是很大。用 \(S[x].cnt[i]\) 表示\(x\) 以及 \(x\) 的所有后继状态的 \(sg\) 值等于 i 的个数,\(S[x].rcnt[i]\) 表示 x 以及 x 的所有后继状态中 sg 值不等于 i 的个数。
统计所有必胜态个数,如果小于 k,那么输出 NO。
然后从 a 为空串开始试填,假设填到某个地方,a串成为答案,a串SAM中结点走到了 u(因为每次匹配一个字符都要往前走一格),那么必胜态只有 \(B.S[1].rcnt[A.S[u].sg]\) , 其中 S[1] 表示SAM中的源点,这个值跟 k 去作比较,如果 k 小于等于它,那么这时的 a 串就已经是答案了,否则考虑 a 的后继状态,也是一样的思考方法,就不再赘述了。
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int inf = 0x3f3f3f3f;
const int N = 200000 + 5;
const int P = 27;
struct node{
int trans[P], link, len, sg;
ll cnt[P], rcnt[P], tot;
node(){
clear();
}
void clear(){
memset(trans, 0, sizeof trans);
memset(cnt,0, sizeof cnt);
memset(rcnt, 0, sizeof rcnt);
link = len = 0;
sg = -1;
}
};
struct SAM{
node S[N];
int p, np, size;
SAM():p(1),np(1),size(1){}
void clear(){
for(int i=0;i<=size;i++)S[i].clear();
np = size = p = 1;
}
void insert(char ch){
int x = ch - 'a';
np = ++size;
S[np].len = S[p].len + 1;
while(p != 0 && !S[p].trans[x]) S[p].trans[x] = np, p = S[p].link;
if(p == 0)S[np].link = 1;
else{
int q, nq;
q = S[p].trans[x];
if(S[q].len == S[p].len + 1) S[np].link = q;
else{
nq = ++size;
S[nq] = S[q];
S[nq].len = S[p].len + 1;
S[np].link = S[q].link = nq;
while(p != 0 && S[p].trans[x] == q) S[p].trans[x] = nq, p = S[p].link;
}
}
p = np;
}
void getSg(int x){
//要记得记忆化,防止复杂度暴增
if(S[x].sg != -1)return;
S[x].sg = 0;
bool st[P] = {false};
for(int i=0;i<P;i++){
int y = S[x].trans[i];
if(y){
getSg(y);
st[S[y].sg] = true;
}
}
while(st[S[x].sg])S[x].sg++;
}
void getcnt(int x){
if(S[x].cnt[S[x].sg])return;
ll sum = 0;
S[x].cnt[S[x].sg]++;
for(int i=0;i<P;i++){
int y = S[x].trans[i];
if(y){
getcnt(y);
for(int j=0;j<P;j++){
S[x].cnt[j] += S[y].cnt[j];
}
}
}
for(int i=0;i<P;i++) sum += S[x].cnt[i];
for(int i=0;i<P;i++) S[x].rcnt[i] = sum - S[x].cnt[i];
}
void insert(char *s){
int n = strlen(s);
for(int i=0;i < n;i++){
insert(s[i]);
}
getSg(1); //递归求SG
getcnt(1); //递归求cnt以及rcnt
}
} A, B;
ll k;
char a[N], b[N];
ll calc(int u){
if(A.S[u].tot) return A.S[u].tot;
ll res = 0;
for(int i = 0;i<P;i++){
res += A.S[u].cnt[i] * B.S[1].rcnt[i];
}
return A.S[u].tot = res;
}
bool get(){
ll sum = 0;
bool flag = false;
for(int i=0;i<P;i++){
sum = sum + A.S[1].cnt[i] * B.S[1].rcnt[i];
if(sum >= k){ // 防止sum爆ll,大于等于 k就立即退出
flag = true;
break;
}
}
if(!flag) return false;
// 开始依次试填 a, u 为 A中当前结点,pos表示a串试填长度
int pos = 0, u = 1;
while(k){
sum = B.S[1].rcnt[A.S[u].sg]; // 计算a串为u状态下,所有的必胜态
if(sum >= k) break; // a 为 u 状态可行
k -= sum;
for(int i=0;i<P;i++){
int y = A.S[u].trans[i];
if(!y)continue;
sum = calc(y);//计算 y后继所有必胜态
if(sum >= k){ // 可以用这个来作为前缀
a[pos++] = i + 'a';
u = y;
break;
}
k -= sum;
}
}
// 此时 A 的状态为 u, 要在 B 中找 第 k 个 不等于 sg 的结点
int sg = A.S[u].sg;
pos = 0, u = 1;
while(k){
if(B.S[u].sg != sg){
k --;
if(k == 0) break;
}
for(int i=0;i<P;i++){
int y = B.S[u].trans[i];
if(!y)continue;
sum = B.S[y].rcnt[sg];
if(sum >= k){
b[pos++] = i + 'a';
u = y;
break;
}
k -= sum;
}
}
return true;
}
int main() {
scanf("%lld%s%s",&k,a,b);
A.insert(a);B.insert(b);
memset(a, 0, sizeof a);
memset(b, 0, sizeof b);
if(get()){
printf("%s\n%s",a,b);
}else{
puts("NO");
}
return 0;
}
2. SAM维护字符串DP
2019年杭电多校的题目,粘一下出题人题解:
对于\(i\)从小到大处理,维护使得\(s[j:i]\in s[1:j-1]\)的最小的\(j\)(\(s[l:r]\)表示子串\(s_ls_{l+1}...s_{r}\)),那么记\(f[i]\)为输出前\(i\)个字符的最小代价,则\(f[i]=\min\{f[i-1]+p,f[j-1]+q\}\)。
用SAM维护\(s[1:j-1]\),若\(s[1:j-1]\)中包含\(s[j:i+1]\),即加入第\(i+1\)个字符仍然能复制,就不需要做任何处理。否则,重复地将第\(j\)个字符加入后缀自动机并\(j=j+1\),相应维护\(s[j:i+1]\)在后缀自动机上新的匹配位置,直到\(s[j,i+1]\in s[1,j-1]\)。
时限给了\(1.5s\),应该是卡掉了所有非线性做法
#include <cstdio>
#include <iostream>
#include <cmath>
#include <string>
#include <cstring>
#include <algorithm>
#include <vector>
#include <queue>
using namespace std;
typedef long long ll;
const int inf = 0x3f3f3f3f;
const int N = 400000 + 5;
const int P = 26;
char a[N];
ll p, q, d[N];
struct node{
int trans[P], link, len;
ll sum;
void clear(){
memset(trans, 0, sizeof trans);
link = len = 0;
}
};
struct SAM{
node S[N];
int p, np, size, u;
SAM():p(1),np(1),size(1){}
void clear(){
for(int i=0;i<=size;i++)S[i].clear();
u = np = size = p = 1;
}
void insert(char ch){
int x = ch - 'a';
np = ++size;
S[np].sum = 1;
S[np].len = S[p].len + 1;
while(p != 0 && !S[p].trans[x]) S[p].trans[x] = np, p = S[p].link;
if(p == 0)S[np].link = 1;
else{
int q, nq;
q = S[p].trans[x];
if(S[q].len == S[p].len + 1) S[np].link = q;
else{
nq = ++size;
S[nq] = S[q];
S[nq].sum = 0;
S[nq].len = S[p].len + 1;
S[np].link = S[q].link = nq;
while(p != 0 && S[p].trans[x] == q) S[p].trans[x] = nq, p = S[p].link;
}
}
p = np;
}
bool find(char ch){
return S[u].trans[ch-'a'];
}
void solve(char *a, ll p, ll q){
int n = strlen(a+1);
int j = 1;
u = 1;
d[0] = 0;
for(int i=1;i<=n;i++){
while(!find(a[i])){//如果[1..j-1]所构成的SAM无法匹配到 a[j..i]
if(u == 1){
insert(a[j++]);
}else{
//刚到这里 u 表示最长的可匹配到 a[j...i-1],然后进一步缩短 u 所表示的长度,增加字符到SAM
u = S[u].link;
while(j < i - S[u].len){//[i-S[u].len,....i-1] 为u匹配的字符串,所以最多添加到 i-S[u].len-1
insert(a[j++]);
}
}
}
u = S[u].trans[a[i]-'a'];
d[i] = d[i-1] + p;
if(j <= i)
d[i] = min(d[i], d[j-1] + q);
}
printf("%lld\n",d[n]);
}
}sam;
int main() {
while(~scanf("%s",a+1)){
scanf("%lld%lld",&p,&q);
sam.solve(a, p, q);
sam.clear();
}
return 0;
}

浙公网安备 33010602011771号