【学习笔记】BM(Berlekamp-Massey)算法

就是那个流氓一样的构造递推式的算法= =

所以做一些恐怖多项式题的时候可以试试 BM+肉眼观察x

BM算法通过增量法构造 每次考虑拟合下一个点

设我们现在的递推式为\(R_c\) 初始时\(R_0\)为空

考虑在数列末尾加入\(a_i\) 当前\(|R_c|=m\)

\(\Delta_i = a_i - \sum_{j=1}^m a_{i-j} R_{c,j}\)

如果\(\Delta_i=0\)则直接拟合到这个点了 直接看下一个点

否则,设当前\(Fail_c\)的位置为\(i\)\(Fail_c=i\)

考虑补充一个\(R_{\Delta}\)使得\(R_{c+1}=R_c+R_\Delta\)可以拟合到当前的点

\(R_\Delta\)需要满足

  1. \(\forall |R_\Delta|+1\le k <i, \sum_{j=1}^{|R_\Delta|} a_{k-j}R_{\Delta j} = 0\)
  2. \(\sum_{j=1}^{|R_{\Delta}|} a_{i-j} R_{\Delta j}=\Delta_i\)

我们再看一下\(R_{c-1}\)都满足了哪些条件

​ 1.\(\forall |R_{c-1}+1|\le k <Fail_{c-1}, a_k -\sum_{j=1}^{|R_{c-1}|} a_{k-j}R_{c-1, j} = 0\)

​ 2.\(a_{Fail_{c-1}} -\sum_{j=1}^{R_{c-1}} a_{Fail_{c-1}-j} R_{c-1,j}=\Delta_{Fail_{c-1}}\)

发现这两个柿子其实很像

我们给出如下构造

\(t=\frac{\Delta_i}{\Delta_{Fail_{c-1}}}\)

\(R_\Delta = \{0,0,\dots,0,t,-t\cdot R_{c-1,1} , -t\cdot R_{c-1,2},\dots\}\)

其中开头是\(i-Fail_{c-1} -1\)个0,\(t\)后面跟的是对应的\(|R_{c-1}|\)个值

考虑证明正确性

  1. \(\forall |R_\Delta|+1\le k <i\) 贡献是\(t\cdot (a_k -\sum_{j=1}^{|R_{c-1}|} a_{k-j}R_{c-1, j})=t\cdot 0 =0\) 细节和底下的2类似 可以先看下面的x
  2. \(t\)是第\(i-Fail_{c-1}\)项 对于第\(i\)项的贡献是 \(t\cdot a_{Fail_{c-1}}\) 考虑后面的\(|R_{c-1}|\)个值 对于答案的贡献是\(\sum_{j=1}^{|R_{c-1}|} -t \cdot R_{c-1,j} a_{Fail_{c-1}-j}\) 总贡献就是\(t \cdot (a_{Fail_{c-1}} -\sum_{j=1}^{R_{c-1}} a_{Fail_{c-1}-j} R_{c-1,j})=t\cdot \Delta_{Fail_{c-1}}=\Delta_i\) (转化见上面满足的条件)

所以\(R_\Delta\)符合要求 至于求最短的递推式呢 我们再额外枚举一个\(id\) 找到\(i-Fail_{id}+|R_{id}|\)最短的就可以了

真的流氓算法x

