张量剪裁

a = torch.rand(3 ,4) * 10 #范围在0~10的3*4张量
a = a.clamp(2, 5) #2 <= a <= 5

Tensor的索引与数据筛选

torch.where(condition, x, y) 按照条件从x和y中选出满足条件的元素组成新的tensor
torch.gather(input, dim, index, out=None) 在指定维度上按照索引赋值输出tensor
torch.index_select(input, dim, index, out=None) 按照指定索引输出tensor
torch.masked_select(input,mask, out=None) 按照mask输出tensor,输出为向量
torch.take(input, indices) 将输入看成1D-tensor,按照索引得到输出tensor
torch.nonzero(input, out=None) 输出非0元素的坐标

torch.where(condition, x, y)

使用condition做二值化

import torch
a = torch.rand(4, 4)
b = torch.rand(4, 4)
print(a,b)
# 当符合条件时使用a成员,否则使用对应index的b成员
out = torch.where(a>0.5, a, b)
print(out)

结果

tensor([[0.3003, 0.1368, 0.0918, 0.8426],
        [0.5309, 0.3278, 0.1919, 0.9897],
        [0.1696, 0.7081, 0.3244, 0.3783],
        [0.1589, 0.0064, 0.0574, 0.3292]]) 
tensor([[0.1048, 0.6781, 0.0046, 0.9271],
        [0.6896, 0.8376, 0.3256, 0.4376],
        [0.0561, 0.3164, 0.9954, 0.8089],
        [0.4044, 0.2543, 0.3587, 0.1339]])
tensor([[0.1048, 0.6781, 0.0046, 0.8426],
        [0.5309, 0.8376, 0.3256, 0.9897],
        [0.0561, 0.7081, 0.9954, 0.8089],
        [0.4044, 0.2543, 0.3587, 0.1339]])

torch.index_select(input, dim, index, out=None)

import torch
a = torch.rand(4, 4)
print(a)
# 使用a中0、3、2行构成新的tensor
out = torch.index_select(a, dim=0, index=torch.tensor([0, 3, 2]))
print(out) # 3*4shape

torch.gather(input, dim, index, out=None)

import torch
a = torch.linspace(1, 16, 16).view(4, 4)
print(a)
# index对应的是列的维度 out
= torch.gather(a,dim=0, index=torch.tensor([[0,1,1,1], [0,1,2,2], [0,1,3,3]])) print(out) #3*4

输出

torch.masked_select(input,mask, out=None)

import torch
a = torch.linspace(1, 16, 16).view(4, 4)
print(a)
mask = torch.gt(a,8)
print(mask)
out = torch.masked_select(a,mask)
print(out)

 torch.take(input, indices)

import torch
a = torch.linspace(1, 16, 16).view(4, 4)
b = torch.take(a,index=torch.tensor([0,15,13,10]))
print(b)

torch.nonzero(input, out=None)

import torch
a = torch.tensor([[0,1,2,0],[2,3,0,1]])
out = torch.nonzero(a)
print(out)

 Tensor的组合和拼接

torch.cat(seq, dim=0, out=None) 按照已经存在的维度进行拼接
torch.stack(seq, dim=O, out=None) 按照新的维度进行拼接
torch.gather(input, dim, index, out=None) 在指定维度上按照索引赋值输出tensor

torch.cat(seq, dim=0, out=None)

import torch
a = torch.zeros((2,4))
b = torch.ones((2,4))
out0 = torch.cat((a,b),dim=0)
out1 = torch.cat((a,b),dim=1)
print(out0)
print(out1)

 torch.stack(seq, dim=O, out=None)

import torch
a = torch.linspace(1,6,6).view(2,3)
b = torch.linspace(7,12,6).view(2,3)
out = torch.stack((a,b),dim=0)
print(out) #2*2*3
out = torch.stack((a,b),dim=1)
print(out) #2*2*3

Tensor的切片

torch.chunk(tensor, chunks, dim=O) 按照某个维度平均分块(最后一个可能小于平均值)
torch.split(tensor, split_size_or_sections, dim=0) (更常用)按照某个维度依照第二个参数给出的list或者int进行分割tensor

torch.chunk(tensor, chunks, dim=O)

import torch
a = torch.rand((3,4))
print(a)
out = torch.chunk(a,2,dim=0)
print(out[0], out[0].shape)
print(out[1], out[1].shape)

torch.split(tensor, split_size_or_sections, dim=0)

int split

import torch
a = torch.rand((10,4))
out = torch.split(a,3,dim=0)
for i in out:
    print(i,i.shape)

 list split

out = torch.split(a,[1,3,6],dim=0)
for i in out:
    print(i,i.shape)

 Tensor的变形操作

torch.reshape(input, shape)  
torch.t(input): 只针对2D tensor转置
torch.transpose(input, dim0, dim1): 交换两个维度
torch.squeeze(input, dim=None, out=None): 去除那些维度大小为1的维度
torch.unbind(tensor, dim=0): 去除某个维度,返回tuple
torch.unsqueeze(input, dim, out=None): 指定位置添加维度
torch.flip(input, dims): 按照给定维度翻转张量
torch.rot90(input, k, dims): 按照指定维度和旋转次数进行张量旋转,逆时针90度旋转

torch.transpose(input, dim0, dim1):

import torch
a = torch.rand(1,2,3)
out = torch.transpose(a,0,1)
print(out,out.shape)

 torch.squeeze(input, dim=None, out=None):

import torch
a = torch.rand(1,2,3)
out = torch.squeeze(a)
print(out,out.shape)

 torch.unsqueeze(input, dim, out=None):

import torch
a = torch.rand(1,2,3)
out = torch.unsqueeze(a, -1)
print(out,out.shape)

 torch.unbind(tensor, dim=0):

import torch
a = torch.rand(1,2,3)
out = torch.unbind(a, dim=2)
print(out)

 Tensor的填充操作

torch.full()

torch.full((2,3),3.14)
# tensor([[ 3.14,3.14,3.14],[3.14,3.14,3.14]])

 Tensor的频谱操作

时域信号转频域信号

torch.fft(input, signal_ndim, normalized=False)
torch.ifft(input, signal_ndim, normalized=False)
torch.rfft(input, signal_ndim, normalized=False, onesided=True)
torch.irfft(input, signal_ndim, normalized=False, onesided=True)
torch.stft(signa, frame_length, hop,...)