【*矩阵运算】你不得不会的线性代数/点乘和矩阵乘法的区别/如何加速运算和不保留中间结果(防止爆内存MLE)
【前置知识】矩阵乘法、点乘的计算以及点积和叉积的分辨

1.矩阵点乘就是矩阵乘法?
//未完,我去写机器学习代码去了;
1.思路两个:
第一个:保存中间结果,temp然后模拟自然矩阵计算流程,先计算\(Q\times K^T\)再计算\(W\cdot temp\),再计算\(temp\times V\);
但是这样做的话由于nxn十分巨大,最多需要多达\(10^{8}*lli(8bytes=64bits)\)的大小,约等于\(8*10^8B/1024\times 1024B=762MB\);由于空间限制为:512 MB,因此一定会爆;
但是d就小得多,因此我们需要尝试一次性把全部结果都计算出来,不保留中间结果;
也就是思路2;
以下代码被注释掉的是思路1,新的是思路2;
代码:
#include<iostream>
#include<vector>
typedef long long int lli;
using namespace std;
int n,d;
void QxKT(vector<vector<lli>> &res,vector<vector<int>>& Q,vector<vector<int>>& K){
for(int i = 0; i < n;i++){
for(int j = 0;j < n;j++){
for(int k = 0;k < d;k++){
res[i][j] += Q[i][k]*K[j][k];
}
}
}
}
void Wxtemp(vector<vector<lli>> &res,vector<int>&W){
for(int i =0; i < n ;i++){
for(int j = 0;j < n;j++){
res[i][j] = res[i][j]*W[i];
}
}
}
void tempxV(vector<vector<lli>>& Res,vector<vector<lli>>& temp,vector<vector<int>>& V){
for(int i = 0;i < n;i++){
for(int j = 0; j < n;j++){
for(int k = 0;k < d;k++){
Res[i][k] += temp[i][j]*V[j][k];
}
}
}
}
int main(){
cin >> n >> d;
vector<vector<int>> Q(n+1,vector<int>(d+1));
vector<vector<int>> K(n+1,vector<int>(d+1));
vector<vector<int>> V(n+1,vector<int>(d+1));
vector<int> W(n+1);
vector<vector<lli>> Res(n+1,vector<lli>(d+1));
for(int i = 0;i < n;i++){
for(int j = 0;j < d;j++){
cin >> Q[i][j];
}
}
for(int i = 0;i < n;i++){
for(int j = 0;j < d;j++){
cin >> K[i][j];
}
}
for(int i = 0;i < n;i++){
for(int j = 0;j < d;j++){
cin >> V[i][j];
}
}
for(int i = 0;i < n;i++){
cin >> W[i];
}
// vector<vector<lli>> temp(n+1,vector<lli>(n+1));//中间矩阵nxn
// QxKT(temp,Q,K);
// Wxtemp(temp,W);
// tempxV(Res,temp,V);
for(int i = 0; i < n; i++){
for(int j = 0; j < n; j++){
long long qk = 0;
for(int t = 0; t < d; t++){
qk += Q[i][t] * K[j][t];
}
qk *= W[i];
for(int k = 0; k < d; k++){
Res[i][k] += qk * V[j][k];
}
}
}
for(int i = 0;i < n;i++){
for(int j = 0;j < d;j++){
cout << Res[i][j] << " ";
}
cout << endl;
}
}

浙公网安备 33010602011771号