torch一些API
定义张量
X = torch.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]])
张量转置
X.t()
定义全0、全1张量
torch.zeros((6,8))
torch.ones((6, 8))
单位张量
torch.eye(n , m , out) //n:行数 m:列数(默认为None) out:输出类型(默认为None)
example
torch.eye(3)
>>tensor([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]])
张量数据复制
torch.repeat()
x = torch.tensor([1, 2, 3])
x.repeat(4, 2), x.repeat(4, 2).shape, x.repeat(4, 2, 1).shape, x.repeat(2)
>>(tensor([[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3]]),
torch.Size([4, 6]),
torch.Size([4, 2, 3]),
tensor([1, 2, 3, 1, 2, 3]))
torch.repeat_interleave()
x = torch.tensor([1, 2, 3])
x.repeat_interleave(2)
>>tensor([1, 1, 2, 2, 3, 3])
y = torch.tensor([[1, 2], [3, 4]])
torch.repeat_interleave(y, 2)
>>tensor([1, 1, 2, 2, 3, 3, 4, 4])
torch.repeat_interleave(y, 3, dim=1)
>>tensor([[1, 1, 1, 2, 2, 2],
[3, 3, 3, 4, 4, 4]])
torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0)
>>tensor([[1, 2],
[3, 4],
[3, 4]])
批量矩阵乘法
torch.bmm:对应通道上的张量进行乘法运算
只在三维时可用
a = torch.randn((2,2,5))
b = torch.randn((2,5,3))
c = torch.bmm(a,b)
c,c.shape
>>(tensor([[[ 0.0485, -0.5363, 3.2399],
[ 1.1522, -0.4762, 1.7790]],
[[ 0.3548, 0.2388, 0.6923],
[ 2.5418, -1.7123, 7.4894]]]),
torch.Size([2, 2, 3]))
张量维度增加
unsqueeze()函数起升维的作用,参数表示在哪个地方加一个维度
example:在weights的1处增加维度,在values末尾增加维度
weights = torch.ones((2, 10)) * 0.1
values = torch.arange(20.0).reshape((2, 10))
weights.unsqueeze(1).shape,values.unsqueeze(-1).shape
>>(torch.Size([2, 1, 10]), torch.Size([2, 10, 1]))