pytorch基础
In [1]:
import torch
In [2]:
# 查询torch版本
print(torch.__version__)
1.11.0+cpu
In [3]:
x = torch.rand(5,3)
y = torch.rand(5,3)
In [4]:
# 矩阵相加
z = torch.add(x,y)
z
Out[4]:
tensor([[0.5727, 0.6683, 1.4929],
[1.5323, 1.7757, 0.3782],
[1.5451, 0.4119, 1.0949],
[0.6613, 0.9304, 1.2568],
[1.4187, 0.8155, 0.9888]])
In [5]:
# 矩阵索引
print(z[1,2],z[:,2],z[1,:])
tensor(0.3782) tensor([1.4929, 0.3782, 1.0949, 1.2568, 0.9888]) tensor([1.5323, 1.7757, 0.3782])
In [6]:
# view矩阵变换
print(z.view([3,5]))
tensor([[0.5727, 0.6683, 1.4929, 1.5323, 1.7757],
[0.3782, 1.5451, 0.4119, 1.0949, 0.6613],
[0.9304, 1.2568, 1.4187, 0.8155, 0.9888]])
In [7]:
# 自动计算维度
print(z.view(-1,5))
tensor([[0.5727, 0.6683, 1.4929, 1.5323, 1.7757],
[0.3782, 1.5451, 0.4119, 1.0949, 0.6613],
[0.9304, 1.2568, 1.4187, 0.8155, 0.9888]])
浙公网安备 33010602011771号