矩乘优化学习笔记
矩阵乘法方式,左边的行乘上右边的列,最终答案的行数与左边相等,列数与右边相等
左行右列
矩阵乘法必须在左矩阵列数与右矩阵行数相同时才可以进行
矩阵乘法满足结合律,不满足一般的交换律。
板子:
struct MT{
int c[7][7],n,m;
MT(){
n=m=0;
memset(c,0x3f,sizeof(c));
}
void I(){
memset(c,0x3f,sizeof(c));
for(int i=1;i<=n;i++)c[i][i]=0;
}
MT friend operator*(MT a,MT b){
MT c;
c.n=a.n,c.m=b.m;
for(int i=1;i<=a.n;i++){
for(int j=1;j<=b.m;j++){
for(int k=1;k<=a.m;k++)c.c[i][j]=min(c.c[i][j],a.c[i][k]+b.c[k][j]);
}
}
return c;
}
};
常见优化
- 循环展开,直接将矩阵乘法展开
- 缩短查询路径,也是优化矩阵
- 矩阵加速递推的快速幂,唯一一个优化了时间复杂度的
应用
矩阵加速递推
致敬传奇斐波那契。
可以用矩阵存下对下一步有影响的值,然后通过各种换算得到下一步时的这个值
由于我们是直接调用原矩阵的元素,所以一定要注意目前的状态是否确定
当然这样是远远不够的,由于矩阵乘法符合交换律,直接快速幂即可
最有意思的应该是 Another kind of Fibonacci
众所周知,斐波那契数列:F(0) = 1, F(1) = 1, F(N) = F(N - 1) + F(N - 2) (N >= 2)。现在我们定义另一种斐波那契数列:A(0) = 1, A(1) = 1, A(N) = X * A(N - 1) + Y * A(N - 2) (N >= 2)。我们想要计算S(N),S(N) = A(0)2 +A(1)2+……+A(n)2。
这里需要把新得到的数的平方和乘积都得到,需要推导式子,拆掉新得到的数字,考虑乘积的增加量,得到最终答案。
代码:
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int mod=10007;
int t,n,x,y;
struct MT{
int n,m,c[20][20];
MT(){
n=m=0;
memset(c,0,sizeof(c));
}
void I(){
memset(c,0,sizeof(c));
for(int i=1;i<=n;i++)c[i][i]=1;
}
void clear(){
memset(c,0,sizeof(c));
}
MT friend operator*(MT a,MT b){
MT c;
c.n=a.n,c.m=b.m;
for(int i=1;i<=a.n;i++){
for(int j=1;j<=b.m;j++){
for(int k=1;k<=a.m;k++){
c.c[i][j]+=(a.c[i][k]*b.c[k][j])%mod;
c.c[i][j]%=mod;
}
}
}
return c;
}
void input(){
for(int i=1;i<=n;i++){
for(int j=1;j<=m;j++)cin>>c[i][j];
}
}
}base,be;
void ksm(MT a,int b){
while(b){
if(b&1)be=be*a;
a=a*a;
b>>=1;
}
}
signed main(){
while(cin>>n>>x>>y){
x%=mod;
y%=mod;
be.n=1,be.m=4;
be.c[1][1]=1,be.c[1][2]=1,be.c[1][3]=1,be.c[1][4]=1;
base.n=base.m=4;
base.c[1][1]=1;
base.c[2][1]=1;
base.c[2][2]=(x*x)%mod;
base.c[3][2]=(y*y)%mod;
base.c[4][2]=(2*x*y)%mod;
base.c[2][3]=1;
base.c[2][4]=x;
base.c[4][4]=y;
ksm(base,n);
cout<<be.c[1][1]<<endl;
// cout<<be.c[1][1]<<' '<<be.c[1][2]<<' '<<be.c[1][3]<<' '<<be.c[1][4]<<endl;
}
return 0;
}
矩阵表达修改
和oi-wiki上的例题一样,大魔法师,先预处理出矩阵,在线段数里面放矩阵即可,还是比较水的题目。
代码:
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int mod=998244353;
int n,op,l,r,v,m;
int read(){
char c=getchar();
int x=0;
while(c<'0'||c>'9')c=getchar();
while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+(c^48),c=getchar();
return x;
}
int add(int x,int y){
int ans=x+y;
if(ans>=mod)ans-=mod;
return ans;
}
struct MT{
int n,m,c[5][5];
MT(){
n=m=0;
}
MT (int _n,int _m){
n=_n;
m=_m;
for(int i=1;i<=n;i++){
for(int j=1;j<=m;j++)c[i][j]=0;
}
}
void I(){
for(int i=1;i<=n;i++){
for(int j=1;j<=m;j++)c[i][j]=0;
}
for(int i=1;i<=n;i++)c[i][i]=1;
}
void input(){
for(int j=1;j<=3;j++)c[1][j]=read();
c[1][4]=1;
}
MT friend operator*(MT a,MT b){
MT c(a.n,b.m);
for(int i=1;i<=a.n;i++){
for(int j=1;j<=b.m;j++){
for(int k=1;k<=a.m;k++){
c.c[i][j]+=(a.c[i][k]*b.c[k][j])%mod;
}
c.c[i][j]%=mod;
}
}
return c;
}
MT friend operator+(MT a,MT b){
MT c;
c.n=a.n;
c.m=b.m;
c.c[1][1]=add(a.c[1][1],b.c[1][1]);
c.c[1][2]=add(a.c[1][2],b.c[1][2]);
c.c[1][3]=add(a.c[1][3],b.c[1][3]);
c.c[1][4]=add(a.c[1][4],b.c[1][4]);
c.c[2][1]=add(a.c[2][1],b.c[2][1]);
c.c[2][2]=add(a.c[2][2],b.c[2][2]);
c.c[2][3]=add(a.c[2][3],b.c[2][3]);
c.c[2][4]=add(a.c[2][4],b.c[2][4]);
c.c[3][1]=add(a.c[3][1],b.c[3][1]);
c.c[3][2]=add(a.c[3][2],b.c[3][2]);
c.c[3][3]=add(a.c[3][3],b.c[3][3]);
c.c[3][4]=add(a.c[3][4],b.c[3][4]);
c.c[4][1]=add(a.c[4][1],b.c[4][1]);
c.c[4][2]=add(a.c[4][2],b.c[4][2]);
c.c[4][3]=add(a.c[4][3],b.c[4][3]);
c.c[4][4]=add(a.c[4][4],b.c[4][4]);
return c;
}
void print(){
for(int j=1;j<=3;j++)printf("%lld ",c[1][j]);
puts("");
}
}q[10];
struct ST{
MT c[1000005],tag[1000005];
#define ls p<<1
#define rs p<<1|1
void pushup(int p){
c[p]=c[ls]+c[rs];
}
void build(int p,int l,int r){
c[p].n=1;
c[p].m=4;
tag[p].n=tag[p].m=4;
tag[p].I();
if(l==r)return c[p].input();
int mid=l+r>>1;
build(ls,l,mid),build(rs,mid+1,r);
pushup(p);
}
void Tag(int p,MT v){
c[p]=c[p]*v;
tag[p]=tag[p]*v;
}
void pushdown(int p){
Tag(ls,tag[p]);
Tag(rs,tag[p]);
tag[p].I();
}
void change(int p,int l,int r,int L,int R,MT v){
if(l>=L&&r<=R)return Tag(p,v);
pushdown(p);
int mid=l+r>>1;
if(mid>=L)change(ls,l,mid,L,R,v);
if(mid<R)change(rs,mid+1,r,L,R,v);
pushup(p);
}
MT query(int p,int l,int r,int L,int R){
if(l>=L&&r<=R)return c[p];
pushdown(p);
int mid=l+r>>1;
if(mid>=L&&mid<R)return query(ls,l,mid,L,R)+query(rs,mid+1,r,L,R);
if(mid>=L)return query(ls,l,mid,L,R);
return query(rs,mid+1,r,L,R);
}
}seg;
signed main(){
q[1].n=q[1].m=4;
q[1].I();
q[1].c[2][1]=1;
q[2].n=q[2].m=4;
q[2].I();
q[2].c[3][2]=1;
q[3].n=q[3].m=4;
q[3].I();
q[3].c[1][3]=1;
cin>>n;
seg.build(1,1,n);
cin>>m;
while(m--){
op=read(),l=read(),r=read();
if(op<=3)seg.change(1,1,n,l,r,q[op]);
else if(op==7)seg.query(1,1,n,l,r).print();
else {
v=read();
MT tmp(4,4);
tmp.I();
if(op==4)tmp.c[4][1]=v;
if(op==5)tmp.c[2][2]=v;
if(op==6)tmp.c[3][3]=0,tmp.c[4][3]=v;
seg.change(1,1,n,l,r,tmp);
}
}
return 0;
}
一系列的图上路径问题
虽然oi-wiki上的内容很多,但是实际上都差不多,重要的是关注每一条路径走一遍,可以通过矩阵倍增处理。
就是通过这种方式来固定走的边数,再check一下即可。
不管是判环还是什么都可以
例题:
代码:
#include<bits/stdc++.h>
#define int long long
using namespace std;
int n,t,k,y,g[1005],h[1005];
struct MT{
int c[105][105];
MT(){
memset(c,0x3f,sizeof(c));
}
MT friend operator*(MT a,MT b){
MT c;
for(int i=1;i<=n;i++){
for(int j=1;j<=n;j++){
for(int k=1;k<=n;k++)c.c[i][j]=min(c.c[i][j],a.c[i][k]+b.c[k][j]);
}
}
return c;
}
bool check(){
for(int i=1;i<=n;i++){
for(int j=1;j<=n;j++){
if(c[i][j]<=t)return true;
}
}
return false;
}
}st[51],be,tmp,tmp2;
signed main(){
cin>>n>>t;
for(int i=1;i<=n;i++){
cin>>k>>y;
for(int j=1;j<=k;j++)cin>>g[j];
for(int j=1;j<=k;j++)cin>>h[j];
for(int j=1;j<=k;j++){
st[0].c[i][g[j]]=min(h[j]+y,st[0].c[i][g[j]]);
}
}
for(int i=1;i<=50;i++)st[i]=st[i-1]*st[i-1];
for(int i=1;i<=n;i++)be.c[i][i]=0;
int ans=0;
for(int i=50;i>=0;i--){
tmp=be*st[i];
if(tmp.check()){
ans+=(1ll<<i);
be=tmp;
}
}
cout<<ans;
return 0;
}

浙公网安备 33010602011771号