【CF587F】Duff is Mad

题目

题目链接:https://codeforces.com/problemset/problem/587/F
给定 \(n\) 个字符串 \(s_{1 \dots n}\)\(q\) 次询问 \(s_{l \dots r}\)\(s_k\) 中出现了多少次。
\(n,q,\sum_{i=1}^n |s_i| \le 10^5\)

思路

看上去就很根号分治。取 \(M=350\),分别考虑当 \(|s_k|\leq M\) 以及 \(|s_k|>M\) 时分别怎么做。

\(|s_k|\leq M\) 时,我们可以建出所有字符串的 AC 自动机,我们知道求字符串 \(t\)\(s\) 中的出现次数,可以遍历 \(s\) 在 AC 自动机上的所有点,并判断每一个点在 fail 树上的祖先是否有 \(t\) 结尾的点。如果有出现次数就 \(+1\)
所以我们将询问拆成两个前缀和相减的形式,按照右端点排序。依次枚举所有的串,把这个串结尾的节点在 fail 树上的子树全部 \(+1\),然后枚举所有右端点为当前加入的串的询问,在 AC 自动机上找到这个串的路径上的节点,并对路径上的点求和即可。
把 fail 树按照 dfs 序编号,需要支持区间加一和单点查询。用树状数组即可。因为每次跳的节点数 \(\leq M\),所以复杂度为 \(O(QM\log len)\)。其中 \(len=\sum_{i=1}^n |s_i|\)。事实上采用分块可以省掉那个 \(\log\)

\(|s_k|>M\) 时,这样的串的数量不超过 \(\frac{len}{M}\) 个。我们对于相同的 \(k\) 同时求解。
依然建出 AC 自动机,把根到 \(s_k\) 的路径上的点权值都设为 \(1\)。然后对于一组询问 \(s_{l\cdots r}\),答案就是 \(l\sim r\) 的串在 fail 树上子树内的权值之和。那么就直接用前缀和记一下就可以了。时间复杂度 \(O(Q+\frac{n}{M})\)

代码

#include <bits/stdc++.h>
#define end ayfiubcyfiluwebyi
using namespace std;
typedef long long ll;

const int N=100010,M=350;
int n,m,Q1,Q2,end[N],pos[N],id[N],siz[N];
ll ans[N],sum[N];
char s[N],t[N];

struct node
{
	int l,r,k,id;
}q1[N*2],q2[N];

bool cmp1(node x,node y)
{
	return x.l<y.l;
}

bool cmp2(node x,node y)
{
	return x.k<y.k;
}

struct ACA
{
	int tot,fa[N],ch[N][26],fail[N];
	vector<int> e[N];
	
	void insert(char *s,int j)
	{
		int len=strlen(s+1),p=0;
		for (int i=1;i<=len;i++)
		{
			if (!ch[p][s[i]-'a']) ch[p][s[i]-'a']=++tot;
			fa[ch[p][s[i]-'a']]=p; p=ch[p][s[i]-'a'];
		}
		end[j]=p;
	}
	
	void build()
	{
		queue<int> q;
		for (int i=0;i<26;i++)
			if (ch[0][i]) q.push(ch[0][i]);
		while (q.size())
		{
			int u=q.front(); q.pop();
			e[fail[u]].push_back(u);
			for (int i=0;i<26;i++)
				if (ch[u][i]) q.push(ch[u][i]),fail[ch[u][i]]=ch[fail[u]][i];
					else ch[u][i]=ch[fail[u]][i];
		}
	}
	
	void dfs(int x)
	{
		id[x]=++tot; siz[x]=1;
		for (int i=0;i<e[x].size();i++)
		{
			int v=e[x][i];
			dfs(v); siz[x]+=siz[v];
		}
	}
}AC;

struct BIT
{
	ll c[N];
	
	void add(int x,ll v)
	{
		for (int i=x;i<=AC.tot;i+=i&-i)
			c[i]+=v;
	}
	
	ll query(int x)
	{
		ll ans=0;
		for (int i=x;i;i-=i&-i)
			ans+=c[i];
		return ans;
	}
}bit;

int main()
{
	scanf("%d%d",&n,&m);
	for (int i=1;i<=n;i++)
	{
		scanf("%s",t+1);
		AC.insert(t,i);
		int len=strlen(t+1); pos[i]=Q1+1;
		for (int j=1;j<=len;j++) s[++Q1]=t[j];
	}
	pos[n+1]=Q1+1; Q1=0;
	for (int i=1,l,r,x;i<=m;i++)
	{
		scanf("%d%d%d",&l,&r,&x);
		if (pos[x+1]-pos[x]<=M)
			q1[++Q1]=(node){r,1,x,i},q1[++Q1]=(node){l-1,-1,x,i};
		else
			q2[++Q2]=(node){l,r,x,i};
	}
	AC.build();
	AC.tot=0; AC.dfs(0);
	sort(q1+1,q1+1+Q1,cmp1);
	for (int i=0,j=1;i<=n;i++)
	{
		if (i)
		{
			bit.add(id[end[i]],1);
			bit.add(id[end[i]]+siz[end[i]],-1);
		}
		for (;j<=Q1 && q1[j].l==i;j++)
			for (int p=end[q1[j].k];p;p=AC.fa[p])
				ans[q1[j].id]+=q1[j].r*bit.query(id[p]);
	}
	sort(q2+1,q2+1+Q2,cmp2);
	for (int i=1,k=1;i<=n;i++)
		if (pos[i+1]-pos[i]>M)
		{
			memset(bit.c,0,sizeof(bit.c));
			for (int p=end[i];p;p=AC.fa[p])
				bit.add(id[p],1);
			for (int j=1;j<=n;j++)
				sum[j]=sum[j-1]+bit.query(id[end[j]]+siz[end[j]]-1)-bit.query(id[end[j]]-1);
			for (;k<=Q2 && q2[k].k==i;k++)
				ans[q2[k].id]=sum[q2[k].r]-sum[q2[k].l-1];
		}
	for (int i=1;i<=m;i++)
		cout<<ans[i]<<"\n";
	return 0;
}
posted @ 2021-07-19 16:21  stoorz  阅读(63)  评论(0编辑  收藏  举报