后缀自动机练习题
版题,用来熟悉 SAM 以及其树形结构的用途。
难度都差不多,就没怎么注意排序。
SDOI2016 生成魔咒
对一个长度为 \(n\) 的字符串,每次动态地往 SAM 中插入一个字符,求每次插入之后不同子串的个数。
提一下另一种计算不同子串个数的方法。
由于一个子串必处于一个状态且仅在一个状态中,所以对子串计数即对每个状态的子串数量计数。
一个状态 \(u\) 的子串即 \(\operatorname{long}(u)\) 的一段连续后缀,一直到 \(\operatorname{short}\) 为止,即共 \(\operatorname{len}(u)-\operatorname{minlen}(u)+1\)。
因为 \(\operatorname{minlen}(u)=\operatorname{len}(\operatorname{link}(u))+1\),化简一下,一个结点的贡献就是 \(\operatorname{len}(u)-\operatorname{len}(\operatorname{link}(u))\)。
对每个结点计数求和即可。
回到本题,加入一个字符会多一个状态,对这个新状态求增量即可。
构造 SAM 的过程中分裂出的新节点 \(new\) 为什么不算?
因为分裂 \(new\) 的本质即把一个状态拆成两个,总和不变,只算新的即可。
#include<bits/stdc++.h>
#define sd std::
#define int long long
#define F(i,a,b) for(int i=(a);i<=(b);i++)
#define f(i,a,b) for(int i=(a);i>=(b);i--)
#define MIN(x,y) (x<y?x:y)
#define MAX(x,y) (x>y?x:y)
#define me(x,y) memset(x,y,sizeof x)
#define pii sd pair<int,int>
#define X first
#define Y second
#define Fr(a) for(auto it:a)
int read(){int w=1,c=0;char ch=getchar();for(;ch>'9'||ch<'0';ch=getchar()) if(ch=='-') w=-1;for(;ch>='0'&&ch<='9';ch=getchar()) c=(c<<3)+(c<<1)+ch-48;return w*c;}
void printt(int x){if(x>9) printt(x/10);putchar(x%10+48);}
void print(int x){if(x<0) putchar('-'),printt(-x);else printt(x);}
void put(int x){print(x);putchar('\n');}
void printk(int x){print(x);putchar(' ');}
const int N=2e5+10;
struct state
{
int link,len;
sd map<int,int> nex;
}st[N];
int last,siz;
void init()
{
st[0].link=-1;
st[0].len=0;
siz++,last=0;
}
int extend(int c)
{
int cur=siz++,p=last;
st[cur].len=st[last].len+1;
while(p!=-1&&!st[p].nex.count(c))
{
st[p].nex[c]=cur;
p=st[p].link;
}
if(p==-1)
{
st[cur].link=0;
}
else
{
int q=st[p].nex[c];
if(st[q].len==st[p].len+1)
{
st[cur].link=q;
}
else
{
int nw=siz++;
st[nw].nex=st[q].nex;
st[nw].len=st[p].len+1;
st[nw].link=st[q].link;
while(p!=-1&&st[p].nex[c]==q)
{
st[p].nex[c]=nw;
p=st[p].link;
}
st[cur].link=st[q].link=nw;
}
}
last=cur;
return cur;
}
int n,ans;
void solve()
{
n=read();
init();
F(i,1,n)
{
int x=read();
int now=extend(x);
ans+=st[now].len-st[st[now].link].len;
put(ans);
}
}
signed main()
{
int T=1;
// T=read();
while(T--) solve();
return 0;
}
TJOI2019 甲苯先生和大中锋的字符串
一开始读错题了。。。
一个状态会将 \([\operatorname{minlen}(u),\operatorname{len}(u)]\) 这一段的数量加 1,差分维护即可。
#include<bits/stdc++.h>
#define sd std::
#define int long long
#define F(i,a,b) for(int i=(a);i<=(b);i++)
#define f(i,a,b) for(int i=(a);i>=(b);i--)
#define MIN(x,y) (x<y?x:y)
#define MAX(x,y) (x>y?x:y)
#define me(x,y) memset(x,y,sizeof x)
#define pii sd pair<int,int>
#define X first
#define Y second
#define Fr(a) for(auto it:a)
int read(){int w=1,c=0;char ch=getchar();for(;ch>'9'||ch<'0';ch=getchar()) if(ch=='-') w=-1;for(;ch>='0'&&ch<='9';ch=getchar()) c=(c<<3)+(c<<1)+ch-48;return w*c;}
void printt(int x){if(x>9) printt(x/10);putchar(x%10+48);}
void print(int x){if(x<0) putchar('-'),printt(-x);else printt(x);}
void put(int x){print(x);putchar('\n');}
void printk(int x){print(x);putchar(' ');}
const int N=1e6+10;
struct state
{
int link,len;
sd map<int,int> nex;
}st[N];
int siz,last;
char s[N];
void init()
{
st[0].link=-1;
st[0].len=0;
siz++;
last=0;
}
int dp[N];
void extend(char c)
{
int cur=siz++;
dp[cur]=1;
st[cur].len=st[last].len+1;
int p=last;
while(p!=-1&&!st[p].nex.count(c))
{
st[p].nex[c]=cur;
p=st[p].link;
}
if(p==-1) st[cur].link=0;
else
{
int q=st[p].nex[c];
if(st[q].len==st[p].len+1) st[cur].link=q;
else
{
int nw=siz++;
st[nw].nex=st[q].nex;
st[nw].link=st[q].link;
st[nw].len=st[p].len+1;
while(p!=-1&&st[p].nex[c]==q)
{
st[p].nex[c]=nw;
p=st[p].link;
}
st[cur].link=st[q].link=nw;
}
}
last=cur;
}
struct node
{
int nex,to;
}a[N];
int tot,head[N];
void add(int u,int v)
{
a[++tot].nex=head[u];
head[u]=tot;
a[tot].to=v;
}
void dfs(int u)
{
for(int i=head[u];i;i=a[i].nex)
{
int v=a[i].to;
dfs(v);
dp[u]+=dp[v];
}
}
int n,K;
int cnt[N];
void clear()
{
F(i,0,siz-1) dp[i]=0,st[i].nex.clear(),st[i].len=head[i]=0;
siz=0;
tot=0;
}
void solve()
{
clear();
scanf("%s",s+1);
n=strlen(s+1);K=read();
F(i,1,n) cnt[i]=0;
init();
F(i,1,n) extend(s[i]);
F(i,1,siz-1) add(st[i].link,i);
dfs(0);
int fl=0;
F(i,1,siz-1) if(dp[i]==K)
{
fl=1;
//len(link(i))+1~len(i)
int l=st[st[i].link].len+1,r=st[i].len;
cnt[l]++;
cnt[r+1]--;
}
if(!fl) return put(-1);
F(i,1,n) cnt[i]+=cnt[i-1];
int ans=0,nice=0;
f(i,n,1) if(cnt[i]>nice) ans=i,nice=cnt[i];
put(ans);
}
signed main()
{
int T=1;
T=read();
while(T--) solve();
return 0;
}
CF802I Fake News (hard)
处理出 \(cnt_u\),一个 \(u\) 的贡献是其子串数量乘以 \(cnt^2_u\),随便算算即可。
#include<bits/stdc++.h>
#define sd std::
//#define int long long
#define F(i,a,b) for(int i=(a);i<=(b);i++)
#define f(i,a,b) for(int i=(a);i>=(b);i--)
#define MIN(x,y) (x<y?x:y)
#define MAX(x,y) (x>y?x:y)
#define me(x,y) memset(x,y,sizeof x)
#define pii sd pair<int,int>
#define X first
#define Y second
#define Fr(a) for(auto it:a)
int read(){int w=1,c=0;char ch=getchar();for(;ch>'9'||ch<'0';ch=getchar()) if(ch=='-') w=-1;for(;ch>='0'&&ch<='9';ch=getchar()) c=(c<<3)+(c<<1)+ch-48;return w*c;}
void printt(long long x){if(x>9) printt(x/10);putchar(x%10+48);}
void print(long long x){if(x<0) putchar('-'),printt(-x);else printt(x);}
void put(long long x){print(x);putchar('\n');}
void printk(long long x){print(x);putchar(' ');}
const int N=2e5+10;
struct state
{
int link,len;
sd map<int,int> nex;
}st[N];
int siz,last;
void init()
{
st[0].link=-1;
st[0].len=0;
siz++,last=0;
}
int cnt[N];
void extend(char c)
{
int cur=siz++,p=last;
st[cur].len=st[last].len+1;
cnt[cur]=1;
while(p!=-1&&!st[p].nex.count(c))
{
st[p].nex[c]=cur;
p=st[p].link;
}
if(p==-1)
{
st[cur].link=0;
}
else
{
int q=st[p].nex[c];
if(st[q].len==st[p].len+1)
{
st[cur].link=q;
}
else
{
int nw=siz++;
st[nw].len=st[p].len+1;
st[nw].link=st[q].link;
st[nw].nex=st[q].nex;
while(p!=-1&&st[p].nex[c]==q)
{
st[p].nex[c]=nw;
p=st[p].link;
}
st[cur].link=st[q].link=nw;
}
}
last=cur;
}
struct node
{
int nex;
int to;
}a[N];
int tot,head[N];
void add(int u,int v)
{
a[++tot].nex=head[u];
head[u]=tot;
a[tot].to=v;
}
void dfs(int u)
{
for(int i=head[u];i;i=a[i].nex)
{
int v=a[i].to;
dfs(v);
cnt[u]+=cnt[v];
}
}
void clear()
{
F(i,0,siz-1) st[i].nex.clear(),head[i]=0,cnt[i]=0;
tot=siz=0;
init();
}
char s[N];
int n;
void solve()
{
clear();
scanf("%s",s+1);
n=strlen(s+1);
F(i,1,n) extend(s[i]);
F(i,1,siz-1) add(st[i].link,i);
dfs(0);
long long ans=0;
F(i,1,siz-1)
{
int l=st[st[i].link].len+1,r=st[i].len;
ans+=1ll*(r-l+1)*cnt[i]*cnt[i];
}
put(ans);
}
int main()
{
int T=1;
T=read();
while(T--) solve();
return 0;
}
APIO2014 回文串
伪广义后缀自动机练习题。
考虑将 \(s\) 和 \(s\) 的反串拼起来,中间用特殊字符分隔。
则变为大字符串的出现过两次且 \(s\) 和 \(s\) 的反串中都出现过的子串贡献最大值。
考虑处理出 \(cnt_u\) 代表状态 \(u\) 的 \(\operatorname{endpos}\) 大小。
显然,实际出现次数为 \(\dfrac{cnt_u}{2}\)。注意 \(cnt_u\) 为奇数的情况,相当于在 \(s\) 或在 \(s\) 反串中多出现一次,是不计算贡献的,即 \(\left \lfloor \dfrac{cnt_u}{2} \right \rfloor\)。
(这是调代码过程中出现的,其实我也不太清楚一个回文串为什么正串反串出现次数能不一样的,可能是我代码实现有问题?)
这个做法因为复制了原串,加上 SAM 本身的两倍空间,要开 \(4n\) 空间,有可能会爆。
以下是经过卡空间之后的代码,可以通过洛谷数据。
#include<bits/stdc++.h>
#define sd std::
#define F(i,a,b) for(i=(a);i<=(b);i++)
#define f(i,a,b) for(i=(a);i>=(b);i--)
#define MIN(x,y) (x<y?x:y)
#define MAX(x,y) (x>y?x:y)
#define me(x,y) memset(x,y,sizeof x)
#define pii sd pair<int,int>
#define X first
#define Y second
#define Fr(a) for(auto it:a)
const int N=1.2e6+5;
struct state
{
int len,link,nex[27];
}st[N];
int last,siz;
void init()
{
st[0].link=-1;
siz++;
}
sd bitset<N> dp[2];//记录两个
int cnt[N];
void extend(char c,int op)
{
int cur=siz++,p=last;
dp[op][cur]=1;
cnt[cur]=1;
st[cur].len=st[last].len+1;
while(p!=-1&&!st[p].nex[c])
{
st[p].nex[c]=cur;
p=st[p].link;
}
if(p==-1) st[cur].link=0;
else
{
int q=st[p].nex[c];
if(st[q].len==st[p].len+1)
{
st[cur].link=q;
}
else
{
int nw=siz++;
st[nw].link=st[q].link;
st[nw].len=st[p].len+1;
for(int i=0;i<26;i++) st[nw].nex[i]=st[q].nex[i];
while(p!=-1&&st[p].nex[c]==q)
{
st[p].nex[c]=nw;
p=st[p].link;
}
st[cur].link=st[q].link=nw;
}
}
last=cur;
}
#define a st
#define to nex[1]
#define NEX nex[2]
int tot,head[N];
void add(int u,int v)
{
a[++tot].NEX=head[u];
head[u]=tot;
a[tot].to=v;
}
void dfs(int u)
{
for(int i=head[u];i;i=a[i].NEX)
{
int v=a[i].to;
dfs(v);
dp[0][u]=dp[0][u]|dp[0][v];
dp[1][u]=dp[1][u]|dp[1][v];
cnt[u]+=cnt[v];
}
}
int n,i;
char s[N];
void solve()
{
scanf("%s",s+1);
n=strlen(s+1);
init();
F(i,1,n) extend(s[i]-'a',0);
extend(26,0);
f(i,n,1) extend(s[i]-'a',1);
for(int i=1;i<=siz-1;i++) st[i].nex[0]=st[i].nex[1]=0;
F(i,1,siz-1) add(st[i].link,i);
dfs(0);
long long ans=0;
F(i,0,siz-1) if(dp[1][i]&dp[0][i]) ans=MAX(ans,1ll*cnt[i]/2*st[i].len);
printf("%lld",ans);
}
signed main()
{
int T=1;
// T=read();
while(T--) solve();
return 0;
}
HAOI2016 找相同字符
也是伪后缀自动机。
将两个串拼起来形成大串。
记录每个串在两个分串中出现的次数 \(cnt_1\) 和 \(cnt_2\)。稍微算算即可。
一个状态的贡献是 \(p_u\times cnt_1(u)\times cnt_2(u)\),\(p_u\) 就是这个状态内有多少个子串,这个说过怎么算的。
#include<bits/stdc++.h>
#define sd std::
//#define int long long
#define F(i,a,b) for(int i=(a);i<=(b);i++)
#define f(i,a,b) for(int i=(a);i>=(b);i--)
#define MIN(x,y) (x<y?x:y)
#define MAX(x,y) (x>y?x:y)
#define me(x,y) memset(x,y,sizeof x)
#define pii sd pair<int,int>
#define X first
#define Y second
#define Fr(a) for(auto it:a)
int read(){int w=1,c=0;char ch=getchar();for(;ch>'9'||ch<'0';ch=getchar()) if(ch=='-') w=-1;for(;ch>='0'&&ch<='9';ch=getchar()) c=(c<<3)+(c<<1)+ch-48;return w*c;}
void printt(int x){if(x>9) printt(x/10);putchar(x%10+48);}
void print(int x){if(x<0) putchar('-'),printt(-x);else printt(x);}
void put(int x){print(x);putchar('\n');}
void printk(int x){print(x);putchar(' ');}
const int N=2e6+10;
char s[N],t[N];
int n,m;
struct state
{
int len,link;
sd map<int,int> nex;
}st[N];
int siz,last;
void init()
{
st[0].link=-1;
st[0].len=0;
siz++,last=0;
}
int dp[N][2],cnt[N][2];
void extend(char c,int op)
{
int cur=siz++,p=last;
dp[cur][op]=1;
cnt[cur][op]=1;
st[cur].len=st[last].len+1;
while(p!=-1&&!st[p].nex.count(c))
{
st[p].nex[c]=cur;
p=st[p].link;
}
if(p==-1)
{
st[cur].link=0;
}
else
{
int q=st[p].nex[c];
if(st[q].len==st[p].len+1)
{
st[cur].link=q;
}
else
{
int nw=siz++;
st[nw].link=st[q].link;
st[nw].len=st[p].len+1;
st[nw].nex=st[q].nex;
while(p!=-1&&st[p].nex[c]==q)
{
st[p].nex[c]=nw;
p=st[p].link;
}
st[cur].link=st[q].link=nw;
}
}
last=cur;
}
struct node
{
int nex;
int to;
}a[N<<1];
int tot,head[N];
void add(int u,int v)
{
a[++tot].nex=head[u];
head[u]=tot;
a[tot].to=v;
}
void dfs(int u)
{
for(int i=head[u];i;i=a[i].nex)
{
int v=a[i].to;
dfs(v);
dp[u][1]|=dp[v][1];
dp[u][0]|=dp[v][0];
cnt[u][1]+=cnt[v][1];
cnt[u][0]+=cnt[v][0];
}
}
void solve()
{
scanf("%s%s",s+1,t+1);
n=strlen(s+1);m=strlen(t+1);
init();
F(i,1,n) extend(s[i],0);
extend('{',0);
F(i,1,m) extend(t[i],1);
F(i,1,siz-1) add(st[i].link,i);
dfs(0);
long long ans=0;
F(i,1,siz-1) if(dp[i][0]&&dp[i][1])
{
int l=st[st[i].link].len+1,r=st[i].len;
ans+=1ll*(r-l+1)*cnt[i][0]*cnt[i][1];
}
printf("%lld",ans);
}
int main()
{
int T=1;
// T=read();
while(T--) solve();
return 0;
}

浙公网安备 33010602011771号