cf 452e Three Things

给出三个字符串 \(s_1,s_2,s_3\) ,对于每一个 \(1\leq l\leq \min(|s_1|,|s_2|,|s_3|)\) , 求 \(s_k[i_k\cdots i_k+l-1],(k=1,2,3)\) 都两两相等的 \((i_1,i_2,i_3)\) 个数. 模 \(10^9+7\) .

\(3\leq |s_1|+|s_2|+|s_3|\leq 3\cdot 10^5\)

如果 \(ans(i)\) 表示长度为 \(i\) 时三维组的个数.

那么必定有 \(ans(\min(|s_1|,|s_2|,|s_3|))\geq ans(n-1)\geq \cdots \geq ans(1)\).

考虑用 sam 解决这个问题.

\(s_1+\)'#'\(+s_2+\)'#'\(+s_3\) 建在一棵 sam 树上.

\(s_1\) , \(s_2\) , \(s_3\) 中每个位置 \(i\) 为结尾的字符串都会对应这 sam 上一个节点和以及此点的祖先.

可以用 \(cnt(x,i)\) 表示在 sam 中 \(s_i\) 有在节点 \(x\) 上的字符串.

接着,发现 \(cnt(x,i)\) 中统计的位置对应 \(x\) 都是最大的长度,它们的长度其实都可以缩短.

于是,考虑 sam 中的节点从叶节点到根加起来.

现在的 \(cnt\) 数组统计的都是 sam 中可以为长度 \(st(i).len\)\(st(st(i).link).len+1\) 的.

时间复杂度:\(O(|s_1|+|s_2|+|s_3|)\)

空间复杂度: \(O(|s_1|+|s_2|+|s_3|)\)

#include<bits/stdc++.h>
using namespace std;
class state{
public:
	int len,link;
	map<char,int>nxt;
}st[300010*2];
int lst=0,sz=0;
void init(){
	st[0].link=-1;
	st[0].len=0;
	sz++;
}
int cnt[3][300010*2];
void extend(char c,int id){
	int cur=sz++;
	if(id>=0){
		cnt[id][cur]++;
	}
	st[cur].len=st[lst].len+1;
	int p=lst;
	while(p!=-1&&st[p].nxt.find(c)==st[p].nxt.end()){
		st[p].nxt[c]=cur;
		p=st[p].link;
	}
//	if(cur==13)cout<<c<<","<<p<<","<<st[p].nxt[c]<<endl;
	if(p==-1){
		st[cur].link=0; 
	}
	else{
		int q=st[p].nxt[c];
		if(st[q].len==st[p].len+1){
			st[cur].link=q;
		}
		else{
			int nq=sz++;
			st[nq].link=st[q].link;
			st[nq].nxt=st[q].nxt;
			st[nq].len=st[p].len+1;
			while(p!=-1&&st[p].nxt[c]==q){
				st[p].nxt[c]=nq;
				p=st[p].link;
			}
			st[q].link=st[cur].link=nq;
		}
	}
	lst=cur;
}
const int mod=1e9+7;
string s[3];
vector<int>child[300010*2];
void dfs(int x){
//	cout<<x<<endl;
	for(int i=0;i<(int)child[x].size();i++){
		int to=child[x][i];
		dfs(to);
		cnt[0][x]+=cnt[0][to];
		cnt[1][x]+=cnt[1][to];
		cnt[2][x]+=cnt[2][to];
	}
}
inline void upd(int &x,int y){
	x+=y;x%=mod;
}
int ans[300010];
int main(){
	cin>>s[0]>>s[1]>>s[2];
	init();
	for(int i=0;i<3;i++){
		for(int j=0;j<(int)s[i].size();j++)extend(s[i][j],i);
		if(i<2)extend('#',-1);
	}
//	for(int i=0;i<sz;i++)cout<<st[i].len<<" ";cout<<endl;
	for(int i=0;i<sz;i++)if(st[i].link!=-1)child[st[i].link].push_back(i);
	dfs(0);
	for(int i=1;i<sz;i++){
		int l=st[st[i].link].len+1,r=st[i].len;
	//	cout<<l<<","<<r<<","<<cnt[0][i]<<","<<cnt[1][i]<<","<<cnt[2][i]<<endl;
		upd(ans[l],1ll*cnt[0][i]*cnt[1][i]%mod*cnt[2][i]%mod);
		upd(ans[r+1],(mod-1ll*cnt[0][i]*cnt[1][i]%mod*cnt[2][i]%mod)%mod); 
	}
	
	for(int i=1;i<=min((int)s[0].size(),min((int)s[1].size(),(int)s[2].size()));i++){
		ans[i]=(ans[i]+ans[i-1])%mod;
		cout<<ans[i]<<" ";
	}
	cout<<endl;
	return 0;
}
/*inline? ll or int? size? min max?*/

posted @ 2021-07-08 23:54  xyangh  阅读(42)  评论(0)    收藏  举报