LG5487 【模板】线性递推+BM算法

【模板】线性递推+BM算法

给出一个数列 \(P\)\(0\) 开始的前 \(n\) 项,求序列 \(P\)\(\bmod~998244353\) 下的最短线性递推式,并在 \(\bmod~ 998244353\) 下输出 \(P_m\)

\(m\leq 10^9,1\leq n\leq 10000\),保证递推式最长不超过 \(5000\)

Berlekamp-Massey 算法

Berlekamp-Massey 算法,常简称为 BM 算法,是用来求解一个数列的最短线性递推式的算法。

BM 算法可以在 \(O(n^2)\) 的时间内求解一个长度为 \(n\) 的数列的最短线性递推式。

基本定义

对于数列 \(A=\{a_1,a_2,\dots,a_n\}\),我们定义数列 \(R=\{r_1,r_2,\dots,r_m\}\) 为其线性递推式当且仅当

\[\forall i\in [m+1,n],~a_i=\sum_{j=1}^mr_ja_{i-j} \]

注意,无论你习惯从左到右写数列还是从右到左,这里的数列和线性递推式的位置对应关系是反着的。

所有可能的线性递推式 \(R\) 中长度 \(m\) 最小的叫做 \(A\)最短线性递推式

算法流程

假设我们已经求得了 \(\{a_1,a_2,\dots,a_{i-1}\}\) 的最短线性递推式 \(\{r_1,r_2,\dots,r_m\}\),那么如何求得 \(\{a_1,a_2,\dots,a_i\}\) 的最短线性递推式?

定义 \(\{a_1,a_2,\dots,a_{i-1}\}\) 的最短线性递推式 \(\{r_1,r_2,\dots,r_m\}\)当前递推式,记递推式被更改的次数为 \(cnt\),第 \(i\) 次更改后的递推式为 \(R_i\),那么当前递推式应当为 \(R_{cnt}\)。特别地,定义 \(R_0=\varnothing\)

我们对每个版本的 \(R\),记一个表示差异量的数组 \(\Delta_i\),满足

\[\Delta_i=a_i-\sum_{j=1}^mr_ja_{i-j} \]

显然若 \(\Delta_i=0\),那么当前递推式就是 \(\{a_1,a_2,\dots,a_i\}\) 的最短线性递推式。

否则我们认为 \(R_{cnt}\)\(a_i\) 处出错了,令 \(fail_i\)\(R_i\) 最早的出错位置,则有 \(fail_{cnt}=i\)。考虑对 \(R_{cnt}\) 进行修改,使其变为 \(R_{cnt+1}\),并在 \(a_i\) 处同样成立。

若当前 \(cnt=0\),说明 \(a_i\) 是第一个非零元素,直接将 \(R_1\) 置为 \(\{ \underbrace{0,0,\dots,0}_{i} \}\) 即可,因为不可能用之前的数递推出 \(a_i\)

