PyTorch中的张量存储方式以及view()和reshape()操作

1. PyTorch张量的底层存储方式

在PyTorch中,张量的头信息(包括张量名、形状等信息)和实际的数据是分开存储的。在数据区,无论张量的维度、形状如何,都是按照一维、连续的方式进行存储的,可以通过张量的storage()来查看张量的数据区:

import torch
a = torch.arange(9).view(3, 3)
print(f"storage of a: {a.storage()}")

1.1 张量的size属性和stride属性

张量的size指的是张量每一维度的大小,张量的步长属性stride可以理解为从索引中的一个维度跨到下一个维度中间的跨度:

对于第一个维度(即行),从上一行跨越到下一行需要跨越三个元素,而第二个维度(即列),从前一列跨越到(同一行的)下一列只需跨越一个元素,故astride(3,1)

1.2 连续性条件

张量是连续的是指张量满足以下条件:

stride[i] = stride[i + 1] * size[i + 1]

2. view()和reshape()

张量的view()和`reshape()的作用都是将张量转换为指定形状,然而,view()要求张量在装换前后满足连续型条件

b = a.permute(1, 0)	# b是a的转置,此时a和b共享存储区
print(f"size of a: {a.size()}, size of b: {b.size()}")
print(f"stride of a: {a.stride()}, stride of b: {b.stride()}")	# b不满足连续型条件
print(f"storage_ptr of a: {a.storage().data_ptr()}, storage_ptr of b: {b.storage().data_ptr()}")

此时,若对b使用view(),则会因不满足连续性条件而报错:

b.view(1, 9)

解决方法是先使用contiguous()将其转换到满足连续性条件,该函数会开辟新的内存空间,对数据进行深拷贝:

b = b.contiguous()
print(f"storage_ptr of a: {a.storage().data_ptr()}, storage_ptr of b: {b.storage().data_ptr()}")	# a, b的data_ptr不同
b = b.view(1, 9)
print(b)

reshape()也可以达到上述的效果:

  • 当张量满足连续性条件时,等价于b.view()
  • 当张量不满足连续性条件时,等价于b.contiguous().view()
c = a.permute(1, 0)
c = c.reshape(1, 9)
print(c)

本文代码详见代码

posted @ 2024-01-10 14:11  Derrick97  阅读(288)  评论(0)    收藏  举报