pytorch(2)----基本数据类型与模块

Tensor

  张量,包含单一数据类型元素的矩阵。

 

 

基本内容

1、数据初始化及数据类型转换

2、组合: torch.cat()  按照某一个维度进行拼接,总维度数目不变

                 torch.stack()  按照制定维度进行叠加,新增维度

3、分块:torch.chunk() 指定分块数量

                torch.split() 指定每块数量

4、 索引、torch.masked_select

5、维度变化:.view().resize()  .reshape()

6、.unsqueeze()、.squeeze()

7、Tensor 的排序: .sort()

8、Tensor 的广播机制

示例代码

1、

# 数据类型 及转换
# 默认的数据类型为 torch.FloatTensor
# x 实际为torch.FloatTensor
x = torch.Tensor(2,2)  #2行2列、未初始化
x = torch.rand(2,2)  #均匀分布 [[0,1)
print(x)
print(x.dtype)

# 使用 int()  double()  float() 等直接进行数据类型转换
b = x.double()
print(b.dtype)
print(type(b))

#使用 .type()函数进行 类型转换
c = x.type(torch.IntTensor)
print(c.dtype)
print(c)

#使用 .type_as()函数转换类型更加 方便
d = x.type_as(c)
print(d.dtype)
print(d)

#其他初始化方法
print('\n其他初始化方法')
# 直接给值
c1 = torch.Tensor([[2,3,4],[1,4,5]])
print(c1.shape)
print(c1.dtype)
# ones() eye() zeros()
c2 =torch.eye(5)
print(c2)
print(c2.dtype)
#randn()  标准正太分布 随机数
c3 = torch.randn(4,4)
# torch.arange(start,end, step)  生成一维向量 [start,end)
c4 = torch.arange(1,6,2)
print(c4)

#元素个数
print(c1.numel())

2、

# 组合
# torch.cat()  按照某一个维度进行拼接,总维度数目不变
print("torch.cat()----")
a = torch.Tensor([[1,2],[3,4]])
print(a)

b = torch.Tensor([[5,6],[7,8]])
print(b)
#  按照第一维进行拼接
c1 = torch.cat([a,b],0)
print(c)

# 按照第二维进行拼接
c2 = torch.cat([a,b],1)
print(c)


# 组合
# torch.stack()  按照制定维度进行叠加,新增维度
print("\n torch.stack()----")
print(a)
print(b)
c3 = torch.stack([a,b],0)
print(c3)
print(c3.shape)
c4 = torch.stack([a,b],1)
print(c4)
print(c4.shape)

3、

# 分块
# torch.chunk() 制定分块数量
a = torch.Tensor([[1,2,3],[4,5,6]])
print(a)
print(a.shape)
# 沿着第0维,分成2块
print( torch.chunk(a,2,0) )
print( torch.chunk(a,2,1) ) #分配不均时,前面的个数多于后面

# 分块
# torch.split() 指定每块数量
print(a)
print(a.shape)
# 沿着第0维,每块的个数为2
print( torch.split(a,2,0) )
print( torch.split(a,2,1) )

4、

# 索引
a = torch.Tensor([[0,1],[6,7]])
# 按下标索引
print(a[0])
print(a[0,1])

# 比较 true 为1,false 为0
print(a>0)

# 选择符合条件的元素返回
print(torch.masked_select(a,a>0))
print(a[a>0])

5、

# 维度变化
a = torch.arange(1,5)
print(a)
print(a.view(2,2))

print(a.resize(4,1))
print(a)

print(a.reshape(2,2))

# 原地操作
print(a.resize_(4,1))
print(a)

6、

# 增加维度
# .unsqueeze
a = torch.arange(1,4)
print(a)
print(a.shape)

# 将第0维变成1
b1 = a.unsqueeze(0)
print(b1)
print(b1.shape)

#将第1维变成1
b2 = a.unsqueeze(1)
print(b2)
print(b2.shape)


# 减少维度
c1 = b2.squeeze(1)
print(c1)
print(c1.shape)

7、

# Tensor 的排序
# 按照第0维进行排序,True为降序,false为升序
a = torch.randn(3,3)
print(a)

b=a.sort(0,True)
print("排序结果\n",b[0])
print(a)
print("排序结果索引\n",b[1])
# 按照第1维进行排序,True为降序,false为升序
b1=a.sort(1,True)
print("排序结果\n",b1[0])
print("排序结果索引\n",b1[1])

# .max  .min
print(a)
c = a.max(0)   # 按照第0维: 选出每一列的最大值
print(c[0])
print(c[1])

8、

# Tensor 的广播机制
# 条件: 任一个tensor 至少有一个维度,且从尾到头部遍历整个tensor维度时
a = torch.ones(3,1,2)
b = torch.ones(2,1)
c = a+b
print(c)
print(c.shape)

d = torch.ones(2,3)
#c2 = a+d #error

 

 

 

 

 

     

posted on 2020-02-13 19:09  feihu_h  阅读(389)  评论(0编辑  收藏  举报

导航