代码容我咕咕一下(。
昨天上午学完这个算法就病倒了...我有理由怀疑是算法的问题x
代码是洛谷的板子w

//Love and Freedom.
#include<cstdio>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<vector>
#define ll long long
#define inf 20021225
#define mdn 998244353
#define N 30100
#define G 3
using namespace std;
int read()
{
	int s=0,t=1; char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')	t=-1; ch=getchar();}
	while(ch>='0' && ch<='9')	s=s*10+ch-'0',ch=getchar();
	return s*t;
}
void upd(int &x,int y){x+=x+y>=mdn?y-mdn:y;}
int ksm(int bs,int mi)
{
	int ans=1;
	while(mi)
	{
		if(mi&1)	ans=1ll*ans*bs%mdn;
		bs=1ll*bs*bs%mdn,mi>>=1;
	}
	return ans;
}
int inv,r[N*4],n,k;
void ntt(int *a,int lim,int l,int f)
{
	for(int i=0;i<lim;i++)	
		r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
	inv=ksm(lim,mdn-2);
	for(int i=0;i<lim;i++)	if(r[i]>i)
		swap(a[r[i]],a[i]);
	for(int k=2,mid=1;k<=lim;k<<=1,mid<<=1)
	{
		int Wn=ksm(G,(mdn-1)/k);
		if(f)	Wn=ksm(Wn,mdn-2);
		for(int i=0,w=1;i<lim;i+=k,w=1)
			for(int j=0;j<mid;j++,w=1ll*w*Wn%mdn)
			{
				int x=a[i+j],y=1ll*w*a[i+j+mid]%mdn;
				a[i+j]=(x+y)%mdn; a[i+j+mid]=(x-y+mdn)%mdn;
			}
	}
	if(f)	for(int i=0;i<lim;i++)	a[i]=1ll*a[i]*inv%mdn;
}
int f[N*4],g[N*4],h[N*4];
void poly_inv(int *a,int *g,int n)
{
	if(n==1){g[0]=ksm(a[0],mdn-2); return;}
	int mid=(n+1)>>1; poly_inv(a,g,mid);
	int lim=1,l=0;
	while(lim<(n<<1))	lim<<=1,l++;
	for(int i=0;i<n;i++)	h[i]=a[i];
	for(int i=n;i<lim;i++)	h[i]=0;
	ntt(h,lim,l,0); ntt(g,lim,l,0);
	for(int i=0;i<lim;i++)
		g[i]=1ll*(mdn+2-1ll*h[i]*g[i]%mdn+mdn)%mdn*g[i]%mdn;
	ntt(g,lim,l,1);
	for(int i=n;i<lim;i++)	g[i]=0;
}
int ff[N*4],fd[N*4],rm[N*4],q[N*4],rf[N*4],irg[N*4];
int st[N],xs[N],sg[N*4],ret[N*4],bs[N*4],a[N*4];
void poly_mod(int *a,int lim,int l)
{
	int mi=(k<<1); while(a[--mi]==0); if(mi<k)return;
    for(int i=0;i<lim;i++)	rf[i]=0;
	for(int i=0;i<=mi;i++)	rf[i]=a[i];
    reverse(rf,rf+mi+1);
	for(int i=mi-k+1;i<=mi;i++)	rf[i]=0;
	ntt(rf,lim,l,0);
    for(int i=0;i<lim;i++)	q[i]=1ll*rf[i]*irg[i]%mdn;
	ntt(q,lim,l,1);
    for(int i=mi-k+1;i<=lim;i++)	q[i]=0;
	reverse(q,q+mi-k+1);ntt(q,lim,l,0);
    for(int i=0;i<lim;i++)	q[i]=1ll*q[i]*sg[i]%mdn;
	ntt(q,lim,l,1);
    for(int i=0;i<k;i++)	a[i]=(a[i]+mdn-q[i])%mdn;
	for(int i=k;i<=mi;i++)	a[i]=0;
}
vector<int> coe[N]; int delta[N],fail[N],bas[N],tot;
void solve(int len,int *a,int *res)
{
	int cur=0;
	for(int i=1;i<=len;i++)
	{
		int tmp=a[i];
		for(int j=0;j<coe[cur].size();j++)
			upd(tmp,mdn-1ll*coe[cur][j]*a[i-j-1]%mdn);
		delta[i]=tmp; if(!tmp)	continue;
		fail[cur]=i;
		if(!cur){coe[++cur].resize(i); delta[i]=a[i]; continue;}
		int id=cur-1,nlen=coe[id].size()-fail[id]+i;
		for(int j=0;j<cur;j++)
			if(i+coe[j].size()-fail[j]<nlen)	nlen=i+coe[j].size()-fail[j],id=j;
		coe[cur+1]=coe[cur],cur++;
		while(coe[cur].size()<nlen)	coe[cur].push_back(0);
		int t=1ll*delta[i]*ksm(delta[fail[id]],mdn-2)%mdn;
		upd(coe[cur][i-fail[id]-1],t);
		for(int j=0;j<coe[id].size();j++)
			upd(coe[cur][i-fail[id]+j],mdn-1ll*t*coe[id][j]%mdn);
	}
	tot=coe[cur].size();
	for(int i=0;i<coe[cur].size();i++)
		bas[i+1]=coe[cur][i];
}
int w[N],remd[N],cur[N];
int main()
{
    int n=read(),m=read();
    for(int i=1;i<=n;i++)	st[i-1]=a[i]=read();
    solve(n,a,bas);
    for(int i=1;i<=tot;i++)	printf("%d ",bas[i]);
    for(int i=1;i<=tot;i++)	sg[tot-i]=(mdn-bas[i])%mdn; sg[tot]=1;
    int l=0,lim=1; k=tot;
	while(lim<=k)	lim<<=1,l++;
	for(int i=0;i<=tot;i++)	ret[i]=sg[i];
    for(int i=0;i<=tot;i++)	rf[i]=sg[i];
	reverse(rf,rf+tot+1); poly_inv(rf,irg,lim);
    for(int i=0;i<=tot;i++)	rf[i]=0; lim<<=1,l++;
    memset(a,0,sizeof(a));
	ntt(sg,lim,l,0); ntt(irg,lim,l,0); a[1]=1; bs[0]=1;
    while(m)
    {
        if(m&1)
        {
            ntt(bs,lim,l,0); ntt(a,lim,l,0);
            for(int i=0;i<lim;i++)	bs[i]=1ll*bs[i]*a[i]%mdn;
            ntt(bs,lim,l,1); ntt(a,lim,l,1); poly_mod(bs,lim,l);
        }
		ntt(a,lim,l,0);
		for(int i=0;i<lim;i++)	a[i]=1ll*a[i]*a[i]%mdn;
        ntt(a,lim,l,1);
		poly_mod(a,lim,l);
		m>>=1;
    }
    int ans=0;
	for(int i=0;i<tot;i++)	ans=1ll*(ans+1ll*bs[i]*st[i]%mdn)%mdn;
	printf("\n%d\n",ans);
    return 0;
}
posted @ 2020-03-24 10:11  寒雨微凝  阅读(553)  评论(0编辑  收藏  举报