【深度学习】Deep Learning Framework Matrix Multiplication Speed Compare(numpy & torch & mlx)

✨ 实验环境

SoC:M1 Pro 10 cores CPU 16 cores GPU

RAM:16GB

numpy:2.2.6

torch:2.8.0

mlx:0.28.0


✨FP16 & FP32

硬件平台 运算速度 主要原因
PyTorch GPU FP16 显著快于 FP32 现代 GPU(如 Tensor Core)对 FP16 有硬件加速。
MLX Apple Silicon(GPU/NPU) FP16 显著快于 FP32 针对 Apple 芯片的优化,充分利用其并行计算能力。
NumPy CPU 差异不显著 CPU 对 FP16 没有专门的硬件加速,主要依赖传统计算单元。

numpy 不算深度学习框架 仅作参考(


✨ 实验结果

分别测试了 torch mlx 使用 FP16 进行矩阵乘法

numpy 使用 FP32 进行矩阵乘法运算

实验结果如下

mlx FP16:3m 35s

torch FP16:6m 0s

numpy FP32:12m 33s

image-20250819122057825

本来想 mlx torch numpy 都测试 FP16 FP32

但是发现没有加速的情况下运算实在太慢了= =

mlx FP32 耗时 13h 38m 还没有完成

image-20250819130211316


✨ 代码实现

import gc
import time
import numpy as np
import mlx.core as mx
import torch


def martix_multiplication_mlx(dtype: mx.float16):
    R = 10
    duration = 0

    for i in range(16):
        N = 2**i

        a = mx.random.uniform(-1.0, 1.0, [N, N], dtype=dtype)
        b = mx.random.uniform(-1.0, 1.0, [N, N], dtype=dtype)
        mx.eval(a)
        mx.eval(b)

        for r in range(R):
            start_time = time.perf_counter()
            c = a @ b
            mx.eval(c)
            duration += time.perf_counter() - start_time

        del a, b, c
        gc.collect()
    mx.clear_cache()

    return duration


def martix_multiplication_torch(dtype: torch.float16):
    R = 10
    duration = 0

    for i in range(16):
        N = 2**i

        a = torch.rand(N, N, dtype=dtype, device='mps')
        b = torch.rand(N, N, dtype=dtype, device='mps')
        # 预热 GPU
        _ = a @ b
        duration = 0

        for r in range(R):
            start_time = time.perf_counter()
            c = a @ b
            torch.mps.synchronize()
            duration += time.perf_counter() - start_time

        del a, b, c
        gc.collect()
    torch.mps.empty_cache()

    return duration


def martix_multiplication_numpy(dtype: np.float16):
    R = 10
    duration = 0

    for i in range(16):
        N = 2 ** i

        # 创建 NumPy 数组
        a = np.random.rand(N, N).astype(dtype)
        b = np.random.rand(N, N).astype(dtype)

        # NumPy 在 CPU 上运行,没有“预热”或“同步”的概念
        # 因此,我们不需要执行 _ = a @ b 或 numpy.synchronize()
        duration = 0

        for r in range(R):
            start_time = time.perf_counter()
            c = a @ b
            duration += time.perf_counter() - start_time

        del a, b, c
        gc.collect()

    return duration


def format_time(seconds):
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    seconds = seconds % 60

    if hours > 0:
        return f"{hours}h {minutes}m {seconds:.0f}s"
    elif minutes > 0:
        return f"{minutes}m {seconds:.0f}s"
    else:
        return f"{seconds:.0f}s"


martix_mul_mlx_fp16 = martix_multiplication_mlx(dtype=mx.float16)
print(f"martix multiplication(mx.float16) cost: {format_time(martix_mul_mlx_fp16)} using mlx")

martix_mul_torch_fp16 = martix_multiplication_torch(dtype=torch.float16)
print(f"martix multiplication(torch.float16) cost: {format_time(martix_mul_torch_fp16)} using torch")

martix_mul_numpy_fp32= martix_multiplication_numpy(dtype=np.float32)
print(f"martix multiplication(np.float32) cost: {format_time(martix_mul_numpy_fp32)} using numpy")

martix_mul_mlx_fp32 = martix_multiplication_mlx(dtype=mx.float32)
print(f"martix multiplication(mx.float32) cost: {format_time(martix_mul_mlx_fp32)} using mlx")

martix_mul_torch_fp32 = martix_multiplication_torch(dtype=torch.float32)
print(f"martix multiplication(torch.float32) cost: {format_time(martix_mul_torch_fp32)} using torch")

martix_mul_numpy_fp16 = martix_multiplication_numpy(dtype=np.float16)
print(f"martix multiplication(np.float16) cost: {format_time(martix_mul_numpy_fp16)} using numpy")

✨ 参考及引用

https://wangkuiyi.github.io/roofline.html


⭐ 转载请注明出处

本文作者:双份浓缩馥芮白

原文链接:https://www.cnblogs.com/Flat-White/p/19046521

版权所有,如需转载请注明出处。

posted @ 2025-08-19 13:13  双份浓缩馥芮白  阅读(12)  评论(0)    收藏  举报