LOJ 3058 「HNOI2019」白兔之舞——单位根反演+MTT

题目:https://loj.ac/problem/3058

先考虑 n=1 怎么做。令 a 表示输入的 w[1][1] 。

  \( ans_t = \sum\limits_{i=0}^{L}C_{L}^{i} a^i [ k|(i-t) ] \)

     \(= \frac{1}{k}\sum\limits_{i=0}^{L}C_{L}^{i} a^i \sum\limits_{j=0}^{k-1} w_{k}^{j*(i-t)} \)

     \(= \frac{1}{k}\sum\limits_{j=0}^{k-1}w_{k}^{-j*t} \sum\limits_{i=0}^{L}C_{L}^{i} a^i w_{k}^{i*j} \)

     \(= \frac{1}{k}\sum\limits_{j=0}^{k-1}w_{k}^{-j*t} (1+a*w_{k}^{j})^{L} \)

  这样是 k2 的,就不会了……

  考虑卷积。把 -j*t 拆成只和 j 有关的与只和 t 或者 t+j 、t-j 有关的。

  注意到 \( j*t = C_{j+t}^{2} - C_{t}^{2} - C_{j}^{2} \) 。考虑 j*t 表示从 j 个里选一个、再从 t 个里选一个;表示成从 (j+t) 里选两个,再减去不合法的,即从 j 个里选了两个或从 t 个里选了两个。

  \( ans_t = \frac{1}{k}\sum\limits_{j=0}^{k}w_{k}^{-\binom{j+t}{2}+\binom{j}{2}+\binom{t}{2}} (1+a*w_{k}^{j})^{L} \)

  最后那个部分只和 j 有关。所以令 \( c_j = (1+a*w_{k}^{j})^{L} \)

     \(= \frac{ w_{k}^{\binom{t}{2}} }{k}\sum\limits_{j=0}^{k-1}w_{k}^{\binom{i}{2}} c_j * w_{k}^{-\binom{j+t}{2}} \)

  然后可以卷积。

  如果 n>1 ,用矩阵表示 “从 x 用 i 步走到 y ”的方案!仍然要乘组合数。

  也就是除了 \( c_j = ( I + A*w_{k}^{j} )^{L} [x,y] \) 之外都没变。其中 A 是输入的矩阵,[x,y] 表示取矩阵的第 x 行第 y 列的值作为 \( c_j \) 。

  给矩阵乘一个数字,是给其每个位置都乘。

  不开 long double 会变成 0 分。

  复习 MTT 。

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define db long double
#define ll long long
using namespace std;
const int N=(1<<18)+5; const db pi2=acos(-1)*2;
int n,k,L,x,y,mod,G,bs,len,r[N],c[N],f[N],g[N],wn[N];
int upt(int x){while(x>=mod)x-=mod;while(x<0)x+=mod;return x;}
int pw(int x,int k)
{int ret=1;while(k){if(k&1)ret=(ll)ret*x%mod;x=(ll)x*x%mod;k>>=1;}return ret;}

