Loading

c 语言矩阵乘法示例

通过命令行输入矩阵大小,矩阵取值随机

//#include <ctime>
#include <cstdlib>
#include <cstdio>
#include <cstring>
#include <chrono>
#include <iostream>
#include <vector>

using namespace std;

void printMatrix(double** m, int rows, int cols);
double** new2DMatrix(int rows, int cols, long seed, long range);
void freeMatrix(double** m);

double** new2DMatrix(int rows, int cols, long seed, long range) {
    auto** m = (double**) malloc(cols * sizeof (double*));
    m[0] = (double*) malloc(rows * cols * sizeof (double));

    if ((void*)m == nullptr || m[0] == nullptr) {
        return nullptr;
    }
    // 初始化那啥,第一层数组的指针值,存放列向量指针
    for (int i = 1; i < cols; i++) {
        m[i] = m[i - 1] + rows;
    }
    // 随机数
#if 0 // 测试上面的代码对不对
    for (int i = 0; i < cols; ++i) {
        for (int j = 0; j < rows; ++j) {
            m[i][j] = 1;
        }
    }
//    printMatrix(m, rows, cols);
#endif
#if 1
    (void) seed;
    for (int i = 0; i < cols; ++i) {
        for (int j = 0; j < rows; ++j) {
            long r = random();
            m[i][j] = (double)(r % range);
        }
    }
#endif
    return m;
}

void freeMatrix(double** m) {
    if (m[0] != nullptr) {
        free(m[0]);
        m[0] = nullptr;
    }
    if ((void*)m != nullptr) {
        free(m);
    }
}

void printMatrix(double** m, int rows, int cols) {
    for (int j = 0; j < rows; j++) {
        for (int i = 0; i < cols; i++) {
            printf("%.1lf, ", m[i][j]);
        }
        puts("");
    }
    puts("");
}

int matrixMul(size_t rowsA, size_t colsA, size_t rowsB, size_t colsB, long seed, long range) {

    size_t rowsC;
    size_t colsC;

    if (colsB != rowsA) {
        puts("the cols number of matrixA and the rows number of matrixB not match!\n");
        return -1;
    }

    rowsC = rowsA;
    colsC = colsB;

//    printf("rowsA: %zu, colsA: %zu, rowsB: %zu, colsB: %zu, seed: %ld, range: \
//    %ld\n\n", rowsA, colsA, rowsB, colsB, seed, range);

    srand(seed);

    double** a = new2DMatrix(rowsA, colsA, seed, range);
    double** b = new2DMatrix(rowsB, colsB, seed, range);
    double** c = new2DMatrix(rowsC, colsC, seed, range);


    // 计算乘法耗费的时间
    auto begin = std::chrono::high_resolution_clock::now();

    for (int j = 0; j < rowsC; j++) {
        for (int i = 0; i < colsC; i++) {
            double tmp = 0.0;
            for (int k = 0; k < colsA; k++) {
                tmp += b[i][k] * a[k][j];
            }
            c[i][j] = tmp;
        }
    }

    // Stop measuring time and calculate the elapsed time
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);

//    printf("Time measured: %.3f microseconds.\n", elapsed.count() * 1e-3);
    printf("%.2lf\n", elapsed.count() / 1000.0);
//    cout << elapsed.count() / 1000.0 << endl;

//    printMatrix(a, rowsA, colsA);
//    printMatrix(b, rowsB, colsB);
//    printMatrix(c, rowsC, colsC);

    freeMatrix(a);
    freeMatrix(b);
    freeMatrix(c);

    return 0;
}

int main(int argc, char** argv) {
    // 检查参数个数
    if (argc == 1) {
        vector<int> rows;
        for (int i = 1; i < 10; ++i) {
            rows.push_back(i);
        }
        for (int i = 10; i < 2000; i *= 2) {
            rows.push_back(i);
        }
//        rows.push_back(5000);
//        rows.push_back(10000);
        for (int i: rows) {
            printf("%d\n", i);
        }

        for (int i : rows) {
            matrixMul(i, i, i, i, 1024, 10);
        }
        return 0;
    }

    if (argc != 7) {
        puts("Wrong parameters! \nUse like this: \n\tmul <rowsA> <colsA> "\
        "<rowsB> <colsB> <seed> <range>\n\tmul 3 3 3 3 1024 10");
        return -1;
    }

    size_t rowsA, rowsB;
    size_t colsA, colsB;
    long seed, range;

    rowsA = strtol(argv[1], nullptr, 10);
    colsA = strtol(argv[2], nullptr, 10);
    rowsB = strtol(argv[3], nullptr, 10);
    colsB = strtol(argv[4], nullptr, 10);
    seed = strtol(argv[5], nullptr, 10);
    range = strtol(argv[6], nullptr, 10);

    return matrixMul(rowsA, colsA, rowsB, colsB, seed, range);
}

posted @ 2021-03-31 14:30  konosubaakua  阅读(266)  评论(0)    收藏  举报