torch.meshgrid函数
在此记录下torch.meshgrid的用法,该函数常常用于生成二维的网格:
|
1
2
3
4
5
6
7
8
9
10
11
|
>>> x = torch.tensor([1, 2, 3])>>> y = torch.tensor([4, 5, 6])>>> grid_x, grid_y = torch.meshgrid(x, y)>>> grid_xtensor([[1, 1, 1], [2, 2, 2], [3, 3, 3]])>>> grid_ytensor([[4, 5, 6], [4, 5, 6], [4, 5, 6]]) |
另一个例子:
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
|
>>> import torch>>> h = 6>>> w = 10>>> ys,xs = torch.meshgrid(torch.arange(h), torch.arange(w))>>> xs.shapetorch.Size([6, 10])>>> ys.shapetorch.Size([6, 10])>>> xstensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])>>> ystensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2, 2, 2, 2, 2], [3, 3, 3, 3, 3, 3, 3, 3, 3, 3], [4, 4, 4, 4, 4, 4, 4, 4, 4, 4], [5, 5, 5, 5, 5, 5, 5, 5, 5, 5]])>>> xys = torch.stack([xs, ys], dim=-1)>>> xys.shapetorch.Size([6, 10, 2]) |
需要注意的点:
1. torch.meshgrid函数的输入是若干个(N个)一维Tensor或者若干个标量。
2. torch.meshgrid函数的输出有N个,每个输出都是N维的。
3. torch.meshgrid函数的每个输出tensor的shape都为(d1,d2,d3...dN)(d1,d2,d3...dN),其中didi为第i个输入向量的长度。
4. torch.meshgrid函数的每个输出有什么不同?答:为该输出对应输入向量在其他维度舒展开的结果。
5. torch的meshgrid实现和numpy的meshgrid实现有所不同,后者“可能”能够更直接地获取我们需要的东西,而torch的meshgrid调用后可能还需要做一个转置。
本文来自博客园,作者:海_纳百川,转载请注明原文链接:https://www.cnblogs.com/chentiao/p/16683675.html,如有侵权联系删除

浙公网安备 33010602011771号