BZOJ3160 万径人踪灭

本文版权归ljh2000和博客园共有,欢迎转载,但须保留此声明,并给出原文链接,谢谢合作。

 

 

本文作者:ljh2000
作者博客:http://www.cnblogs.com/ljh2000-jump/
转载请注明出处,侵权必究,保留最终解释权!

 

题目链接:BZOJ3160

 

正解:FFT+manacher

解题报告:

  参考博客:戳这里

   题目求的是一个字符串的不连续回文子序列个数。

  考虑用所有的回文子序列个数$-$连续回文子序列就是答案。

  求连续回文子序列的个数只需要跑一遍$manacher$,然后得到以每个点为对称中心的$p$数组之后,可以直接统计出答案。

  回文子序列的个数似乎不好考虑,我们不妨考虑以每个地方(包括间隔)为对称点的回文子序列个数。

  我们如果知道了两边对应位置相等的个数有$x$个,根据二项式定理$C(n,1)+C(n,2)+C(n,3)+…+C(n,n)=2^n-1$,所以答案就是$2^x-1$。

  而$a$、$b$是彼此独立的,所以我们可以分别考虑$a$和$b$。

  我们设出一个多项式,若这一位是$a$那么系数就是$1$,容易发现把这个多项式平方之后,$i$项对应的系数就是以$i$为对称中心的相等的$a$的个数。

  因为我一直写的是递归版的$FFT$,然后被卡常了...

  拖了一个非递归版的$FFT$就愉快地$AC$了。

 

 

//It is made by ljh2000
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <algorithm>
#include <ctime>
#include <vector>
#include <queue>
#include <map>
#include <set>
#include <string>
#include <complex>
using namespace std;
typedef long long LL;
typedef complex<double> C;
const int MOD = 1000000007;
const double pi = acos(-1);
const int MAXN = 300011;
int n,L,f[MAXN],mx,pos,m,p[MAXN];
char ch[MAXN],s[MAXN];
C a[MAXN],b[MAXN],aa[MAXN],bb[MAXN];
int ans[MAXN],tot,out,er[MAXN],R[MAXN];
//ans[i]表示以i为对称中心的两边的对称字符数量(包含i)
inline int getint(){
    int w=0,q=0; char c=getchar(); while((c<'0'||c>'9') && c!='-') c=getchar();
    if(c=='-') q=1,c=getchar(); while (c>='0'&&c<='9') w=w*10+c-'0',c=getchar(); return q?-w:w;
}

inline LL fast_pow(LL x,LL y){
	LL r=1;
	while(y>0) {
		if(y&1) r*=x,r%=MOD;
		x*=x; x%=MOD;
		y>>=1;
	}
	return r;
}

inline void fft(C *a,int n,int f){
    for(int i=0;i<n;i++) if(i<R[i]) swap(a[i],a[R[i]]);//交换位置
    for(int i=1;i<n;i<<=1){//待合并区间长度
        C wn(cos(pi/i),sin(f*pi/i)),x,y;//这里就不用再*2了,因为合并后的区间长度是i的两倍
        for(int j=0;j<n;j+=i<<1){//起始位置
            C w(1,0);
            for(int k=0;k<i;k++,w*=wn){//第k个
                x=a[j+k];y=w*a[j+i+k];
                a[j+k]=x+y;
                a[j+i+k]=x-y;
            }
        }
    }
}

inline LL manacher(){
	pos=0; mx=0; s[0]='%'; s[1]='#'; m=1;
	for(int i=0;i<n;i++) s[++m]=ch[i],s[++m]='#';
	for(int i=1;i<=m;i++) {
		if(i<mx) p[i]=min(p[2*pos-i],mx-i); else p[i]=1;
		for(;i+p[i]<=m/*!!!*/ && s[i+p[i]]==s[i-p[i]];p[i]++);
		if(i+p[i]>mx) { mx=i+p[i]; pos=i; }
		tot+=p[i]/2;
		tot%=MOD;//一个回文串的贡献
	}
	return tot;
}

inline void work(){
	scanf("%s",ch); n=strlen(ch); int N=n<<1,ll=0;
	for(int i=0;i<=N;i++) er[i]=fast_pow(2,i);

	for(int i=0;i<n;i++) if(ch[i]=='a') a[i]=b[i]=1;
	for(L=1;L<=N;L<<=1) ll++; for(int i=0;i<L;i++) R[i]=(R[i>>1]>>1)|((i&1)<<(ll-1));
	fft(a,L,1); fft(b,L,1); for(int i=0;i<L;i++) a[i]*=b[i];
	fft(a,L,-1); for(int i=0;i<N;i++) ans[i]=(int)(a[i].real()/L+0.5);

	for(int i=0;i<n;i++) if(ch[i]=='b') aa[i]=bb[i]=1;
	fft(aa,L,1); fft(bb,L,1); for(int i=0;i<L;i++) aa[i]*=bb[i]; fft(aa,L,-1);
	for(int i=0;i<N;i++) ans[i]+=(int)(aa[i].real()/L+0.5);

	for(int i=0;i<N;i++) ans[i]=er[(ans[i]+1)/2]-1;
	for(int i=0;i<N;i++) out+=ans[i],out%=MOD;

	out-=manacher(); out+=MOD; out%=MOD;
	printf("%d",out);
}

int main()
{
    work();
    return 0;
}

  

posted @ 2017-02-24 11:58  ljh_2000  阅读(207)  评论(0编辑  收藏  举报