【*矩阵运算】你不得不会的线性代数/点乘和矩阵乘法的区别/如何加速运算和不保留中间结果(防止爆内存MLE)

【前置知识】矩阵乘法、点乘的计算以及点积和叉积的分辨

0ba8a323f2a9eb6c8492a2c95ce2e4ab

1.矩阵点乘就是矩阵乘法?
//未完,我去写机器学习代码去了;

1.思路两个:
第一个:保存中间结果,temp然后模拟自然矩阵计算流程,先计算\(Q\times K^T\)再计算\(W\cdot temp\),再计算\(temp\times V\);
但是这样做的话由于nxn十分巨大,最多需要多达\(10^{8}*lli(8bytes=64bits)\)的大小,约等于\(8*10^8B/1024\times 1024B=762MB\);由于空间限制为:512 MB,因此一定会爆;

但是d就小得多,因此我们需要尝试一次性把全部结果都计算出来,不保留中间结果;
也就是思路2;

以下代码被注释掉的是思路1,新的是思路2;
代码:

#include<iostream>
#include<vector>
typedef long long int lli;
using namespace std;
int n,d;

void QxKT(vector<vector<lli>> &res,vector<vector<int>>& Q,vector<vector<int>>& K){
    for(int i = 0; i < n;i++){
        for(int j = 0;j < n;j++){
            for(int k = 0;k < d;k++){
                res[i][j] += Q[i][k]*K[j][k];
            }
        }
    }
}
void Wxtemp(vector<vector<lli>> &res,vector<int>&W){
    for(int i =0; i < n ;i++){
        for(int j = 0;j < n;j++){
            res[i][j] = res[i][j]*W[i];
        }
    }
}

void tempxV(vector<vector<lli>>& Res,vector<vector<lli>>& temp,vector<vector<int>>& V){
    for(int i = 0;i < n;i++){
        for(int j = 0; j < n;j++){
            for(int k = 0;k < d;k++){
                Res[i][k] += temp[i][j]*V[j][k];
            }
        }
    }
}


int main(){
    cin >> n >> d;
    vector<vector<int>> Q(n+1,vector<int>(d+1));
    vector<vector<int>> K(n+1,vector<int>(d+1));
    vector<vector<int>> V(n+1,vector<int>(d+1));
    vector<int> W(n+1);
    vector<vector<lli>> Res(n+1,vector<lli>(d+1));

    for(int i = 0;i < n;i++){
        for(int j = 0;j < d;j++){
            cin >> Q[i][j];
        }
    }
    for(int i = 0;i < n;i++){
        for(int j = 0;j < d;j++){
            cin >> K[i][j];
        }
    }
    for(int i = 0;i < n;i++){
        for(int j = 0;j < d;j++){
            cin >> V[i][j];
        }
    }

    for(int i = 0;i < n;i++){
        cin >> W[i];
    }

    // vector<vector<lli>> temp(n+1,vector<lli>(n+1));//中间矩阵nxn
    // QxKT(temp,Q,K);
    // Wxtemp(temp,W);
    // tempxV(Res,temp,V);
    for(int i = 0; i < n; i++){
        for(int j = 0; j < n; j++){
            long long qk = 0;
            for(int t = 0; t < d; t++){
                qk += Q[i][t] * K[j][t];
            }
            qk *= W[i];
            for(int k = 0; k < d; k++){
                Res[i][k] += qk * V[j][k];
            }
        }
    }

    for(int i = 0;i < n;i++){
        for(int j = 0;j < d;j++){
            cout << Res[i][j] << " ";
        }
        cout << endl;
    }
}
posted @ 2025-12-02 16:15  q_z_chen  阅读(0)  评论(0)    收藏  举报