【XSY3320】string AC自动机 哈希 点分治

题目大意

  给一棵树,每条边上有一个字符,求有多少对 \((x,y)(x<y)\),满足 \(x\)\(y\) 路径上的边上的字符按顺序组成的字符串为回文串。

  \(1\leq n\leq 50000,1\leq x_i,y_i\leq n,z_i\in\{0,1\}\)

题解

  观察一条经过重心的回文串是长什么样的

  \(S\) 是一个任意的字符串,\(T\) 是一个回文串。

  建出根到每个节点对应的串的AC自动机。

  那么 \(x\) 这边的 \(S\) 串就是 \(x\) 对应的AC自动机节点的一个后缀, \(T\) 串是一个前缀。

  dfs 整棵树的 fail 树,先统计每个点作为 \(x\) 点的贡献,再把作为 \(y\) 点的贡献加到数据结构中。

  开 \(\sqrt n\) 个长度为 \(\sqrt n\) 的数组 \(c_{1,\sqrt n}\)\(c_{i,j}\) 表示当前节点有多少个长度 \(\bmod i=j\) 的祖先。

  当一个点是 \(y\) 点的时候,令对应长度的字符串的出现次数 \(+1\),还要对于 \(\leq \sqrt n\) 的所有数 \(i\),令 \(c_{i,\lvert S \rvert \bmod i}++\)

  当一个点是 \(x\) 点的时候,一个回文串的所有回文前缀可以被表示为 \(O(\log n)\) 个等差数列,公差 \(\leq \sqrt n\) 的那部分在 \(c\) 里面查,剩下的暴力查就好了。

  记一个等差数列的首项为 \(a_1\),公差为 \(d\),末项为 \(a_n\),那么贡献就是 dfs 到深度为 \(a_n\) 的点时 \(c_{d,a_1\bmod d}\) 的值减掉 dfs 到深度为 \(a_1-d\) 的点时 \(c_{d,a_1\bmod d}\) 的值。

  先 dfs 一遍把所有询问的信息插到 vector 中,再 dfs 一遍计算答案。

  求一个串的所有回文前缀可以直接哈希。

  时间复杂度:\(f(n)=O(n^\frac{3}{2})+O(n\log^2 n)=O(n^\frac{3}{2})\)

  \(T(n)=2T(\frac{n}{2})+f(n)=2T(\frac{n}{2})+O(n^\frac{3}{2})=O(n^\frac{3}{2})\)

