张量剪裁
a = torch.rand(3 ,4) * 10 #范围在0~10的3*4张量 a = a.clamp(2, 5) #2 <= a <= 5
Tensor的索引与数据筛选
| torch.where(condition, x, y) | 按照条件从x和y中选出满足条件的元素组成新的tensor |
| torch.gather(input, dim, index, out=None) | 在指定维度上按照索引赋值输出tensor |
| torch.index_select(input, dim, index, out=None) | 按照指定索引输出tensor |
| torch.masked_select(input,mask, out=None) | 按照mask输出tensor,输出为向量 |
| torch.take(input, indices) | 将输入看成1D-tensor,按照索引得到输出tensor |
| torch.nonzero(input, out=None) | 输出非0元素的坐标 |
torch.where(condition, x, y)
使用condition做二值化
import torch a = torch.rand(4, 4) b = torch.rand(4, 4) print(a,b) # 当符合条件时使用a成员,否则使用对应index的b成员 out = torch.where(a>0.5, a, b) print(out)
结果
tensor([[0.3003, 0.1368, 0.0918, 0.8426], [0.5309, 0.3278, 0.1919, 0.9897], [0.1696, 0.7081, 0.3244, 0.3783], [0.1589, 0.0064, 0.0574, 0.3292]]) tensor([[0.1048, 0.6781, 0.0046, 0.9271], [0.6896, 0.8376, 0.3256, 0.4376], [0.0561, 0.3164, 0.9954, 0.8089], [0.4044, 0.2543, 0.3587, 0.1339]]) tensor([[0.1048, 0.6781, 0.0046, 0.8426], [0.5309, 0.8376, 0.3256, 0.9897], [0.0561, 0.7081, 0.9954, 0.8089], [0.4044, 0.2543, 0.3587, 0.1339]])
torch.index_select(input, dim, index, out=None)
import torch a = torch.rand(4, 4) print(a) # 使用a中0、3、2行构成新的tensor out = torch.index_select(a, dim=0, index=torch.tensor([0, 3, 2])) print(out) # 3*4shape
torch.gather(input, dim, index, out=None)
import torch a = torch.linspace(1, 16, 16).view(4, 4) print(a)
# index对应的是列的维度 out = torch.gather(a,dim=0, index=torch.tensor([[0,1,1,1], [0,1,2,2], [0,1,3,3]])) print(out) #3*4
输出

torch.masked_select(input,mask, out=None)
import torch a = torch.linspace(1, 16, 16).view(4, 4) print(a) mask = torch.gt(a,8) print(mask) out = torch.masked_select(a,mask) print(out)

torch.take(input, indices)
import torch a = torch.linspace(1, 16, 16).view(4, 4) b = torch.take(a,index=torch.tensor([0,15,13,10])) print(b)

torch.nonzero(input, out=None)
import torch a = torch.tensor([[0,1,2,0],[2,3,0,1]]) out = torch.nonzero(a) print(out)

Tensor的组合和拼接
| torch.cat(seq, dim=0, out=None) | 按照已经存在的维度进行拼接 |
| torch.stack(seq, dim=O, out=None) | 按照新的维度进行拼接 |
| torch.gather(input, dim, index, out=None) | 在指定维度上按照索引赋值输出tensor |
torch.cat(seq, dim=0, out=None)
import torch
a = torch.zeros((2,4))
b = torch.ones((2,4))
out0 = torch.cat((a,b),dim=0)
out1 = torch.cat((a,b),dim=1)
print(out0)
print(out1)

torch.stack(seq, dim=O, out=None)
import torch a = torch.linspace(1,6,6).view(2,3) b = torch.linspace(7,12,6).view(2,3) out = torch.stack((a,b),dim=0) print(out) #2*2*3 out = torch.stack((a,b),dim=1) print(out) #2*2*3

Tensor的切片
| torch.chunk(tensor, chunks, dim=O) | 按照某个维度平均分块(最后一个可能小于平均值) |
| torch.split(tensor, split_size_or_sections, dim=0) | (更常用)按照某个维度依照第二个参数给出的list或者int进行分割tensor |
torch.chunk(tensor, chunks, dim=O)
import torch a = torch.rand((3,4)) print(a) out = torch.chunk(a,2,dim=0) print(out[0], out[0].shape) print(out[1], out[1].shape)

torch.split(tensor, split_size_or_sections, dim=0)
int split
import torch a = torch.rand((10,4)) out = torch.split(a,3,dim=0) for i in out: print(i,i.shape)

list split
out = torch.split(a,[1,3,6],dim=0) for i in out: print(i,i.shape)

Tensor的变形操作
| torch.reshape(input, shape) | |
| torch.t(input): | 只针对2D tensor转置 |
| torch.transpose(input, dim0, dim1): | 交换两个维度 |
| torch.squeeze(input, dim=None, out=None): | 去除那些维度大小为1的维度 |
| torch.unbind(tensor, dim=0): | 去除某个维度,返回tuple |
| torch.unsqueeze(input, dim, out=None): | 在指定位置添加维度 |
| torch.flip(input, dims): | 按照给定维度翻转张量 |
| torch.rot90(input, k, dims): | 按照指定维度和旋转次数进行张量旋转,逆时针90度旋转 |
torch.transpose(input, dim0, dim1):
import torch a = torch.rand(1,2,3) out = torch.transpose(a,0,1) print(out,out.shape)

torch.squeeze(input, dim=None, out=None):
import torch a = torch.rand(1,2,3) out = torch.squeeze(a) print(out,out.shape)

torch.unsqueeze(input, dim, out=None):
import torch a = torch.rand(1,2,3) out = torch.unsqueeze(a, -1) print(out,out.shape)

torch.unbind(tensor, dim=0):
import torch a = torch.rand(1,2,3) out = torch.unbind(a, dim=2) print(out)

Tensor的填充操作
torch.full()
torch.full((2,3),3.14) # tensor([[ 3.14,3.14,3.14],[3.14,3.14,3.14]])
Tensor的频谱操作
时域信号转频域信号
| torch.fft(input, signal_ndim, normalized=False) |
| torch.ifft(input, signal_ndim, normalized=False) |
| torch.rfft(input, signal_ndim, normalized=False, onesided=True) |
| torch.irfft(input, signal_ndim, normalized=False, onesided=True) |
| torch.stft(signa, frame_length, hop,...) |
浙公网安备 33010602011771号