003.稀疏矩阵乘法
稀疏矩阵乘法需要用到CSR格式
用数组\(rptr[]\)记录每行首个非零元素在表中的存储下标
根据矩阵乘法\(c_{ij}=\displaystyle{\sum_{k=1}^{n}a_{ik}b_{kj}}\),可知第一个数组表中元素\((i,k,a_{ik})\)的列号\(k\)对应第二个数组表中元素\((k,j,b_{kj})\)
先遍历\(a\)的所有行,在每一行中,\(a\)的若干非零元素在表中的下标区间为\([rptr[i],rptr[i+1])\),对于这个区间的每一个值\((i,k,a_{ik})\)都去乘以\(b\)中第\(k\)行的非零元素\((k,j,b_{kj})\)
\(b\)中这些非零元素的下标又可以由\([rptr[k],rptr[k+1])\)锁定,\((i,k,a_{ik})\)与\((k,j,b_{kj})\)的相乘结果会贡献给元素\(c_{ij}\),由于每次遍历\(i\)为定值,故只需要一个数组\(ctemp[]\)记录列贡献的和
TSMatrix Mul(TSMatrix& a, TSMatrix b)
{
TSMatrix res;
res.nums = 0;
res.rows = a.rows;
res.cols = b.cols;
for (int i = 0; i < a.rows; i++)
{
memset(ctemp, 0, sizeof ctemp);
int ar1 = a.rptr[i], ar2 = a.rptr[i + 1];
for (int j = ar1; j < ar2; j++)
{
int k = a.data[j].c;
int br1 = b.rptr[k], br2 = b.rptr[k + 1];
for (int s = br1; s < br2; s++)
{
int col = b.data[s].c;
ctemp[col] += a.data[j].d * b.data[s].d;
}
}
for (int j = 0; j < b.cols; j++)
{
if (ctemp[j])
{
res.data[res.nums++] = {i, j, ctemp[j]};
}
}
}
// 计算结果矩阵的行指针数组
memset(num, 0, sizeof num);
for (int i = 0; i < res.nums; i++) {
num[res.data[i].r]++;
}
res.rptr[0] = 0;
for (int i = 1; i <= res.rows; i++) {
res.rptr[i] = res.rptr[i - 1] + num[i - 1];
}
return res;
}

浙公网安备 33010602011771号