【深度学习基础】PyTorch Tensor生成方式及复制方式详解

PyTorch Tensor生成方式及复制方法详解

在PyTorch中,Tensor的创建和复制是深度学习开发的基础操作。本文将全面总结Tensor的各种生成方式,并深入分析不同复制方法的区别。


一、Tensor的生成方式

(一)从Python列表/元组创建

import torch
# 直接创建Tensor
t1 = torch.tensor([1, 2, 3]) # 整型Tensor
t2 = torch.tensor([[1.0, 2], [3, 4]]) # 浮点型Tensor

(二)从NumPy数组创建

import numpy as np
arr = np.array([1, 2, 3])
t = torch.from_numpy(arr) # 共享内存

(三)特殊初始化方法

zeros = torch.zeros(2, 3) # 全0矩阵
ones = torch.ones(2, 3) # 全1矩阵
rand = torch.rand(2, 3) # [0,1)均匀分布
randn = torch.randn(2, 3) # 标准正态分布
arange = torch.arange(0, 10, 2) # 0-10步长为2

(四)从现有Tensor创建

x = torch.tensor([1, 2, 3])
x1 = x.new_tensor([4, 5, 6]) # 新Tensor(复制数据)
x2 = torch.zeros_like(x) # 形状相同,全0
x3 = torch.randn_like(x) # 形状相同,随机值

(五)高级初始化方法

eye = torch.eye(3) # 3x3单位矩阵
lin = torch.linspace(0, 1, 5) # 0-1等分5份
log = torch.logspace(0, 2, 3) # 10^0到10^2等分3份

二、复制方法对比

(一) torch.tensor() vs torch.from_numpy()

方法数据源内存共享梯度传递数据类型
torch.tensor()Python数据不共享支持自动推断
torch.from_numpy()NumPy数组共享不支持保持一致
# 示例:内存共享验证
arr = np.array([1, 2, 3])
t = torch.from_numpy(arr)
arr[0] = 99 # 修改NumPy数组
print(t) # tensor([99, 2, 3]),同步变化

(二) .clone() vs .copy_() vs copy.deepcopy()

方法内存共享梯度传递计算图保留使用场景
.clone()不共享保留梯度保留计算图需要梯度回传
.copy_()目标共享不保留破坏计算图高效覆盖数据
copy.deepcopy()不共享不保留不保留完全独立拷贝
# 示例:梯度传递对比
x = torch.tensor([1.], requires_grad=True)
y = x.clone()
z = torch.tensor([2.], requires_grad=True)
z.copy_(x) # 覆盖z的值
y.backward() # 正常回传梯度到x
# z.backward() # 报错!copy_()破坏计算图

(三) 深度拷贝(Deep Copy)

import copy
orig = torch.tensor([1, 2, 3])
deep_copied = copy.deepcopy(orig) # 完全独立拷贝

三、核心区别总结

  1. 内存共享

    • from_numpy() 与NumPy共享内存
    • 视图操作(如view()/切片)共享内存
    • 其他方法均创建独立副本
  2. 梯度处理

    • .clone() 唯一保留梯度计算图
    • copy_() 会破坏目标Tensor的计算图
    • torch.tensor() 创建新计算图
  3. 使用场景

    • 需要梯度回传:使用.clone()
    • 高效数据覆盖:使用.copy_()
    • 完全独立拷贝:使用copy.deepcopy()
    • 与NumPy交互:使用from_numpy()/numpy()

四、最佳实践建议

  1. 优先使用torch.tensor()创建新Tensor
  2. 需要从NumPy导入数据且避免复制时用from_numpy()
  3. 在计算图中复制数据时必须使用.clone()
  4. 需要覆盖现有Tensor数据时使用.copy_()
  5. 调试时注意内存共享可能导致的意外修改
# 正确梯度传递示例
x = torch.tensor([1.], requires_grad=True)
y = x.clone() ** 2 # 保留计算图
y.backward() # 梯度可回传到x
posted @ 2025-08-22 10:47  yfceshi  阅读(17)  评论(0)    收藏  举报