# 题解

$$f[i][j] = \sum_{k=0}^{j-1} f[i-1][k] * f[i][j-k-1]\\f[i][0] = 1\\f[0][i] = 0(i>0)\\f[1][i] = 1$$

$$f_i(x) = x f_{i-1}(x)f_i(x) + 1\\f_i(x) = \frac{1}{1-xf_{i-1}(x)}$$

$$\frac{A_i(x)}{B_i(x)} = \frac 1 {1-x\frac{A_{i-1}(x)}{B_{i-1}(x)}}\\=\frac{B_{i-1}(x)}{B_{i-1}(x) - xA_{i-1}(x) }$$

$$A_i(x) = B_{i-1}(x)$$

$$B_i(x) = B_{i-1}(x) - xA_{i-1}(x)$$

$$\begin{bmatrix}0 & 1\\-x & 1 \end{bmatrix}\begin{pmatrix}A_{i-1}(x)\\B_{i-1}(x)\end{pmatrix}=\begin{pmatrix}A_i(x)\\B_i(x)\end{pmatrix}$$

# 代码

#pragma GCC optimize("Ofast","inline")
#include <bits/stdc++.h>
#define clr(x) memset(x,0,sizeof (x))
#define For(i,a,b) for (int i=a;i<=b;i++)
#define Fod(i,b,a) for (int i=b;i>=a;i--)
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define _SEED_ ('C'+'L'+'Y'+'A'+'K'+'I'+'O'+'I')
#define outval(x) printf(#x" = %d\n",x)
#define outvec(x) printf("vec "#x" = ");for (auto _v : x)printf("%d ",_v);puts("")
#define outtag(x) puts("----------"#x"----------")
#define outarr(a,L,R) printf(#a"[%d...%d] = ",L,R);\
For(_v2,L,R)printf("%d ",a[_v2]);puts("");
using namespace std;
typedef long long LL;
LL x=0,f=0;
char ch=getchar();
while (!isdigit(ch))
f|=ch=='-',ch=getchar();
while (isdigit(ch))
x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
return f?-x:x;
}
const int N=1<<20,mod=998244353;
if ((x+=y)>=mod)
x-=mod;
}
void Del(int &x,int y){
if ((x-=y)<0)
x+=mod;
}
int Pow(int x,int y){
int ans=1;
for (;y;y>>=1,x=(LL)x*x%mod)
if (y&1)
ans=(LL)ans*x%mod;
return ans;
}
namespace poly{
int w[N],R[N];
int a[N],b[N];
void prework(int n,int d){
For(i,0,n-1)
R[i]=(R[i>>1]>>1)|((i&1)<<(d-1));
w[0]=1,w[1]=Pow(3,(mod-1)/n);
For(i,2,n-1)
w[i]=(LL)w[i-1]*w[1]%mod;
}
void FFT(int a[],int n){
For(i,0,n-1)
if (i<R[i])
swap(a[i],a[R[i]]);
for (int t=n>>1,d=1;d<n;d<<=1,t>>=1)
for (int i=0;i<n;i+=d<<1)
for (int j=0;j<d;j++){
int tmp=(LL)w[t*j]*a[i+j+d]%mod;
a[i+j+d]=(a[i+j]-tmp+mod)%mod;
}
}
vector <int> Mul(vector <int> A,vector <int> B){
static vector <int> ans;
ans.clear();
int n,d;
for (n=1,d=0;n<A.size()+B.size();n<<=1,d++);
prework(n,d);
For(i,0,n-1)
a[i]=b[i]=0;
For(i,0,(int)A.size()-1)
a[i]=A[i];
For(i,0,(int)B.size()-1)
b[i]=B[i];
FFT(a,n),FFT(b,n);
For(i,0,n-1)
a[i]=(LL)a[i]*b[i]%mod;
w[1]=Pow(w[1],mod-2);
For(i,2,n-1)
w[i]=(LL)w[i-1]*w[1]%mod;
FFT(a,n);
int inv=Pow(n,mod-2);
For(i,0,n-1)
ans.pb((LL)a[i]*inv%mod);
while (!ans.empty()&&!ans.back())
ans.pop_back();
return ans;
}
vector <int> Get_Inv(vector <int> a,int n){
static vector <int> A,B,tmp;
A.clear(),B.clear();
B.pb(Pow(a[0],mod-2));
for (int d=1;d<=n*2;d<<=1){
while (A.size()<=d)
if (a.size()>d)
A.pb(a[A.size()]);
else
A.pb(0);
tmp=Mul(A,Mul(B,B));
tmp.resize(d+1,0);
B.resize(d+1,0);
For(i,0,d)
B[i]=(2LL*B[i]-tmp[i]+mod)%mod;
}
B.resize(n+1,0);
return B;
}
}
using poly::FFT;
using poly::Mul;
using poly::Get_Inv;
struct Mat{
int v[2][2];
Mat(){}
Mat(int x){
clr(v);
For(i,0,1)
v[i][i]=x;
}
Mat(int v00,int v01,int v10,int v11){
v[0][0]=v00,v[0][1]=v01;
v[1][0]=v10,v[1][1]=v11;
}
friend Mat operator * (Mat A,Mat B){
Mat ans(0);
For(i,0,1)
For(j,0,1)
For(k,0,1)
ans.v[i][j]=((LL)A.v[i][k]*B.v[k][j]+ans.v[i][j])%mod;
return ans;
}
};
Mat Pow(Mat x,int y){
Mat ans(1);
for (;y;y>>=1,x=x*x)
if (y&1)
ans=ans*x;
return ans;
}
int n,m;
int k,d;
int a[N],b[N];
vector <int> A,B;
int solve(int n,int m){
if (n<m)
return 0;
for (k=1,d=0;k<=n;k<<=1,d++);
poly::prework(k,d);
For(i,0,k-1){
Mat res=Pow(Mat(0,1,(mod-poly::w[i])%mod,1),m)*Mat(1,0,1,0);
a[i]=res.v[0][0],b[i]=res.v[1][0];
}
poly::w[1]=Pow(poly::w[1],mod-2);
For(i,2,k-1)
poly::w[i]=(LL)poly::w[i-1]*poly::w[1]%mod;
FFT(a,k),FFT(b,k);
A.clear(),B.clear();
For(i,0,n)
A.pb(a[i]),B.pb(b[i]);
A=Mul(A,Get_Inv(B,n));
A.resize(n+1,0);
return A[n];
}
int main(){
cout<<solve(n,m)<<endl;
return 0;
}


posted @ 2019-03-12 19:47 -zhouzhendong- 阅读(...) 评论(...) 编辑 收藏