转置原理口胡

转置原理口胡


抄自:cy的WC2020课件、rqy的uoj博客和个人博客

本文没有任何严谨证明,基本都是博主自己口胡的。

也不保证不会出锅,因为博主是垃圾。

线性算法

见cy的WC2020课件。

转化

有一个DAG,点数为\(n+k+m\),满足所有边形如u v a\((n<u<v)\)

DAG分为三部分:前\(n\)个点(\([1,n]\))是输入,中间\(k\)个(\([n+1,n+k]\))是中间变量,后\(m\)个(\([n+k+1,n+k+m] \))是输出。

在这个DAG上运行如下算法:

  • 输入\(n\)个点的权值。
  • 按编号顺序遍历后\(k+m\)个点,将每个点的权值设为\(w_v=\sum_{(u,v,a)\in E}a\times w_u\)
  • 输出\(m\)个点的权值。

可以发现这个模型可以实现任何线性算法。

容易发现,\(\forall i>n\)\(w_i\)\(w_1,w_2,\cdots,w_n\)的线性组合。

所以存在一个\(n\times m\)的矩阵\(A\)\(A_{i,j}\)表示输入\(w_i\)到输出\(w_{n+k+j}\)的贡献。

现在将\([n+k+1,n+k+m]\)看做输入,\([1,n]\)看做输出,边全部反向,并保持边权不变。

反向运行上述算法,容易发现,\(\forall i\le n+k\)\(w_i\)\(w_{n+k+1},w_{n+k+2},\cdots,w_{n+k+m}\)的线性组合。

所以存在一个\(m\times n\)的矩阵\(B\)\(B_{i,j}\)表示输入\(w_{n+k+i}\)到输出\(w_j\)的贡献。

上述两个矩阵的“贡献”其实就是路径权值和(路径的权值是所有边权积)所以显然有\(A_{i,j}=B_{j,i}\)\(A=B^T\)

原来的dag对应一个线性算法,新的dag对应另外一个线性算法,它们的计算次数完全相同。(至于加法次数的略微差别是因为第一次加法可以直接赋值,这个可以忽略)

实际的线性算法是可以复用空间的,但这里我懒得写了。最后可以对应到cy的PPT中的构造。

多点求值

给定多项式\(F(x)=\sum_{i=0}^{n-1}f_ix^i\)

\(ans_i=F(q_i),i\in[0,m)\)

问题是

\[\begin{bmatrix} f_0&f_1&\cdots&f_{n-1} \end{bmatrix} \times \begin{bmatrix} q_0^0&q_1^0&\cdots&q_{m-1}^0\\ q_0^1&q_1^1&\cdots&q_{m-1}^1\\ \vdots&\vdots&\ddots&\vdots\\ q_0^{n-1}&q_1^{n-1}&\cdots&q_{m-1}^{n-1}\\ \end{bmatrix} = \begin{bmatrix} ans_0&ans_1&\cdots&ans_{m-1} \end{bmatrix} \]

看成\(uA=v\),考虑问题\(u'A^T=v'\),它的输入是\(u'\)输出是\(v'\)。即:

\[\begin{bmatrix} g_0&g_1&\cdots&g_{m-1} \end{bmatrix} \times \begin{bmatrix} q_0^0&q_0^1&\cdots&q_0^{n-1}\\ q_1^0&q_1^1&\cdots&q_1^{n-1}\\ \vdots&\vdots&\ddots&\vdots\\ q_{m-1}^0&q_{m-1}^1&\cdots&q_{m-1}^{n-1}\\ \end{bmatrix} = \begin{bmatrix} b_0&b_1&\cdots&b_{n-1} \end{bmatrix} \]

可以发现求出\(\sum_{j=0}^{m-1}\frac{g[j]}{1-xq_j}\)就完事儿了,这个很好做,分治就完事了

新问题的求解过程如下

  • 新问题以\(g_0,g_1,\cdots,g_{m-1}\)作为输入,\(b_0,b_1,\cdots,b_{n-1}\)作为输出
  • \(q\)始终是常量,不参与转置
  • 线段树上维护两个信息:\(P_x,Q_x\),分别表示这个节点对应区间的\(\sum\frac{g[j]}{1-xq_j}\)的分子和分母
  • 由于\(q\)为常量\(Q\)可以直接确定,只需向上求\(P_x=P_{ls}Q_{rs}+P_{rs}Q_{ls}\)
  • 求出线段树根节点的\(P_1Q_1^{-1}\)\(0\)\(n-1\)项系数即为答案

