cf 452e Three Things
给出三个字符串 \(s_1,s_2,s_3\) ,对于每一个 \(1\leq l\leq \min(|s_1|,|s_2|,|s_3|)\) , 求 \(s_k[i_k\cdots i_k+l-1],(k=1,2,3)\) 都两两相等的 \((i_1,i_2,i_3)\) 个数. 模 \(10^9+7\) .
\(3\leq |s_1|+|s_2|+|s_3|\leq 3\cdot 10^5\)
如果 \(ans(i)\) 表示长度为 \(i\) 时三维组的个数.
那么必定有 \(ans(\min(|s_1|,|s_2|,|s_3|))\geq ans(n-1)\geq \cdots \geq ans(1)\).
考虑用 sam 解决这个问题.
将 \(s_1+\)'#'\(+s_2+\)'#'\(+s_3\) 建在一棵 sam 树上.
\(s_1\) , \(s_2\) , \(s_3\) 中每个位置 \(i\) 为结尾的字符串都会对应这 sam 上一个节点和以及此点的祖先.
可以用 \(cnt(x,i)\) 表示在 sam 中 \(s_i\) 有在节点 \(x\) 上的字符串.
接着,发现 \(cnt(x,i)\) 中统计的位置对应 \(x\) 都是最大的长度,它们的长度其实都可以缩短.
于是,考虑 sam 中的节点从叶节点到根加起来.
现在的 \(cnt\) 数组统计的都是 sam 中可以为长度 \(st(i).len\) 到 \(st(st(i).link).len+1\) 的.
时间复杂度:\(O(|s_1|+|s_2|+|s_3|)\)
空间复杂度: \(O(|s_1|+|s_2|+|s_3|)\)
#include<bits/stdc++.h>
using namespace std;
class state{
public:
int len,link;
map<char,int>nxt;
}st[300010*2];
int lst=0,sz=0;
void init(){
st[0].link=-1;
st[0].len=0;
sz++;
}
int cnt[3][300010*2];
void extend(char c,int id){
int cur=sz++;
if(id>=0){
cnt[id][cur]++;
}
st[cur].len=st[lst].len+1;
int p=lst;
while(p!=-1&&st[p].nxt.find(c)==st[p].nxt.end()){
st[p].nxt[c]=cur;
p=st[p].link;
}
// if(cur==13)cout<<c<<","<<p<<","<<st[p].nxt[c]<<endl;
if(p==-1){
st[cur].link=0;
}
else{
int q=st[p].nxt[c];
if(st[q].len==st[p].len+1){
st[cur].link=q;
}
else{
int nq=sz++;
st[nq].link=st[q].link;
st[nq].nxt=st[q].nxt;
st[nq].len=st[p].len+1;
while(p!=-1&&st[p].nxt[c]==q){
st[p].nxt[c]=nq;
p=st[p].link;
}
st[q].link=st[cur].link=nq;
}
}
lst=cur;
}
const int mod=1e9+7;
string s[3];
vector<int>child[300010*2];
void dfs(int x){
// cout<<x<<endl;
for(int i=0;i<(int)child[x].size();i++){
int to=child[x][i];
dfs(to);
cnt[0][x]+=cnt[0][to];
cnt[1][x]+=cnt[1][to];
cnt[2][x]+=cnt[2][to];
}
}
inline void upd(int &x,int y){
x+=y;x%=mod;
}
int ans[300010];
int main(){
cin>>s[0]>>s[1]>>s[2];
init();
for(int i=0;i<3;i++){
for(int j=0;j<(int)s[i].size();j++)extend(s[i][j],i);
if(i<2)extend('#',-1);
}
// for(int i=0;i<sz;i++)cout<<st[i].len<<" ";cout<<endl;
for(int i=0;i<sz;i++)if(st[i].link!=-1)child[st[i].link].push_back(i);
dfs(0);
for(int i=1;i<sz;i++){
int l=st[st[i].link].len+1,r=st[i].len;
// cout<<l<<","<<r<<","<<cnt[0][i]<<","<<cnt[1][i]<<","<<cnt[2][i]<<endl;
upd(ans[l],1ll*cnt[0][i]*cnt[1][i]%mod*cnt[2][i]%mod);
upd(ans[r+1],(mod-1ll*cnt[0][i]*cnt[1][i]%mod*cnt[2][i]%mod)%mod);
}
for(int i=1;i<=min((int)s[0].size(),min((int)s[1].size(),(int)s[2].size()));i++){
ans[i]=(ans[i]+ans[i-1])%mod;
cout<<ans[i]<<" ";
}
cout<<endl;
return 0;
}
/*inline? ll or int? size? min max?*/

浙公网安备 33010602011771号