否,即 \(cnt>0\),则考虑用 \(R_{cnt}\) 之前失败的递推式将这个 \(\Delta_i\) 加回去(\(a=\sum+\Delta\))。我们希望得到数列 \(R'=\{r'_1,r'_2,...,r'_{m'}\}\),使得

  1. \[\forall k\in [m'+1,i-1],~\sum_{j=1}^{m'}r'_ja_{k-j}=0 \]

  2. \[\sum_{j=1}^{m'}r'_ja_{i-j}=\Delta_i \]

如果能够找到这样的数列 \(R'\),那么令\(R_{cnt+1}=R_{cnt}+R'\)即可。这里加号表示各位对应相加。

在之前失败的递推式中任选一个 \(R_p\),尝试在它的基础上修改,在 \(i\) 的位置上构造出 \(\Delta_{fail_p}\)(这里的 \(\Delta\) 是对应 \(R_p\) 版本的),记得到的结果为 \(R_p'\),那么

\[R'=\frac{\Delta_i}{\Delta_{fail_p}}R'_p \]

考虑如何构造 \(R'_p\) 。将 \(R_p\) 的元素全部变成它的相反数,再在前面补上一个 \(1\) , \(\Delta_{fail_p}\) 就到 \(fail_p+1\) 位置上来了。

\[a_{fail_p}-\sum_{i=1}^{m_p}R_{p,i}a_{fail_p-i}=\Delta_{p,fail_p} \]

前面再补上 \(i-fail_p-1\)\(0\)\(\Delta_{fail_p}\) 就到 \(i\) 位置上来了。于是

\[R'_p=\{\underbrace{0,0,\dots,0}_{i-fail_p-1},1\}(-R_p) \]

这里乘号表示顺次连接。

又因为 \(R_p\)\(fail_p\) 前的 \(\Delta=0\),所以我们构造出来的 \(R'\) 是满足第一条约束的。

为了保证得到的递推式长度最短,我们需要选取恰当的 \(R_p\)。容易看出,得到的 \(R_{cnt+1}\) 的长度为 \(\max(i-fail_p+m_p,m)\)。于是记录 \(m_p-fail_p\) 最短的递推式作为 \(R_p\)

至此我们完成了 BM 算法的理论部分,在最坏情况下,我们可能需要对数列进行 \(O(n)\) 次修改,因此该算法的时间复杂度为 \(O(n^2)\)

经验之谈

用 BM 得到的最短递推式长度最好要明显小于 \(n\) 的一半,否则需要再打些表。

为什么?因为若长度为 \(\frac n 2\),可以看做 \(\frac n 2\) 个变量列出 \(\frac n 2\) 个方程,总能找到解。所以一个随机数列解出的最短递推式长度就是 \(\frac n 2\) 左右。发生了这样的情况说明原数列很可能并没有一定的规律,即递推式大概率对之后的数据不适用。

另外因为计算中涉及除法,所以 BM 在实数域内求解可能有一定的精度误差。

namespace linear{
	typedef vector<int> polynomial;
	void num_trans(polynomial&a,int dir){
	    int lim=a.size();
	    static vector<int> rev,w[2];
	    if(rev.size()!=lim){
	        rev.resize(lim);
	        int len=log2(lim);
	        for(int i=0;i<lim;++i) rev[i]=rev[i>>1]>>1|(i&1)<<(len-1);
	        for(int dir=0;dir<2;++dir){
	            static co int g[2]={3,332748118};
	            w[dir].resize(lim);
	            w[dir][0]=1,w[dir][1]=fpow(g[dir],(mod-1)/lim);
	            for(int i=2;i<lim;++i) w[dir][i]=mul(w[dir][i-1],w[dir][1]);
	        }
	    }
	    for(int i=0;i<lim;++i)if(i<rev[i]) swap(a[i],a[rev[i]]);
	    for(int step=1;step<lim;step<<=1){
	        int quot=lim/(step<<1);
	        for(int i=0;i<lim;i+=step<<1){
	            int j=i+step;
	            for(int k=0;k<step;++k){
	                int t=mul(w[dir][quot*k],a[j+k]);
	                a[j+k]=add(a[i+k],mod-t),a[i+k]=add(a[i+k],t);
	            }
	        }
	    }
	    if(dir){
	        int ilim=fpow(lim,mod-2);
	        for(int i=0;i<lim;++i) a[i]=mul(a[i],ilim);
	    }
	}
	polynomial poly_inv(polynomial a,int n){
	    polynomial b(1,fpow(a[0],mod-2));
	    if(n==1) return b;
	    int lim=2;
	    for(;lim<n;lim<<=1){
	        polynomial a1(a.begin(),a.begin()+lim);
	        a1.resize(lim<<1),num_trans(a1,0);
	        b.resize(lim<<1),num_trans(b,0);
	        for(int i=0;i<lim<<1;++i) b[i]=mul(add(2,mod-mul(a1[i],b[i])),b[i]);
	        num_trans(b,1),b.resize(lim);
	    }
	    a.resize(lim<<1),num_trans(a,0);
	    b.resize(lim<<1),num_trans(b,0);
	    for(int i=0;i<lim<<1;++i) b[i]=mul(add(2,mod-mul(a[i],b[i])),b[i]);
	    num_trans(b,1),b.resize(n);
	    return b;
	}
	polynomial operator/(polynomial f,polynomial g){
	    int n=f.size()-1,m=g.size()-1;
	    reverse(g.begin(),g.end()),g.resize(n-m+1),g=poly_inv(g,n-m+1);
	    reverse(f.begin(),f.end()),f.resize(n-m+1);
	    int lim=1<<int(ceil(log2((n-m)<<1|1)));
	    f.resize(lim),num_trans(f,0);
	    g.resize(lim),num_trans(g,0);
	    for(int i=0;i<lim;++i) f[i]=mul(f[i],g[i]);
	    num_trans(f,1),f.resize(n-m+1);
	    return reverse(f.begin(),f.end()),f;
	}
	polynomial operator%(polynomial f,polynomial g){
	    int n=f.size()-1,m=g.size()-1;
	    polynomial q=f/g;
	    int lim=1<<int(ceil(log2(n+1)));
	    q.resize(lim),num_trans(q,0);
	    g.resize(lim),num_trans(g,0);
	    for(int i=0;i<lim;++i) q[i]=mul(q[i],g[i]);
	    num_trans(q,1);
	    for(int i=0;i<m;++i) f[i]=add(f[i],mod-q[i]);
	    return f.resize(m),f;
	}
	
	int n,k;
	void mul_mod(polynomial&a,polynomial b,co polynomial&p){
	    static co int lim=1<<int(ceil(log2(2*k-1)));
	    a.resize(lim),b.resize(lim);
	    num_trans(a,0),num_trans(b,0);
	    for(int i=0;i<lim;++i) a[i]=mul(a[i],b[i]);
	    num_trans(a,1),a.resize(2*k-1);
	    a=a%p;
	}
	void main(int _n,int _k,co vector<int>&_a,co vector<int>&_f){
		n=_n,k=_k;
	    polynomial a(k),f(k);
	    for(int i=1;i<=k;++i) a[k-i]=mod-_a[i];
	    a.push_back(1);
	    for(int i=0;i<k;++i) f[i]=_f[i];
	    polynomial rmd(1,1),tmp(2);tmp[1]=1;
	    for(;n;n>>=1,mul_mod(tmp,tmp,a))
	        if(n&1) mul_mod(rmd,tmp,a);
	    int ans=0;
	    for(int i=0;i<k;++i) ans=add(ans,mul(rmd[i],f[i]));
	    printf("%d\n",ans);
	}
}

vector<int> ber_ma(vector<int> f){
	vector<int> lst,cur;
	int lsfa,lsdel;
	for(int i=0;i<(int)f.size();++i){
		int del=f[i];
		for(int j=1;j<(int)cur.size();++j)
			del=add(del,mod-mul(cur[j],f[i-j]));
		if(!del) continue;
		if(!cur.size()){
			cur.resize(i+1),lsfa=i,lsdel=del;
			continue;
		}
		int alph=mul(del,fpow(lsdel,mod-2));
		vector<int> nw(i-lsfa);
		nw.push_back(alph);
		for(int j=1;j<(int)lst.size();++j)
			nw.push_back(mul(alph,mod-lst[j]));
		if(nw.size()<cur.size()) nw.resize(cur.size());
		for(int j=1;j<(int)cur.size();++j)
			nw[j]=add(nw[j],cur[j]);
		if(i-lsfa+(int)lst.size()>=(int)cur.size())
			lst=cur,lsfa=i,lsdel=del;
		cur=nw;
	}
	return cur;
}

int main(){
	int n=read<int>(),m=read<int>();
	vector<int> f(n);
	for(int i=0;i<n;++i) read(f[i]);
	vector<int> a=ber_ma(f);
	for(int i=1;i<(int)a.size();++i) printf("%d ",a[i]);
	puts("");
	if(m<=n) {printf("%d\n",f[m]);return 0;}
	linear::main(m,a.size()-1,a,f);
	return 0;
}

线性递推式是 base 1 的,用 vector 存的话代码有点奇怪。

posted on 2019-08-29 16:53  autoint  阅读(694)  评论(0编辑  收藏  举报

导航