003.稀疏矩阵乘法

稀疏矩阵乘法需要用到CSR格式
用数组\(rptr[]\)记录每行首个非零元素在表中的存储下标
根据矩阵乘法\(c_{ij}=\displaystyle{\sum_{k=1}^{n}a_{ik}b_{kj}}\),可知第一个数组表中元素\((i,k,a_{ik})\)的列号\(k\)对应第二个数组表中元素\((k,j,b_{kj})\)
先遍历\(a\)的所有行,在每一行中,\(a\)的若干非零元素在表中的下标区间为\([rptr[i],rptr[i+1])\),对于这个区间的每一个值\((i,k,a_{ik})\)都去乘以\(b\)中第\(k\)行的非零元素\((k,j,b_{kj})\)
\(b\)中这些非零元素的下标又可以由\([rptr[k],rptr[k+1])\)锁定,\((i,k,a_{ik})\)\((k,j,b_{kj})\)的相乘结果会贡献给元素\(c_{ij}\),由于每次遍历\(i\)为定值,故只需要一个数组\(ctemp[]\)记录列贡献的和

TSMatrix Mul(TSMatrix& a, TSMatrix b)
{
    TSMatrix res;
    res.nums = 0;
    res.rows = a.rows;
    res.cols = b.cols;

    for (int i = 0; i < a.rows; i++)
    {
        memset(ctemp, 0, sizeof ctemp);
        int ar1 = a.rptr[i], ar2 = a.rptr[i + 1];
        for (int j = ar1; j < ar2; j++)
        {
            int k = a.data[j].c;
            int br1 = b.rptr[k], br2 = b.rptr[k + 1];
            for (int s = br1; s < br2; s++)
            {
                int col = b.data[s].c;
                ctemp[col] += a.data[j].d * b.data[s].d;
            }
        }
        for (int j = 0; j < b.cols; j++)
        {
            if (ctemp[j])
            {
                res.data[res.nums++] = {i, j, ctemp[j]};
            }
        }
    }
    
    // 计算结果矩阵的行指针数组
    memset(num, 0, sizeof num);
    for (int i = 0; i < res.nums; i++) {
        num[res.data[i].r]++;
    }
    res.rptr[0] = 0;
    for (int i = 1; i <= res.rows; i++) {
        res.rptr[i] = res.rptr[i - 1] + num[i - 1];
    }
    
    return res;
}
posted @ 2025-06-02 14:44  _P_D_X  阅读(33)  评论(0)    收藏  举报