BM 算法学习笔记

\(\text{Berlekamp-Massey}\) 算法

考虑维护这个序列 \(a\) 前缀的递推序列 \(f\)

不妨假设当前考虑到 \(a_1,a_2,a_3\dots a_n\),当前的递推序列为 \(f_1,f_2,f_3\dots f_m\),那么接下来分为两种情况:

  • \(a_n= \sum _{i=1}^m f_i a_{n-i}\),那么 \(a_n\) 是符合这一个递推序列的。
  • \(a_n \neq \sum_{i=1}^m f_i a_{n-i}\),考虑记下当前的差值 \(s=a_n-\sum_{i=1}^m f_ia_{n-i}\),我们称次为发生一个冲突。

找到上一个冲突的位置 \(p\),不妨记当时的差值为 \(t\),有 \(a_p - \sum_{i=1}^{m'} f'_i a_{p-i}=t\)

考虑如果我们能够找到一个递推序列 \(g\) 使得计算出来 \(a_n\) 的值位为 \(1\),其余为 \(0\),那么我们就可以令 \(f \leftarrow f+sg\)

考虑这一个冲突的位置 \(p\) 有什么用,我们可以构造一个递推序列 \(F=\{0,0,0,\dots,0,\dfrac{s}{t},-\dfrac{s}{t} \times f' \}\),且 \(|F|=n-m+|f'|\),不难发现我们可以让 \(F\) 成为 \(g\),显然 \(a_n=1,a_{n-1}=a_{m-1}\dots\)

关于最短的话感性理解一下好了。

代码如下。

#include<bits/stdc++.h>

#define fi first
#define se second
#define vc vector
#define db double
#define END exit(0)
#define pb push_back
#define mk make_pair
#define ll long long
#define PI pair<int,int>
#define ull unsigned long long
#define all(x) x.begin(), x.end()
#define err cerr << " -_- " << endl
#define debug cerr << "------------------------" << endl

//useful:
//#define cout cerr
//#define OccDreamer
#define int long long

using namespace std;

namespace IO{
	inline int read(){
		int X=0, W=0;char ch=getchar();
		while(!isdigit(ch)) W|=ch=='-', ch=getchar();
		while(isdigit(ch)) X=(X<<1)+(X<<3)+(ch^48), ch=getchar();
		return W?-X:X;
	}
	inline void write(ll x){
		if(x<0) x=-x, putchar('-');
		if(x>9) write(x/10);
		putchar(x%10+'0');
	}
	inline void sprint(ll x){write(x);putchar(32);}
	inline void eprint(ll x){write(x);putchar(10);}
}using namespace IO;

const int MAXN = 5003;
const int mod = 998244353;

int n, m, tot[MAXN<<1];
int P[MAXN<<1], f[MAXN<<1][MAXN], a[MAXN<<4], dp[MAXN<<4], las, t;
int p[MAXN<<4], q[MAXN<<4], r[MAXN<<4];

inline ll Quickpow(ll x, ll y){
	ll z=1;
	while(y){
		if(y&1) z=z*x%mod;
		x=x*x%mod; y>>=1;
	}
	return z;
}

inline void NTT(int *x, int limit, int op){
	ll g, w, A, B;
	for(int i=0;i<limit;++i) if(i<dp[i]) swap(x[i],x[dp[i]]);
	for(int L=1;L<limit;L<<=1){
		g=Quickpow(114514,(mod-1)/(L<<1));
		if(op==-1) g=Quickpow(g,mod-2);
		for(int R=L<<1, pos=0;pos<limit;pos+=R){
			w=1;
			for(int k=0;k<L;++k, w=w*g%mod){
				A=x[pos+k], B=1ll*x[pos+k+L]*w%mod;
				x[pos+k]=(A+B)%mod; x[pos+k+L]=(A-B+mod)%mod;
			}	
		} 
	}
	if(op==-1){
		ll inv=Quickpow(limit,mod-2);
		for(int i=0;i<limit;++i) x[i]=1ll*x[i]*inv%mod;
	}
	return ;
}

inline int Solve(int k, int s){
	n=s;
	int limit=1; while(limit<=2*n) limit<<=1;
	for(int i=1;i<limit;++i) dp[i]=dp[i-(i&-i)]+limit/((i&-i)<<1);
	for(int i=0;i<n;++i) p[i]=P[i+1];
	for(int i=1;i<=n;++i) q[i]=mod-a[i]; q[0]++;
	for(int i=0;i<=n;++i) r[i]=q[i];
	NTT(r,limit,1); NTT(p,limit,1);
	for(int i=0;i<limit;++i) p[i]=1ll*p[i]*r[i]%mod;
	NTT(p,limit,-1);
	for(int i=n;i<limit;++i) p[i]=0;
	while(k){
		for(int i=0;i<=n;++i) r[i]=(i&1)?(mod-q[i]):q[i];
		for(int i=n+1;i<limit;++i) r[i]=0;
		NTT(p,limit,1); NTT(q,limit,1); NTT(r,limit,1);
		for(int i=0;i<limit;++i) p[i]=1ll*p[i]*r[i]%mod, q[i]=1ll*q[i]*r[i]%mod;
		NTT(p,limit,-1); NTT(q,limit,-1);
		for(int i=0;i<=n;++i) q[i]=q[i<<1];
		for(int i=0;i<=n;++i) p[i]=p[i<<1|(k&1)];
		for(int i=n+1;i<limit;++i) p[i]=q[i]=0;
		k>>=1;
	}
	return 1ll*p[0]*Quickpow(q[0],mod-2)%mod;
}

inline void BM(){
	las=0; t=0; int res=0, inv;
	for(int i=1;i<=n;++i){
		res=0;
		for(int j=1;j<=tot[i-1];++j)
			res+=1ll*f[i-1][j]*P[i-j]%mod, res%=mod;
		res=mod+P[i]-res; res%=mod;
		if(res==0){
			tot[i]=tot[i-1];
			for(int j=1;j<=tot[i];++j) f[i][j]=f[i-1][j];
			continue;
		}
		if(!las){
			tot[i]=i;
			las=i; t=res;
			continue;
		}
		inv=Quickpow(t,mod-2)*res%mod;
		tot[i]=i-las+tot[las-1]; f[i][i-las]=inv;
		for(int j=i-las+1;j<=tot[i];++j) f[i][j]=1ll*(mod-inv)*f[las-1][j-i+las]%mod;
		for(int j=1;j<=tot[i];++j) f[i][j]+=f[i-1][j], f[i][j]%=mod;
		las=i, t=res;
	}
	cerr << tot[n] << endl;
	for(int i=1;i<=tot[n];++i) sprint(a[i]=f[n][i]);
	putchar(10); eprint(Solve(m,tot[n]));
	return ;
}

signed main(){
	n=read(), m=read();
	for(int i=1;i<=n;++i) P[i]=read(); BM();
	return 0;
}
posted @ 2023-04-05 15:48  OccDreamer  阅读(102)  评论(0)    收藏  举报