BM 算法学习笔记
\(\text{Berlekamp-Massey}\) 算法
考虑维护这个序列 \(a\) 前缀的递推序列 \(f\)。
不妨假设当前考虑到 \(a_1,a_2,a_3\dots a_n\),当前的递推序列为 \(f_1,f_2,f_3\dots f_m\),那么接下来分为两种情况:
- \(a_n= \sum _{i=1}^m f_i a_{n-i}\),那么 \(a_n\) 是符合这一个递推序列的。
- \(a_n \neq \sum_{i=1}^m f_i a_{n-i}\),考虑记下当前的差值 \(s=a_n-\sum_{i=1}^m f_ia_{n-i}\),我们称次为发生一个冲突。
找到上一个冲突的位置 \(p\),不妨记当时的差值为 \(t\),有 \(a_p - \sum_{i=1}^{m'} f'_i a_{p-i}=t\)。
考虑如果我们能够找到一个递推序列 \(g\) 使得计算出来 \(a_n\) 的值位为 \(1\),其余为 \(0\),那么我们就可以令 \(f \leftarrow f+sg\)。
考虑这一个冲突的位置 \(p\) 有什么用,我们可以构造一个递推序列 \(F=\{0,0,0,\dots,0,\dfrac{s}{t},-\dfrac{s}{t} \times f' \}\),且 \(|F|=n-m+|f'|\),不难发现我们可以让 \(F\) 成为 \(g\),显然 \(a_n=1,a_{n-1}=a_{m-1}\dots\)。
关于最短的话感性理解一下好了。
代码如下。
#include<bits/stdc++.h>
#define fi first
#define se second
#define vc vector
#define db double
#define END exit(0)
#define pb push_back
#define mk make_pair
#define ll long long
#define PI pair<int,int>
#define ull unsigned long long
#define all(x) x.begin(), x.end()
#define err cerr << " -_- " << endl
#define debug cerr << "------------------------" << endl
//useful:
//#define cout cerr
//#define OccDreamer
#define int long long
using namespace std;
namespace IO{
inline int read(){
int X=0, W=0;char ch=getchar();
while(!isdigit(ch)) W|=ch=='-', ch=getchar();
while(isdigit(ch)) X=(X<<1)+(X<<3)+(ch^48), ch=getchar();
return W?-X:X;
}
inline void write(ll x){
if(x<0) x=-x, putchar('-');
if(x>9) write(x/10);
putchar(x%10+'0');
}
inline void sprint(ll x){write(x);putchar(32);}
inline void eprint(ll x){write(x);putchar(10);}
}using namespace IO;
const int MAXN = 5003;
const int mod = 998244353;
int n, m, tot[MAXN<<1];
int P[MAXN<<1], f[MAXN<<1][MAXN], a[MAXN<<4], dp[MAXN<<4], las, t;
int p[MAXN<<4], q[MAXN<<4], r[MAXN<<4];
inline ll Quickpow(ll x, ll y){
ll z=1;
while(y){
if(y&1) z=z*x%mod;
x=x*x%mod; y>>=1;
}
return z;
}
inline void NTT(int *x, int limit, int op){
ll g, w, A, B;
for(int i=0;i<limit;++i) if(i<dp[i]) swap(x[i],x[dp[i]]);
for(int L=1;L<limit;L<<=1){
g=Quickpow(114514,(mod-1)/(L<<1));
if(op==-1) g=Quickpow(g,mod-2);
for(int R=L<<1, pos=0;pos<limit;pos+=R){
w=1;
for(int k=0;k<L;++k, w=w*g%mod){
A=x[pos+k], B=1ll*x[pos+k+L]*w%mod;
x[pos+k]=(A+B)%mod; x[pos+k+L]=(A-B+mod)%mod;
}
}
}
if(op==-1){
ll inv=Quickpow(limit,mod-2);
for(int i=0;i<limit;++i) x[i]=1ll*x[i]*inv%mod;
}
return ;
}
inline int Solve(int k, int s){
n=s;
int limit=1; while(limit<=2*n) limit<<=1;
for(int i=1;i<limit;++i) dp[i]=dp[i-(i&-i)]+limit/((i&-i)<<1);
for(int i=0;i<n;++i) p[i]=P[i+1];
for(int i=1;i<=n;++i) q[i]=mod-a[i]; q[0]++;
for(int i=0;i<=n;++i) r[i]=q[i];
NTT(r,limit,1); NTT(p,limit,1);
for(int i=0;i<limit;++i) p[i]=1ll*p[i]*r[i]%mod;
NTT(p,limit,-1);
for(int i=n;i<limit;++i) p[i]=0;
while(k){
for(int i=0;i<=n;++i) r[i]=(i&1)?(mod-q[i]):q[i];
for(int i=n+1;i<limit;++i) r[i]=0;
NTT(p,limit,1); NTT(q,limit,1); NTT(r,limit,1);
for(int i=0;i<limit;++i) p[i]=1ll*p[i]*r[i]%mod, q[i]=1ll*q[i]*r[i]%mod;
NTT(p,limit,-1); NTT(q,limit,-1);
for(int i=0;i<=n;++i) q[i]=q[i<<1];
for(int i=0;i<=n;++i) p[i]=p[i<<1|(k&1)];
for(int i=n+1;i<limit;++i) p[i]=q[i]=0;
k>>=1;
}
return 1ll*p[0]*Quickpow(q[0],mod-2)%mod;
}
inline void BM(){
las=0; t=0; int res=0, inv;
for(int i=1;i<=n;++i){
res=0;
for(int j=1;j<=tot[i-1];++j)
res+=1ll*f[i-1][j]*P[i-j]%mod, res%=mod;
res=mod+P[i]-res; res%=mod;
if(res==0){
tot[i]=tot[i-1];
for(int j=1;j<=tot[i];++j) f[i][j]=f[i-1][j];
continue;
}
if(!las){
tot[i]=i;
las=i; t=res;
continue;
}
inv=Quickpow(t,mod-2)*res%mod;
tot[i]=i-las+tot[las-1]; f[i][i-las]=inv;
for(int j=i-las+1;j<=tot[i];++j) f[i][j]=1ll*(mod-inv)*f[las-1][j-i+las]%mod;
for(int j=1;j<=tot[i];++j) f[i][j]+=f[i-1][j], f[i][j]%=mod;
las=i, t=res;
}
cerr << tot[n] << endl;
for(int i=1;i<=tot[n];++i) sprint(a[i]=f[n][i]);
putchar(10); eprint(Solve(m,tot[n]));
return ;
}
signed main(){
n=read(), m=read();
for(int i=1;i<=n;++i) P[i]=read(); BM();
return 0;
}

浙公网安备 33010602011771号