容器中pytorch的cpu速度很慢,原因找到了

  • 容器中pytorch的cpu速度很慢,原因找到了
# import numpy as np
# import time
# import torch
# import os

# os.environ['OMP_NUM_THREADS'] = '8'  # 增加线程数
# torch.set_num_threads(8)  # 设置PyTorch线程数

# # 创建测试矩阵
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# print(f"使用设备: {device}")
# B = torch.randn(2000, 2000).to(device)  # 使用GPU

# # 预热GPU
# if device == 'cuda':
#     _ = torch.svd(torch.randn(100,100).to(device))

# # 计时SVD
# start = time.time()
# with torch.no_grad():  # 禁用梯度计算
#     x = torch.svd(B)
# end = time.time()

# print(f"SVD耗时: {end - start:.2f}秒")

import time
import numpy as np
import torch
import os

def diagnose_torch_svd():
    print("=== PyTorch SVD性能诊断 ===")
    
    # 1. 检查系统配置
    print("\n1. 系统配置:")
    print(f"PyTorch线程数: {torch.get_num_threads()}")
    print(f"PyTorch interop线程数: {torch.get_num_interop_threads()}")
    print(f"OMP_NUM_THREADS: {os.environ.get('OMP_NUM_THREADS', '未设置')}")
    print(f"MKL_NUM_THREADS: {os.environ.get('MKL_NUM_THREADS', '未设置')}")
    
    # 2. 创建测试数据
    print("\n2. 性能测试:")
    size = 1000
    numpy_array = np.random.randn(size, size).astype(np.float64)
    torch_tensor = torch.from_numpy(numpy_array.copy())
    
    # 3. NumPy基准
    start = time.time()
    U_np, s_np, Vt_np = np.linalg.svd(numpy_array, full_matrices=False)
    numpy_time = time.time() - start
    
    # 4. PyTorch测试
    start = time.time()
    U_pt, s_pt, Vt_pt = torch.svd(torch_tensor, some=True)
    torch_time = time.time() - start
    
    print(f"NumPy SVD: {numpy_time:.4f}s")
    print(f"PyTorch SVD: {torch_time:.4f}s")
    print(f"速度比: {torch_time/numpy_time:.2f}x")
    
    # 5. 优化建议
    print("\n3. 优化建议:")
    if torch_time > 2 * numpy_time:
        print("⚠️  PyTorch SVD明显慢于NumPy,建议:")
        print("   - 设置 torch.set_num_threads(4)")
        print("   - 检查环境变量 OMP_NUM_THREADS 和 MKL_NUM_THREADS")
        print("   - 考虑对大型矩阵使用NumPy后端")
    else:
        print("✅ PyTorch SVD性能正常")

# 运行诊断
# diagnose_torch_svd()

# 优化后重新测试
print("\n" + "="*50)
print("应用优化后:")

# 应用优化
# os.environ['OMP_NUM_THREADS'] = '4'
# os.environ['MKL_NUM_THREADS'] = '4'
torch.set_num_threads(8)

diagnose_torch_svd()
posted @ 2025-11-05 16:57  bregman  阅读(5)  评论(0)    收藏  举报