【洛谷P3449】PAL-Palindromes

题目

题目链接:https://www.luogu.com.cn/problem/P3449
Johnny 喜欢玩文字游戏。
他写下了 \(n\) 个回文串,随后将这些串两两组合,合并成一个新串。容易看出,一共会有 \(n^2\) 个新串。
两个串组合时顺序是任意的,即 ab 可以组合成 abba,另外自己和自己组合也是允许的。
现在他想知道这些新串中有多少个回文串,你能帮帮他吗?

思路

不难发现,一个长度较短的串\(s1\)与长度较长的串\(s2\)拼接起来能是一个回文串,必须满足以下两点:

  1. \(s1\)是将\(s2\)颠倒过来的串\(s2'\)的前缀。
  2. \(s1\)长度为\(len1\)\(s2'\)长度为\(len2\),那么\(s2'\)\([len1+1,len2]\)所构成的子串必须是一个回文串,其实也就是\(s2\)\([1,len2-len1]\)必须是回文串。

那么我们把所有的串用\(vector\)存起来,然后全部插入一棵\(Trie\)里,再枚举每一个串,将它倒过来在\(Trie\)里寻找。
如果找到某一个时刻有一个串与这个倒过来的串相匹配,那么说明这两个串满足了条件\((1)\)。如果再满足剩余部分是一个回文串,那么这两个串就对答案有贡献。
判断是否是回文串用\(hash\)就可以了。正序做一遍\(hash\),倒序做一遍\(hash\),然后判断一个区间是否回文就把这个区间的前半部分的正序\(hash\)值与后半部分的倒序\(hash\)值相比较即可。
由于这道题给出的串全部回文,所以每一个可行方案倒过来也是一种可行方案,所以\(ans\)要乘\(2\),又因为我们没有计算自己与自己匹配,所以最终答案是\(2ans+n\)
时间复杂度\(O(n)\)

\(update:\)其实可以不用倒过来的。。。因为给出的串都是回文串,倒过来的前缀其实还是原来的前缀。。。当时傻掉了。

代码

#include <string>
#include <cstdio>
#include <vector>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;

const int N=2000010;
const ull base=131;
int n,maxn,len[N];
ll ans;
char Ch;
ull hash[N][3],power[N];
vector<string> ch;
string s;

struct Trie
{
	int tot,trie[N][27];
	bool end[N];
	
	void find(int i,int j)
	{
		int p=1;
		for (;j>=0;j--)
		{
			if (end[p]==1)
			{
				int mid=(1+j+1)>>1,f=j&1;
				if (hash[mid][1]==hash[mid+f][2]-hash[j+2][2]*power[j+2-mid-f]) ans++;
			}
			if (trie[p][ch[i][j]-'a'+1]) p=trie[p][ch[i][j]-'a'+1];
				else return;
		}
	}
	
	void update(int i)
	{
		int p=1;
		for (register int j=0;j<len[i];j++)
		{
			if (!trie[p][ch[i][j]-'a'+1])
				trie[p][ch[i][j]-'a'+1]=++tot;
			p=trie[p][ch[i][j]-'a'+1];
		}
		end[p]=1;
	}
}trie;

int main()
{
	scanf("%d",&n);
	ch.push_back("WYC AK IOI");
	for (register int i=1;i<=n;i++)
	{
		scanf("%d",&len[i]);
		s="";
		for (register int j=1;j<=len[i];j++)
		{
			while (Ch=getchar()) if (Ch>='a' && Ch<='z') break;
			s+=Ch;
		}
		ch.push_back(s);
		maxn=max(maxn,len[i]);
	}
	power[0]=1;
	for (register int i=1;i<=maxn;i++)
		power[i]=power[i-1]*base;
	trie.tot=1;
	for (register int i=1;i<=n;i++)
		trie.update(i);
	for (int i=1;i<=n;i++)
	{
		hash[0][1]=hash[len[i]+1][2]=0;
		for (register int j=1;j<=len[i];j++)
			hash[j][1]=hash[j-1][1]*base+ch[i][j-1]-'a'+1;
		for (register int j=len[i];j>=1;j--)
			hash[j][2]=hash[j+1][2]*base+ch[i][j-1]-'a'+1;
		trie.find(i,len[i]-1);
	}
	printf("%lld\n",ans*2+n);
	return 0;
}
posted @ 2020-01-24 10:05  stoorz  阅读(160)  评论(0编辑  收藏  举报