张量的集合操作
点击查看代码
# -*- coding: utf-8 -*-
# @Author  : 钱力
# @Time    : 2024/7/26 14:24
import torch

# 合并操作
A = torch.arange(0, 16).view(2, 8)
B = 10 * A
C = torch.cat([A, B], dim=1)  # 将矩阵根据特定维度进行缝合
print(C)
D = torch.stack([A, B], dim=1)  # 通过增加维度来融合矩阵,这种融合方式一般是时间序列采用
print(D)

# 切分操作
print('=====================================================')
a = torch.arange(10).reshape(5, 2)
print(torch.chunk(a, 2))  # 根据索引进行切分
a = torch.arange(10).reshape(5, 2)
print(torch.split(a, 2))  # 根据长度进行切分
a = torch.arange(10).reshape(5, 2)
print(torch.split(a, [3, 1, 1]))  # 根据长度进行切分

# 现有张量沿着值为1的维度扩展到新的维度n,输出重复n次
a = torch.tensor([[[1, 2, 3], [4, 5, 6]]])
print(a.size())
print(a)
a = a.expand(2, 2, 3)  # 仅限于size=1的维度
print(a.size())
print(a)

# 改变张量的维度
a = torch.arange(9).reshape(3, 3)
print('a:', a)
b = a.permute(1, 0)  # 维度转换,但不改变索引方式
print('b:', b)
print(b.stride())  # 张量的索引方式
print(b.is_contiguous())  # 是否连续,视图索引和内存索引是否一致
c = b.contiguous()  # 强制转换为一致
print(c.stride())
print(c.is_contiguous())
# a 和 b共享内存,但c不是
print('ptr of storage of a', a.untyped_storage().data_ptr())
print('ptr of storage of b', b.untyped_storage().data_ptr())
print('ptr of storage of c', c.untyped_storage().data_ptr())

# reshape和view区别
a = torch.arange(9).reshape(3, 3)
b = a.permute(1, 0)
print(b.reshape(9))
# print(b.view(9))  # 如果视图索引和内存索引不一致,就会报错
print(b.contiguous().view(9))

posted on 2024-07-26 15:58  凯申物流——  阅读(16)  评论(0)    收藏  举报