struct cpl{
  db x,y;
  cpl(db x=0,db y=0):x(x),y(y) {}
  cpl operator+ (const cpl &b)const{return cpl(x+b.x,y+b.y);}
  cpl operator- (const cpl &b)const{return cpl(x-b.x,y-b.y);}
  cpl operator* (const cpl &b)const
  {return cpl(x*b.x-y*b.y,x*b.y+y*b.x);}
  cpl operator/ (const db &b)const{return cpl(x/b,y/b);}
}Ta[N],Tb[N],P[N],Q[N];
cpl cnj(cpl a){return cpl(a.x,-a.y);}
struct Mtr{
  int a[3][3];
  Mtr(){memset(a,0,sizeof a);}
  Mtr operator* (const Mtr &b)const
  {
    Mtr c;
    for(int i=0;i<n;i++)
      for(int k=0;k<n;k++)
    for(int j=0;j<n;j++)
      c.a[i][j]=(c.a[i][j]+(ll)a[i][k]*b.a[k][j])%mod;
    return c;
  }
  Mtr operator* (const int &b)const
  {
    Mtr c;
    for(int i=0;i<n;i++)
      for(int j=0;j<n;j++)c.a[i][j]=(ll)a[i][j]*b%mod;
    return c;
  }
  Mtr operator+ (const Mtr &b)const
  {
    Mtr c;
    for(int i=0;i<n;i++)
      for(int j=0;j<n;j++)c.a[i][j]=upt(a[i][j]+b.a[i][j]);
    return c;
  }
}A,tA,tA2,I;
namespace get_G{
  int p[35],tot;
  void solve()
  {
    int k=mod-1;
    for(int i=2;i*i<=k;i++)
      if(k%i==0){ p[++tot]=(mod-1)/i; while(k%i==0)k/=i;}
    for(int i=2;;i++)
      {
    bool fg=0;
    for(int j=1;j<=tot;j++)
      if(pw(i,p[j])==1){fg=1;break;}
    if(!fg){G=i;break;}
      }
  }
}
void fft(cpl *a,bool fx)
{
  for(int i=0;i<len;i++)
    if(i<r[i])swap(a[i],a[r[i]]);
  for(int R=2;R<=len;R<<=1)
    {
      cpl wn=cpl(cos(pi2/R),fx?-sin(pi2/R):sin(pi2/R));
      for(int i=0,m=R>>1;i<len;i+=R)
    {
      cpl w=cpl(1,0);
      for(int j=0;j<m;j++,w=w*wn)
        {
          cpl x=a[i+j],y=w*a[i+m+j];
          a[i+j]=x+y; a[i+m+j]=x-y;
        }
    }
    }
  if(!fx)return;
  for(int i=0;i<len;i++)a[i]=a[i]/len;
}
void MTT(int *a,int *b,int n,int m,int *c)
{
  bs=sqrt(mod); cpl ta,tb,tc,td;
  for(len=1;len<=n+m;len<<=1);
  for(int i=0,j=len>>1;i<len;i++)
    r[i]=(r[i>>1]>>1)+((i&1)?j:0);
  for(int i=0;i<=n;i++) P[i]=cpl(a[i]/bs,a[i]%bs);
  for(int i=0;i<=m;i++) Q[i]=cpl(b[i]/bs,b[i]%bs);
  fft(P,0); fft(Q,0);
  P[len]=P[0]; Q[len]=Q[0];
  for(int i=0;i<len;i++)
    {
      ta=(P[i]+cnj(P[len-i]))*cpl(0.5,0);
      tb=(P[i]-cnj(P[len-i]))*cpl(0,-0.5);
      tc=(Q[i]+cnj(Q[len-i]))*cpl(0.5,0);
      td=(Q[i]-cnj(Q[len-i]))*cpl(0,-0.5);
      Ta[i]=ta*tc+ta*td*cpl(0,1);
      Tb[i]=tb*tc+tb*td*cpl(0,1);
    }
  fft(Ta,1); fft(Tb,1);
  for(int i=0,lm=n+m,bs2=bs*bs;i<=lm;i++)
    {
      int a2=(ll)(Ta[i].x+0.5)%mod;//%mod
      int b2=(ll)(Ta[i].y+0.5)%mod;
      int c2=(ll)(Tb[i].x+0.5)%mod;
      int d2=(ll)(Tb[i].y+0.5)%mod;
      c[i]=((ll)a2*bs2+(ll)(b2+c2)*bs+d2)%mod;
    }
}
int main()
{
  scanf("%d%d%d%d%d%d",&n,&k,&L,&x,&y,&mod);
  get_G::solve(); x--; y--;
  for(int i=0;i<n;i++)
    for(int j=0;j<n;j++)scanf("%d",&A.a[i][j]);
  wn[0]=1; wn[1]=pw(G,(mod-1)/k);
  for(int i=2;i<=k;i++)wn[i]=(ll)wn[i-1]*wn[1]%mod;//<=k
  for(int i=0;i<n;i++)I.a[i][i]=1;//
  for(int i=0;i<k;i++)
    {
      tA=A*wn[i]+I; tA2=I; int tp=L;//tp=L not i!!!!!
      while(tp)
    { if(tp&1)tA2=tA2*tA; tA=tA*tA; tp>>=1;}
      c[i]=tA2.a[x][y];
    }
  for(int i=0;i<k;i++)
    f[k-1-i]=(ll)wn[(ll)i*(i-1)/2%k]*c[i]%mod;
  for(int i=0,lm=2*(k-1);i<=lm;i++)
    g[i]=wn[k-(ll)i*(i-1)/2%k];
  MTT(f,g,k-1,2*(k-1),f); int inv=pw(k,mod-2);
  for(int i=0;i<k;i++)
    {
      int ans=(ll)wn[(ll)i*(i-1)/2%k]*inv%mod;
      ans=(ll)ans*f[k-1+i]%mod; printf("%d\n",ans);
    }
  return 0;
}
View Code

 

posted on 2019-05-22 22:00  Narh  阅读(178)  评论(0编辑  收藏

导航