矩乘优化学习笔记

矩阵乘法方式,左边的行乘上右边的列,最终答案的行数与左边相等,列数与右边相等

左行右列

矩阵乘法必须在左矩阵列数与右矩阵行数相同时才可以进行

矩阵乘法满足结合律,不满足一般的交换律。

板子:

struct MT{
	int c[7][7],n,m;
	MT(){
		n=m=0;
		memset(c,0x3f,sizeof(c));
	}
	void I(){
		memset(c,0x3f,sizeof(c));
		for(int i=1;i<=n;i++)c[i][i]=0;
	}
	MT friend operator*(MT a,MT b){
		MT c;
		c.n=a.n,c.m=b.m;
		for(int i=1;i<=a.n;i++){
			for(int j=1;j<=b.m;j++){
				for(int k=1;k<=a.m;k++)c.c[i][j]=min(c.c[i][j],a.c[i][k]+b.c[k][j]);
			}
		}
		return c;
	}
};

常见优化

  1. 循环展开,直接将矩阵乘法展开
  2. 缩短查询路径,也是优化矩阵
  3. 矩阵加速递推的快速幂,唯一一个优化了时间复杂度的

应用

矩阵加速递推

致敬传奇斐波那契。

可以用矩阵存下对下一步有影响的值,然后通过各种换算得到下一步时的这个值

由于我们是直接调用原矩阵的元素,所以一定要注意目前的状态是否确定

当然这样是远远不够的,由于矩阵乘法符合交换律,直接快速幂即可

最有意思的应该是 Another kind of Fibonacci

众所周知,斐波那契数列:F(0) = 1, F(1) = 1, F(N) = F(N - 1) + F(N - 2) (N >= 2)。现在我们定义另一种斐波那契数列:A(0) = 1, A(1) = 1, A(N) = X * A(N - 1) + Y * A(N - 2) (N >= 2)。我们想要计算S(N),S(N) = A(0)2 +A(1)2+……+A(n)2。

这里需要把新得到的数的平方和乘积都得到,需要推导式子,拆掉新得到的数字,考虑乘积的增加量,得到最终答案。

代码:

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int mod=10007;
int t,n,x,y;
struct MT{
	int n,m,c[20][20];
	MT(){
		n=m=0;
		memset(c,0,sizeof(c));
	}
	void I(){
		memset(c,0,sizeof(c));
		for(int i=1;i<=n;i++)c[i][i]=1;
	}
	void clear(){
		memset(c,0,sizeof(c));
	}
	MT friend operator*(MT a,MT b){
		MT c;
		c.n=a.n,c.m=b.m;
		for(int i=1;i<=a.n;i++){
			for(int j=1;j<=b.m;j++){
				for(int k=1;k<=a.m;k++){
					c.c[i][j]+=(a.c[i][k]*b.c[k][j])%mod;
					c.c[i][j]%=mod;
				}
			}
		}
		return c;
	}
	void input(){
		for(int i=1;i<=n;i++){
			for(int j=1;j<=m;j++)cin>>c[i][j];
		}
	}
}base,be;
void ksm(MT a,int b){
	while(b){
		if(b&1)be=be*a;
		a=a*a;
		b>>=1;
	}
}
signed main(){
	while(cin>>n>>x>>y){
		x%=mod;
		y%=mod;
		be.n=1,be.m=4;
		be.c[1][1]=1,be.c[1][2]=1,be.c[1][3]=1,be.c[1][4]=1;
		base.n=base.m=4;
		base.c[1][1]=1;
		base.c[2][1]=1;
		base.c[2][2]=(x*x)%mod;
		base.c[3][2]=(y*y)%mod;
		base.c[4][2]=(2*x*y)%mod;
		base.c[2][3]=1;
		base.c[2][4]=x;
		base.c[4][4]=y;
		ksm(base,n);
		cout<<be.c[1][1]<<endl;
//		cout<<be.c[1][1]<<' '<<be.c[1][2]<<' '<<be.c[1][3]<<' '<<be.c[1][4]<<endl;
	}
	return 0;
}

矩阵表达修改

和oi-wiki上的例题一样,大魔法师,先预处理出矩阵,在线段数里面放矩阵即可,还是比较水的题目。