代码

  把这份代码中的后缀自动机换成 AC自动机,回文自动机换成哈希就好了。

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
#include<functional>
#include<cmath>
#include<vector>
#include<queue>
#include<assert.h>
//using namespace std;
using std::min;
using std::max;
using std::swap;
using std::sort;
using std::reverse;
using std::random_shuffle;
using std::lower_bound;
using std::upper_bound;
using std::unique;
using std::vector;
using std::queue;
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
typedef std::pair<int,int> pii;
typedef std::pair<ll,ll> pll;
void open(const char *s){
#ifndef ONLINE_JUDGE
	char str[100];sprintf(str,"%s.in",s);freopen(str,"r",stdin);sprintf(str,"%s.out",s);freopen(str,"w",stdout);
#endif
}
void open2(const char *s){
#ifdef DEBUG
	char str[100];sprintf(str,"%s.in",s);freopen(str,"r",stdin);sprintf(str,"%s.out",s);freopen(str,"w",stdout);
#endif
}
int rd(){int s=0,c,b=0;while(((c=getchar())<'0'||c>'9')&&c!='-');if(c=='-'){c=getchar();b=1;}do{s=s*10+c-'0';}while((c=getchar())>='0'&&c<='9');return b?-s:s;}
void put(int x){if(!x){putchar('0');return;}static int c[20];int t=0;while(x){c[++t]=x%10;x/=10;}while(t)putchar(c[t--]+'0');}
int upmin(int &a,int b){if(b<a){a=b;return 1;}return 0;}
int upmax(int &a,int b){if(b>a){a=b;return 1;}return 0;}
const int N=50010;
vector<pii> g[N];
int sz[N];
int totsz,rt,rtsz;
int b[N];
int n;
int f[N];
ll* ss[N];
ll ss2[N];
ll ans=0;
int _log[N];
struct info
{
	int x;
	int y;
	int z;
	info(int a=0,int b=0,int c=0):x(a),y(b),z(c){}
};
int cmp(info a,info b)
{
	if(a.x!=b.x)
		return a.x<b.x;
	return a.z<b.z;
}
void dfs1(int x,int fa)
{
	sz[x]=1;
	for(auto v:g[x])
		if(v.first!=fa&&!b[v.first])
		{
			dfs1(v.first,x);
			sz[x]+=sz[v.first];
		}
}
void dfs2(int x,int fa)
{
	int mx=totsz-sz[x];
	for(auto v:g[x])
		if(v.first!=fa&&!b[v.first])
		{
			dfs2(v.first,x);
			mx=max(mx,sz[v.first]);
		}
	if(mx<rtsz)
	{
		rtsz=mx;
		rt=x;
	}
}
void dfs3(int x,int fa)
{
	f[x]=fa;
	for(auto v:g[x])
		if(v.first!=fa&&!b[v.first])
			dfs3(v.first,x);
}
int tot;
int str[N];
namespace sam
{
	int next[2*N][2];
	int fail[2*N];
	int len[2*N];
	int last,cnt;
	int b[2*N];
	int a[2*N][2];
	int s[2*N]; 
	void init()
	{
		while(cnt)
		{
			next[cnt][0]=next[cnt][1]=0;
			a[cnt][0]=a[cnt][1]=0;
			b[cnt]=0;
			s[cnt]=0;
			cnt--;
		}
		cnt=1;
		last=1;
	}
	int insert(int p,int c)
	{
		if(next[p][c])
		{
			last=next[p][c];
			s[last]++;
			return last;
		}
//		int p=last;
		int np=++cnt;
		len[np]=len[p]+1;
		s[np]=1;
		for(;p&&!next[p][c];p=fail[p])
			next[p][c]=np;
		if(!p)
			fail[np]=1;
		else
		{
			int q=next[p][c];
			if(len[q]==len[p]+1)
				fail[np]=q;
			else
			{
				int nq=++cnt;
				len[nq]=len[p]+1;
				memcpy(next[nq],next[q],sizeof next[q]);
				fail[nq]=fail[q];
				fail[q]=fail[np]=nq;
				for(;p&&next[p][c]==q;p=fail[p])
					next[p][c]=nq;
			}
		}
		return last=np;
	}
}
namespace pam
{
	int next[N][2];
	int trans[N][2];
	int fail[N];
	int len[N];
	int diff[N];
	int link[N];
	int top[N];
	int last;
	int cnt;
	void init()
	{
		while(cnt>=0)
		{
			next[cnt][0]=next[cnt][1]=0;
			trans[cnt][0]=trans[cnt][1]=0;
			cnt--;
		}
		cnt=1;
		str[0]=-1;
		fail[0]=1;
		fail[1]=0;
		len[0]=0;
		len[1]=-1;
		last=0;
		link[0]=0;
		diff[0]=1;
		diff[1]=0;
		top[0]=0;
		top[1]=1;
		trans[0][0]=trans[0][1]=trans[1][0]=trans[1][1]=1;
	}
	int find(int x,int c)
	{
		return str[tot-len[x]-1]==c?x:trans[x][c];
	}
	void insert(int c)
	{
		str[++tot]=c;
		last=find(last,c);
		int now=last;
		if(!next[now][c])
		{
			int cur=++cnt;
			len[cur]=len[now]+2;
			last=find(fail[last],c);
			fail[cur]=next[last][c];
			diff[cur]=len[cur]-len[fail[cur]];
			if(diff[cur]==diff[fail[cur]])
			{
				link[cur]=link[fail[cur]];
				top[cur]=top[fail[cur]];
			}
			else
			{
				link[cur]=fail[cur];
				top[cur]=cur;
			}
			if(!link[cur])
				link[cur]=cur;
			memcpy(trans[cur],trans[fail[cur]],sizeof trans[cur]);
			trans[cur][str[tot-len[fail[cur]]]]=fail[cur];
			next[now][c]=cur;
		}
		last=next[now][c];
	}
}
namespace trie
{
	int a[N][2];
	int s[N];
	int cnt;
	void clear()
	{
		while(cnt)
		{
			a[cnt][0]=a[cnt][1]=0;
			s[cnt]=0;
			cnt--;
		}
		cnt=1;
	}
}
ll s,s2;
int pos[N];
int pos2[N];
int pos3[N];
int pos4[N];
int q[N];
int len[N],id[N],top;
int head,tail;
vector<int> e[2*N];
int sq;
vector<info> h[2*N];
int orzzjt,orzzjt2;
void bfs(int x)
{
	sam::init();
//	sam::s[1]=1;
	pos[x]=1;
	head=1;
	tail=0;
	q[++tail]=x;
	trie::clear();
	pos4[x]=1;
	while(tail>=head)
	{
		int y=q[head++];
		s+=trie::s[pos4[y]];
		trie::s[pos4[y]]++;
		for(auto v:g[y])
			if(!b[v.first]&&v.first!=f[y])
			{
				pos[v.first]=sam::insert(pos[y],v.second);
				q[++tail]=v.first;
				if(trie::a[pos4[y]][v.second])
					pos4[v.first]=trie::a[pos4[y]][v.second];
				else
					pos4[v.first]=trie::a[pos4[y]][v.second]=++trie::cnt;
			}
	}
}
void dfs(int x,int fa)
{
	for(int y=pos[x];y!=1&&!sam::b[y];y=sam::fail[y])
	{
		sam::a[sam::fail[y]][str[tot-sam::len[sam::fail[y]]]]=y;
		sam::b[y]=1;
	}
	//这样建出来的后缀树不是完整的,但已经够用了 
	
	int now=pam::last;
	pos2[x]=now;
	if(pam::len[now]==tot)
	{
		if(fa)
			s2++;
		pos3[x]=now;
	}
	else
		pos3[x]=pos3[fa];
	for(auto v:g[x])
		if(!b[v.first]&&v.first!=fa)
		{
			pam::last=now;
			pam::insert(v.second);
			dfs(v.first,x);
			tot--;
		}
}
void dfs4(int x)
{
	len[++top]=sam::len[x];
	id[top]=x;
	for(auto v:e[x])
		for(int y=pos3[v];y>1;)
			if(pam::diff[y]<=sq)
			{
				h[id[lower_bound(len+1,len+top+1,sam::len[x]-pam::len[y]-pam::diff[y])-len]].push_back(info(sam::len[x]-pam::len[y]-pam::diff[y],pam::diff[y],-1));
				h[id[lower_bound(len+1,len+top+1,sam::len[x]-pam::len[pam::link[y]])-len]].push_back(info(sam::len[x]-pam::len[pam::link[y]],pam::diff[y],1));
				//h.push_back(info(sam::len[x]-pam::len[y],id[lower_bound(len+1,len+top+1,sam::len[x]-pam::len[y])-len],1));
//				h.push_back(info(sam::len[x]-pam::len[pam::link[y]]+pam::diff[y],id[lower_bound(len+1,len+top+1,sam::len[x]-pam::len[pam::link[y]]+pam::diff[y])-len],-1));
				y=pam::fail[pam::link[y]];
				orzzjt2+=_log[top];
			}
			else
			{
				y=pam::fail[y];
			}
	if(sam::a[x][0])
		dfs4(sam::a[x][0]);
	if(sam::a[x][1])
		dfs4(sam::a[x][1]);
	top--;
}
void dfs5(int x)
{
	for(auto v:h[x])
		if(v.x>=0&&v.x!=sam::len[x])
			s+=ss[v.y][v.x%v.y]*v.z;
	orzzjt+=sq;
	for(int i=1;i<=sq;i++)
		ss[i][sam::len[x]%i]+=sam::s[x];
	ss2[sam::len[x]]+=sam::s[x];
	
	
	for(auto v:h[x])
		if(v.x>=0&&v.x==sam::len[x])
			s+=ss[v.y][v.x%v.y]*v.z;
			
			
	for(auto v:e[x])
		for(int y=pos3[v];y>1;)
			if(pam::diff[y]<=sq)
			{
				y=pam::fail[pam::link[y]];
			}
			else
			{
				s+=ss2[sam::len[x]-pam::len[y]];
				y=pam::fail[y];
			}
			
	if(sam::a[x][0])
		dfs5(sam::a[x][0]);
	if(sam::a[x][1])
		dfs5(sam::a[x][1]);
	
		
	for(int i=1;i<=sq;i++)
		ss[i][sam::len[x]%i]-=sam::s[x];
	ss2[sam::len[x]]-=sam::s[x];
}
ll calc(int x)
{
	s=0;
	s2=0;
	bfs(x);
	pam::init();
	dfs(x,0);
	for(int i=1;i<=sam::cnt;i++)
	{
		e[i].clear();
		h[i].clear();
	}
	for(int i=1;i<=tail;i++)
		e[pos[q[i]]].push_back(q[i]);
	dfs4(1);
//	for(int i=1;i<=sam::cnt;i++)
//		sort(h[i].begin(),h[i].end());
	dfs5(1);
	return s;
}
int c[N],c2[N];
int t;
vector<pii> g2;
void solve(int x)
{
	dfs1(x,0);
	totsz=sz[x];
	rtsz=0x7fffffff;
	dfs2(x,0);
	x=rt;
	dfs3(x,0);
	int t=0;
	sq=sqrt(totsz);
//	sq=0;
	ans+=calc(x);
	ans+=s2;
	for(auto v:g[x])
		if(!b[v.first])
		{
			b[v.first]=1;
			c[++t]=v.first;
			c2[t]=v.second;
		}
	g2=g[x];
	g[x].clear();
	for(int i=1;i<=t;i++)
	{
		b[c[i]]=0;
		g[x].clear();
		g[x].push_back(pii(c[i],c2[i]));
		ans-=calc(x);
		b[c[i]]=1;
	}
	g[x]=g2;
	for(int i=1;i<=t;i++)
		b[c[i]]=0;
	b[x]=1;
	for(auto v:g[x])
		if(!b[v.first])
			solve(v.first);
}
int main()
{
	open("string");
	scanf("%d",&n);
	for(int i=1;i<=n;i++)
		for(int j=1,k=0;j<=n;j<<=1,k++)
			_log[i]=k;
	int _sqrt=sqrt(n);
	for(int i=1;i<=_sqrt;i++)
	{
		ss[i]=new ll[i];
		for(int j=0;j<i;j++)
			ss[i][j]=0;
	}
	int x,y,z;
	for(int i=1;i<n;i++)
	{
		scanf("%d%d%d",&x,&y,&z);
		g[x].push_back(pii(y,z));
		g[y].push_back(pii(x,z));
	}
	solve(1);
//	assert(ans%2==0);
//	ans/=2;
	printf("%lld\n",ans);
//	printf("%d\n",orzzjt);
//	printf("%d\n",orzzjt2);
	return 0;
}
posted @ 2019-01-08 20:20  ywwyww  阅读(511)  评论(0编辑  收藏  举报