多项式乘法

见cy的PPT

转置回原问题,求解过程

分治的过程\(P_x=P_{ls}Q_{rs}+P_{rs}Q_{ls}\)可以看成是:

P2_ls=P_ls*Q_rs
P2_rs=P_rs*Q_ls
Px=Px+P2_ls
Px=Px+P2_rs

重写之后是:

P2_rs=P2_rs+Px
P2_ls=P2_ls+Px
P_rs=P2_rs *^T Q_ls
P_ls=P2_ls *^T Q_rs
  • \(f_0,f_1,\cdots,f_{n-1}\)作为输入,\(ans_0,ans_1,\cdots,ans_{m-1}\)作为输出
  • 一开始先分治求出所有的\(Q\)\(q\)始终是常量)
  • 计算\(f\times^T Q_1^{-1}\),保留\(m\)\(n+m-2\)项系数作为\(P_1\)
  • 从上到下分治,令\(P_{ls}=P_{x}\times^TQ_{rs},P_{rs}=P_{x}\times^TQ_{ls}\)
  • \(ans_i\)为第\(i\)个叶子的\(P\)中常数项

因为博主水平不行,必须要有\(m\ge n\)以保证复杂度(如果\(m<n\),因为\(P_1\)次数为\(n\)将导致所有的\(P\)次数增加\(n-m\)

#include<bits/stdc++.h>
typedef long long ll;
#define mod 998244353
#define poly std::vector<int>
ll gi(){
	ll x=0,f=1;
	char ch=getchar();
	while(!isdigit(ch))f^=ch=='-',ch=getchar();
	while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
	return f?x:-x;
}
std::mt19937 rnd(time(NULL));
#define rand rnd
#define pr std::pair<int,int>
#define all(x) (x).begin(),(x).end()
#define fi first
#define se second
template<class T>void cxk(T&a,T b){a=a>b?a:b;}
template<class T>void cnk(T&a,T b){a=a<b?a:b;}
#ifdef mod
int pow(int x,int y){
	int ret=1;
	while(y){
		if(y&1)ret=1ll*ret*x%mod;
		x=1ll*x*x%mod;y>>=1;
	}
	return ret;
}
template<class Ta,class Tb>void inc(Ta&a,Tb b){a=a+b>=mod?a+b-mod:a+b;}
template<class Ta,class Tb>void dec(Ta&a,Tb b){a=a>=b?a-b:a+mod-b;}
template<class Ta,class Tb>int sub(Ta&a,Tb b){return a>=b?a-b:a+mod-b;}
#endif
int coef[65539],qx[65539],Q[262147],rev[131113],A[131113],B[131113],N,lg,ans[65539];
void setN(int n){
	lg=32-__builtin_clz(n),N=1<<lg;
	for(int i=0;i<N;++i)rev[i]=(rev[i>>1]>>1)|((i&1)<<lg-1);
}
int getN(int n){return 1<<32-__builtin_clz(n);}
void ntt(int*A,int t){
	for(int i=0;i<N;++i)if(i>rev[i])std::swap(A[i],A[rev[i]]);
	for(int o=1,*qq=Q+o*2;o<N;o<<=1,qq=Q+o*2)
		for(int*p=A;p!=A+N;p+=o<<1)
			for(int i=0;i<o;++i){
				int t=1ll*p[i+o]*qq[i]%mod;
				p[i+o]=sub(p[i],t),inc(p[i],t);
			}
	if(!t){
		std::reverse(A+1,A+N);
		for(int i=0,iv=pow(N,mod-2);i<N;++i)A[i]=1ll*A[i]*iv%mod;
	}
}
poly mul(const poly&x,const poly&y){
	int len=x.size()+y.size()-1;setN(len);
	memset(A,0,N<<2);memset(B,0,N<<2);
	for(int i=0;i<x.size();++i)A[i]=x[i];
	for(int i=0;i<y.size();++i)B[i]=y[i];
	ntt(A,1),ntt(B,1);for(int i=0;i<N;++i)A[i]=1ll*A[i]*B[i]%mod;ntt(A,0);
	poly z(len);for(int i=0;i<len;++i)z[i]=A[i];
	return z;
}
poly mulT(const poly&x,const poly&y){
	int len=x.size();setN(len);
	memset(A,0,N<<2);memset(B,0,N<<2);
	for(int i=0;i<x.size();++i)A[i]=x[i];
	for(int i=0;i<y.size();++i)B[i]=y[i];
	std::reverse(B,B+y.size());
	ntt(A,1),ntt(B,1);for(int i=0;i<N;++i)A[i]=1ll*A[i]*B[i]%mod;ntt(A,0);
	poly z(x.size()-y.size()+1);for(int i=0;i<z.size();++i)z[i]=A[i+y.size()-1];return z;
}
poly getinv(poly x){
	if(x.size()==1)return{pow(x[0],mod-2)};
	int n=x.size(),m=x.size()+1>>1;
	poly y(x.begin(),x.begin()+m),_y;_y=y=getinv(y);
	setN(x.size()*2+2);y.resize(N);x.resize(N);
	ntt(&y[0],1);ntt(&x[0],1);
	for(int i=0;i<N;++i)x[i]=1ll*x[i]*y[i]%mod*y[i]%mod;
	ntt(&x[0],0);
	for(int i=0;i<n;++i)x[i]=((i<m?2ll*_y[i]:0ll)-x[i]+mod)%mod;
	x.resize(n);return x;
}
#define mid ((l+r)>>1)
poly qwq[262147],qaq[262147];
void divide1(int x,int l,int r){
	if(l==r){qwq[x]={1,mod-qx[l]};return;}
	divide1(x<<1,l,mid),divide1(x<<1|1,mid+1,r);
	qwq[x]=mul(qwq[x<<1],qwq[x<<1|1]);
}
void divide2(int x,int l,int r){
	if(l==r){ans[l]=qaq[x][0];return;}
	setN(qaq[x].size());
	{
		poly&a=qaq[x],&y=qwq[x<<1|1],&z=qaq[x<<1];
		memset(A,0,N<<2);memset(B,0,N<<2);
		for(int i=0;i<a.size();++i)A[i]=a[i];
		for(int i=0;i<y.size();++i)B[i]=y[i];
		std::reverse(B,B+y.size());
		ntt(A,1),ntt(B,1);for(int i=0;i<N;++i)B[i]=1ll*A[i]*B[i]%mod;ntt(B,0);
		z.resize(a.size()-y.size()+1);
		for(int i=0;i<z.size();++i)z[i]=B[i+y.size()-1];
	}
	{
		poly&a=qaq[x],&y=qwq[x<<1],&z=qaq[x<<1|1];
		memset(B,0,N<<2);
		for(int i=0;i<y.size();++i)B[i]=y[i];
		std::reverse(B,B+y.size());
		ntt(B,1);for(int i=0;i<N;++i)B[i]=1ll*A[i]*B[i]%mod;ntt(B,0);
		z.resize(a.size()-y.size()+1);
		for(int i=0;i<z.size();++i)z[i]=B[i+y.size()-1];
	}
	divide2(x<<1,l,mid);
	divide2(x<<1|1,mid+1,r);
}
int main(){
#ifdef LOCAL
	freopen("in.in","r",stdin);
	//freopen("out.out","w",stdout);
#endif
	for(int o=1;o<=(1<<17);o<<=1){
		int P=pow(19260817,mod/o);
		Q[o]=1;for(int i=1;i<o;++i)Q[i+o]=1ll*Q[i+o-1]*P%mod;
	}
	int n=gi()+1,m=gi(),_m=m;
	for(int i=0;i<n;++i)coef[i]=gi();
	for(int i=0;i<m;++i)qx[i]=gi();
	cxk(m,n);divide1(1,0,m-1);
	poly sfm=getinv(qwq[1]);std::reverse(all(sfm));
	poly s=mul(poly(coef,coef+n),sfm);
	qaq[1]=poly(s.begin()+m,s.end());
	qaq[1].resize(m+1);
	divide2(1,0,m-1);
	for(int i=0;i<_m;++i)printf("%d\n",ans[i]);
	return 0;
}
posted @ 2020-09-03 15:10  菜狗xzz  阅读(838)  评论(3编辑  收藏  举报