torch一些API

定义张量

X = torch.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]])

张量转置

X.t()

定义全0、全1张量

torch.zeros((6,8))

torch.ones((6, 8))

单位张量

torch.eye(n , m , out) //n:行数 m:列数(默认为None) out:输出类型(默认为None)

example

torch.eye(3)

>>tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])

张量数据复制

torch.repeat()

x = torch.tensor([1, 2, 3])
x.repeat(4, 2), x.repeat(4, 2).shape, x.repeat(4, 2, 1).shape, x.repeat(2)

>>(tensor([[1, 2, 3, 1, 2, 3],
         [1, 2, 3, 1, 2, 3],
         [1, 2, 3, 1, 2, 3],
         [1, 2, 3, 1, 2, 3]]),
 torch.Size([4, 6]),
 torch.Size([4, 2, 3]),
 tensor([1, 2, 3, 1, 2, 3]))

torch.repeat_interleave()

x = torch.tensor([1, 2, 3])
x.repeat_interleave(2)

>>tensor([1, 1, 2, 2, 3, 3])
y = torch.tensor([[1, 2], [3, 4]])
torch.repeat_interleave(y, 2)

>>tensor([1, 1, 2, 2, 3, 3, 4, 4])

torch.repeat_interleave(y, 3, dim=1)

>>tensor([[1, 1, 1, 2, 2, 2],
        [3, 3, 3, 4, 4, 4]])

torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0)

>>tensor([[1, 2],
        [3, 4],
        [3, 4]])

批量矩阵乘法

torch.bmm:对应通道上的张量进行乘法运算

只在三维时可用

a = torch.randn((2,2,5))
b = torch.randn((2,5,3))
c = torch.bmm(a,b)
c,c.shape

>>(tensor([[[ 0.0485, -0.5363,  3.2399],
          [ 1.1522, -0.4762,  1.7790]],
 
         [[ 0.3548,  0.2388,  0.6923],
          [ 2.5418, -1.7123,  7.4894]]]),
 torch.Size([2, 2, 3]))

张量维度增加

unsqueeze()函数起升维的作用,参数表示在哪个地方加一个维度

example:在weights的1处增加维度,在values末尾增加维度

weights = torch.ones((2, 10)) * 0.1
values = torch.arange(20.0).reshape((2, 10))
weights.unsqueeze(1).shape,values.unsqueeze(-1).shape

>>(torch.Size([2, 1, 10]), torch.Size([2, 10, 1]))
posted @ 2022-05-10 17:58  老裴菌  阅读(240)  评论(0)    收藏  举报