bzoj 3160 万径人踪灭 —— FFT

题目:https://www.lydsy.com/JudgeOnline/problem.php?id=3160

求出关于一个位置有多少对对称字母,如果 i 位置有 f[i] 对,对答案的贡献是 2^f[i] - 1;

然后减去连续的,用 manachar 求出回文长度,每个位置作为边界都是一种不合法情况;

求对称,首先把字符串中间穿插字符 '$',于是字符串的长度变成2倍;

考虑一对字母 s[x],s[y],如果 s[x] = s[y],其对称中心是 (x+y)/2;

放在加入字符后的字符串中,对称中心就是 x+y;

所以可以看出卷积了:f[i] = ∑(0<=j<=i) (s[j]==s[i-j]),其中 i 视为新字符串中的位置,j 和 i-j 视为原字符串中的位置;

注意卷积和 manachar 算的个数都要包括自己成对,否则判断挺麻烦...

这里卷积的两个多项式其实是一样的,所以只要用 FFT 算出一个,然后自己乘起来即可;

做下一步的时候注意清空,别忘了清空 n~lim 部分的值;

处理 bin 的边界是 n 而非 n-1,因为最多可能有 n 对。

(学习了 manachar 的简洁写法)

代码如下:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
typedef double db;
int const xn=(1<<19),mod=1e9+7;
db const Pi=acos(-1.0);
int n,rev[xn],lim=1,l,len[xn],bin[xn],c[xn];
char ch[xn];
struct com{db x,y;}a[xn],b[xn],aa[xn];
com operator + (com a,com b){return (com){a.x+b.x,a.y+b.y};}
com operator - (com a,com b){return (com){a.x-b.x,a.y-b.y};}
com operator * (com a,com b){return (com){a.x*b.x-a.y*b.y,a.x*b.y+b.x*a.y};}
int upt(int x){while(x>=mod)x-=mod; while(x<0)x+=mod; return x;}
void fft(com *a,int tp)
{
  for(int i=0;i<lim;i++)
    if(i<rev[i])swap(a[i],a[rev[i]]);
  for(int mid=1;mid<lim;mid<<=1)
    {
      com wn=(com){cos(Pi/mid),tp*sin(Pi/mid)};
      for(int j=0,len=(mid<<1);j<lim;j+=len)
    {
      com w=(com){1,0};
      for(int k=0;k<mid;k++,w=w*wn)
        {
          com x=a[j+k],y=w*a[j+mid+k];
          a[j+k]=x+y; a[j+mid+k]=x-y;
        }
    }
    }
}
void solve()
{
  for(int i=0;i<n;i++)a[i].x=(ch[i]=='a');
  fft(a,1);
  for(int i=0;i<lim;i++)b[i]=a[i]*a[i];
  for(int i=0;i<n;i++)a[i].x=(ch[i]=='b'),a[i].y=0;//y=0
  for(int i=n;i<lim;i++)a[i].x=0,a[i].y=0;//!!
  fft(a,1);
  for(int i=0;i<lim;i++)b[i]=b[i]+a[i]*a[i];
  fft(b,-1);
  for(int i=0;i<n+n;i++)c[i]=(c[i]+(int)(b[i].x/lim+0.5))%mod;
}
char s[xn];
int manachar()//+i self
{
  int mx=0,id=0,ret=0; s[0]='$';
  for(int i=1;i<=n+n;i++)
    if(i%2==0)s[i]='$';
    else s[i]=ch[i>>1];
  for(int i=1;i<=n+n;i++)
    {
      if(i<mx)len[i]=min(mx-i,len[id*2-i]);
      while(i-len[i]>=0&&i+len[i]<=n+n&&s[i-len[i]]==s[i+len[i]])len[i]++;
      if(i+len[i]>mx)mx=i+len[i],id=i;
      ret=upt(ret+len[i]/2);
    }
  return ret;
}
int main()
{
  scanf("%s",ch); n=strlen(ch);
  while(lim<=n+n)lim<<=1,l++;//
  for(int i=0;i<lim;i++)
    rev[i]=((rev[i>>1]>>1)|((i&1)<<(l-1)));
  bin[0]=1;
  for(int i=1;i<=n;i++)bin[i]=upt(bin[i-1]+bin[i-1]);
  solve();
  int ans=0;
  for(int i=0;i<n+n;i++)ans=upt(ans+bin[(c[i]+1)>>1]-1);//+1 -1
  printf("%d\n",upt(ans-manachar()));
  return 0;
}

 

posted @ 2018-11-26 19:46  Zinn  阅读(113)  评论(0编辑  收藏  举报