代码:

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int mod=998244353;
int n,op,l,r,v,m;
int read(){
	char c=getchar();
	int x=0;
	while(c<'0'||c>'9')c=getchar();
	while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+(c^48),c=getchar();
	return x;
}
int add(int x,int y){
	int ans=x+y;
	if(ans>=mod)ans-=mod;
	return ans;
}
struct MT{
	int n,m,c[5][5];
	MT(){
		n=m=0;
	}
	MT (int _n,int _m){
		n=_n;
		m=_m;
		for(int i=1;i<=n;i++){
			for(int j=1;j<=m;j++)c[i][j]=0;
		}
	}
	void I(){
		for(int i=1;i<=n;i++){
			for(int j=1;j<=m;j++)c[i][j]=0;
		}
		for(int i=1;i<=n;i++)c[i][i]=1;
	}
	void input(){
		for(int j=1;j<=3;j++)c[1][j]=read();
		c[1][4]=1;
	}
	MT friend operator*(MT a,MT b){
		MT c(a.n,b.m);
		for(int i=1;i<=a.n;i++){
			for(int j=1;j<=b.m;j++){
				for(int k=1;k<=a.m;k++){
					c.c[i][j]+=(a.c[i][k]*b.c[k][j])%mod;
				}
				c.c[i][j]%=mod;
			}
		}
		return c;
	}
	MT friend operator+(MT a,MT b){
		MT c;
		c.n=a.n;
		c.m=b.m;
		c.c[1][1]=add(a.c[1][1],b.c[1][1]);
		c.c[1][2]=add(a.c[1][2],b.c[1][2]);
		c.c[1][3]=add(a.c[1][3],b.c[1][3]);
		c.c[1][4]=add(a.c[1][4],b.c[1][4]);
		c.c[2][1]=add(a.c[2][1],b.c[2][1]);
		c.c[2][2]=add(a.c[2][2],b.c[2][2]);
		c.c[2][3]=add(a.c[2][3],b.c[2][3]);
		c.c[2][4]=add(a.c[2][4],b.c[2][4]);
		c.c[3][1]=add(a.c[3][1],b.c[3][1]);
		c.c[3][2]=add(a.c[3][2],b.c[3][2]);
		c.c[3][3]=add(a.c[3][3],b.c[3][3]);
		c.c[3][4]=add(a.c[3][4],b.c[3][4]);
		c.c[4][1]=add(a.c[4][1],b.c[4][1]);
		c.c[4][2]=add(a.c[4][2],b.c[4][2]);
		c.c[4][3]=add(a.c[4][3],b.c[4][3]);
		c.c[4][4]=add(a.c[4][4],b.c[4][4]);
		return c;
	}
	void print(){
		for(int j=1;j<=3;j++)printf("%lld ",c[1][j]);
		puts("");
	}
}q[10];
struct ST{
	MT c[1000005],tag[1000005];
	#define ls p<<1
	#define rs p<<1|1
	void pushup(int p){
		c[p]=c[ls]+c[rs];
	}
	void build(int p,int l,int r){
		c[p].n=1;
		c[p].m=4;
		tag[p].n=tag[p].m=4;
		tag[p].I();
		if(l==r)return c[p].input();
		int mid=l+r>>1;
		build(ls,l,mid),build(rs,mid+1,r);
		pushup(p);
	}
	void Tag(int p,MT v){
		c[p]=c[p]*v;
		tag[p]=tag[p]*v;
	}
	void pushdown(int p){
		Tag(ls,tag[p]);
		Tag(rs,tag[p]);
		tag[p].I();
	}
	void change(int p,int l,int r,int L,int R,MT v){
		if(l>=L&&r<=R)return Tag(p,v);
		pushdown(p);
		int mid=l+r>>1;
		if(mid>=L)change(ls,l,mid,L,R,v);
		if(mid<R)change(rs,mid+1,r,L,R,v);
		pushup(p);
	}
	MT query(int p,int l,int r,int L,int R){
		if(l>=L&&r<=R)return c[p];
		pushdown(p);
		int mid=l+r>>1;
		if(mid>=L&&mid<R)return query(ls,l,mid,L,R)+query(rs,mid+1,r,L,R);
		if(mid>=L)return query(ls,l,mid,L,R);
		return query(rs,mid+1,r,L,R);
	}
}seg;
signed main(){
	q[1].n=q[1].m=4;
	q[1].I();
	q[1].c[2][1]=1;
	q[2].n=q[2].m=4;
	q[2].I();
	q[2].c[3][2]=1;
	q[3].n=q[3].m=4;
	q[3].I();
	q[3].c[1][3]=1;
	cin>>n;
	seg.build(1,1,n);
	cin>>m;
	while(m--){
		op=read(),l=read(),r=read();
		if(op<=3)seg.change(1,1,n,l,r,q[op]);
		else if(op==7)seg.query(1,1,n,l,r).print();
		else {
			v=read();
			MT tmp(4,4);
			tmp.I();
			if(op==4)tmp.c[4][1]=v;
			if(op==5)tmp.c[2][2]=v;
			if(op==6)tmp.c[3][3]=0,tmp.c[4][3]=v;
			seg.change(1,1,n,l,r,tmp);
		}
	}
	return 0;
}

一系列的图上路径问题

虽然oi-wiki上的内容很多,但是实际上都差不多,重要的是关注每一条路径走一遍,可以通过矩阵倍增处理。

就是通过这种方式来固定走的边数,再check一下即可。

不管是判环还是什么都可以

例题:

Gremlin的繁殖

代码:

#include<bits/stdc++.h>
#define int long long
using namespace std;
int n,t,k,y,g[1005],h[1005];
struct MT{
	int c[105][105];
	MT(){
		memset(c,0x3f,sizeof(c));
	}
	MT friend operator*(MT a,MT b){
		MT c;
		for(int i=1;i<=n;i++){
			for(int j=1;j<=n;j++){
				for(int k=1;k<=n;k++)c.c[i][j]=min(c.c[i][j],a.c[i][k]+b.c[k][j]);
			}
		}
		return c;
	}
	bool check(){
		for(int i=1;i<=n;i++){
			for(int j=1;j<=n;j++){
				if(c[i][j]<=t)return true;
			}
		}
		return false;
	}
}st[51],be,tmp,tmp2;
signed main(){
	cin>>n>>t;
	for(int i=1;i<=n;i++){
		cin>>k>>y;
		for(int j=1;j<=k;j++)cin>>g[j];
		for(int j=1;j<=k;j++)cin>>h[j];
		for(int j=1;j<=k;j++){
			st[0].c[i][g[j]]=min(h[j]+y,st[0].c[i][g[j]]);
		}
	}
	for(int i=1;i<=50;i++)st[i]=st[i-1]*st[i-1];
	for(int i=1;i<=n;i++)be.c[i][i]=0;
	int ans=0;
	for(int i=50;i>=0;i--){
		tmp=be*st[i];
		if(tmp.check()){
			ans+=(1ll<<i);
			be=tmp;
		}
	}
	cout<<ans;
	return 0;
}
posted @ 2025-11-28 09:42  huhangqi  阅读(0)  评论(0)    收藏  举报
/*
*/