cute 教程 01
本节主要探讨cute layouts,本质上一个Layout
是从坐标(coord)空间到索引(index)空间的映射。
Layout
s 提出了一个针对多维数组访问的通用接口,它隐藏了数据元素是如何存储在内存上的细节。例如,一个row-major的 MxN
的layout和一个col-major的 MxN
layout可以在软件层面上等同。
除此之外,cute也提供 Layout
代数,可以被组合或是操作来构建更复杂的layouts;这可以帮助用户做类似于将数据的layouts划分到线程的layouts上。
基础类型与概念
Integers
cute 主要利用 dynamic (known only at run-time) 和 static (known only at compile-time) 的integers
- dynamic: 就是c++中的普通数据类型,例如
int
size_t
等, 任何可以被std::is_integral<T>
接受的类型都是dynamic integer - static: 是
std::integral_const<Value>
的实例化,这些类型将value编码成static constexpr
的成员,同时也支持将其转为dynamic类型;cute也定义了它自己的与CUDA兼容的static 类型cute::C<value>
并重载了数学运算。cute提供了一些aliaseInt<1>
,Int<2>
,Int<3>
,_1
,_2
...
cute 尝试以同样方式处理static和dynamic integers。在下面的例子中,所有的static和dynamic都可以相互替代。
cute 提供了一些对于integers的traits:
cute::is_integral<T>
: 检测 T 是否是 static 或 dynamic integercute::is_std_integral<T>
: 检测 T 是否为 dynamiccute::is_static<T>
: 检查 T 是否为 空类型; 没有任何动态信息,例如其中没有数据成员...cute::is_constant<N, T>
: 检查 T 是否为 static 并且 值为N
Tuple
tuple是一个有限(0+)的顺序表。 cute::tuple
类似于 std::tuple
的行为,但是可以在host或device上工作,对模板参数有一定限制。
IntTuple
定义:要么是一个integer,要么是一个IntTuple的tuple。
IntTuple
的例子:
int{2}
, dynamic int 2Int<3>{}
, static int 3make_tuple(int{2}, Int<3>{})
: dynamic-2 和 static-3的tuple
cute在很多地方重用了 IntTuple
的概念,包括 Shape
Stride
Step
Coord
.
操作:
rank(IntTuple)
: 元素的个数,注意这里的元素可能是IntTupleget<I>(IntTuple)
: IntTuple中的第I
个元素depth(IntTuple)
: 层级的个数;单个int的depth为0;int的tuple的depth为1...size(IntTuple)
: 所有元素的积
我们可以通过括号表示层级,例如 (3, (6, 2), 8)
shapes 和 strides
就是 IntTuple
Layout
一个 Layout
是 (Shape, Stride)
的tuple,实现了从坐标空间到索引空间的映射。
Tensor
Layout
可以与数据(例如 指针或者数据)组合 来创建一个 Tensor
。通过 Layout
产生的索引可以用于迭代器的下标来检索对应的数据
layout的创建和使用
Layout
是 IntTuple
s 的 pair:
Shape
: 定义 layout 的抽象形状Stride
: 负责将shape内的坐标映射到索引空间
操作:
rank(Layout)
: Layout 有多少 modes,等价于 Layout.shape 的 tuple sizeget<I>(Layout)
: Layout 中 的第 I 个 sub-Layout, 其中I < rank
.depth(Layout)
: depth(Layout.shape)shape(Layout)
:stride(Layout)
size(Layout)
:size(shape(Layout))
cosize(Layout)
:A(size(A)-1) + 1
, 这个表达式是指找到A中最后一个坐标空间的元素,将其进行映射得到索引空间的元素,加1是因为一般是从0可以存储
层次化的访问函数
IntTuple
s 和 Layout
s 可以任意嵌套;为了方便起见,我们对于上面的操作将定义一些使用integers的序列的版本,而不仅仅是一个integer。例如,我们允许 get<I...>(x)
, 这里的 I...
是一个 "C++ 的参数pack", 用于指定 0+ 个模板参数。
get<I0, I1, ..., IN>(x) := get<IN(...(get<I0>(x))...)>
rank<I...>(x) := rank(get<I...>(x))
depth<I...>(x) := depth(get<I...>(x))
shape<I...>(x) := shape(get<I...>(x))
size<I...>(x) := size(get<I...>(x))
构造一个layout
可以包含任何static 或 dynamic int的组合
Layout s8 = make_layout(Int<8>{});
Layout d8 = make_layout(8);
Layout s2xs4 = make_layout(make_shape(Int<2>{}, Int<4>{}));
Layout s2xd4_a = make_layout(make_shape(Int<2>{}, 4), make_stride(Int<12>{}, Int<1>{}));
Layout s2xd4_col = make_layout(make_shape(Int<2>{},4),
LayoutLeft{});
Layout s2xd4_row = make_layout(make_shape(Int<2>{},4),
LayoutRight{});
Layout s2xh4 = make_layout(make_shape (2,make_shape (2,2)),
make_stride(4,make_stride(2,1)));
Layout s2xh4_col = make_layout(shape(s2xh4),
LayoutLeft{});
cute 通常使用 make_*
函数,主要利用CTAD(ctor template arg deduction) 避免 重复 static 或 dynamic int 类型。
当省略Stride
时,默认使用LayoutLeft
的stride,可以认为是通用 col-major 的 stride 生成。
可以对每个layout调用print方法得到下面的结果:
s8 : _8:_1
d8 : 8:_1
s2xs4 : (_2,_4):(_1,_2)
s2xd4 : (_2,4):(_1,_2)
s2xd4_a : (_2,4):(_12,_1)
s2xd4_col : (_2,4):(_1,_2)
s2xd4_row : (_2,4):(4,_1)
s2xh4 : (2,(2,2)):(4,(2,1))
s2xh4_col : (2,(2,2)):(_1,(2,4))
Shape:Stride
是 Layout 通常的记号;需要注意Shape
和Stride
是 congruent的, static_assert(congruent(my_shape, my_stride));
使用layout
layout的基本使用是将由shape定义的坐标空间映射到由stride定义的索引空间。
// mode-2
template <class Shape, class Stride>
void print2D(Layout<Shape, Stride> const& layout) {
for (int m{0}; m < size<0>(layout); ++m) {
for (int n{0}; n < size<1>(layout); ++n) {
printf("%3d ", layout(m,n));
}
printf("\n");
}
}
// any-mode
// template
void print1D(Layout<Shape, Stride> const& layout) {
for (int i{0}; i < size(layout); ++i){
printf("%3d ", layout(i));
}
}
cute 也提供更多可以可视化layout的工具, print_layout
函数可以产生一个 layout 映射的2维表格。
vector layouts
我们定义任何 rank==1
的layout为vector,例如 8:1
可以解释为8个元素的vector,每个索引是连续的。((4, 2): (2,1))
也是一个vector。
Layout: ((4,2): (2,1))
Coord: 0 1 2 3 4 5 6 7
Index: 0 2 4 6 1 3 5 7
矩阵的例子
那么我们可以定义任何rank==2
的layout为matrix。例如,
Shape : (4,2)
Stride : (1,4)
0 4
1 5
2 6
3 7
Layout 相关概念
Layout 兼容性
我们说layout A 和 B兼容是指A的shape和B的shape是兼容的,当:
- A的size等于B的size
- 所有A的坐标在B中也是有效的
例如:
- Shape (4,6) < ((2,2),6)
- ((2,3),4) !< ((2,2),(3,2))
- (24) !< 24
这意味着 兼容 是一个 weak 偏序,满足 反身性,反对称性,传递性
Layouts coord
有了兼容的概念,我们可以强调每个Layout
可以接受多种坐标;cute通过 colexicographical order 提供了这些坐标之间的映射关系。
因此,所有layout提供了两个基本的映射:
- 输入坐标 -> 自然坐标: 通过
Shape
- 自然坐标 -> 索引: 通过
Stride
坐标映射
从输入坐标到自然坐标的映射是通过从右到左的顺序进行操作。
以(3, (2,2))
的shape为例,存在3种坐标: 1-D, 2-D, 以及 自然坐标 h-D
1-D | 2-D | Natural | 1-D | 2-D | Natural | |
---|---|---|---|---|---|---|
0 |
(0,0) |
(0,(0,0)) |
9 |
(0,3) |
(0,(1,1)) |
|
1 |
(1,0) |
(1,(0,0)) |
10 |
(1,3) |
(1,(1,1)) |
|
2 |
(2,0) |
(2,(0,0)) |
11 |
(2,3) |
(2,(1,1)) |
|
3 |
(0,1) |
(0,(1,0)) |
12 |
(0,4) |
(0,(0,2)) |
|
4 |
(1,1) |
(1,(1,0)) |
13 |
(1,4) |
(1,(0,2)) |
|
5 |
(2,1) |
(2,(1,0)) |
14 |
(2,4) |
(2,(0,2)) |
|
6 |
(0,2) |
(0,(0,1)) |
15 |
(0,5) |
(0,(1,2)) |
|
7 |
(1,2) |
(1,(0,1)) |
16 |
(1,5) |
(1,(1,2)) |
|
8 |
(2,2) |
(2,(0,1)) |
17 |
(2,5) |
(2,(1,2)) |
cute::idx2crd(idx, shape)
负责坐标映射,主要返回这个shape对应的自然坐标。
索引映射
自然坐标到索引的映射可以通过自然坐标与Stride
的内积获得。cute::crd2idx(c, shape, stride)
负责索引映射。
Layout 操作
sublayouts
// layout<I...>
Layout a = Layout<Shape<_4,Shape<_3,_6>>>{}; // (4,(3,6)):(1,(4,12))
Layout a0 = layout<0>(a); // 4:1
Layout a1 = layout<1>(a); // (3,6):(4,12)
Layout a10 = layout<1,0>(a); // 3:4
Layout a11 = layout<1,1>(a); // 6:12
// select<I...>
Layout a = Layout<Shape<_2,_3,_5,_7>>{}; // (2,3,5,7):(1,2,6,30)
Layout a13 = select<1,3>(a); // (3,7):(2,30)
Layout a01 = select<0,1,3>(a); // (2,3,7):(1,2,30)
Layout a2 = select<2>(a); // (5):(6)
// take<ModeBegin, ModeEnd>
Layout a = Layout<Shape<_2,_3,_5,_7>>{}; // (2,3,5,7):(1,2,6,30)
Layout a13 = take<1,3>(a); // (3,5):(2,6)
Layout a14 = take<1,4>(a); // (3,5,7):(2,6,30)
// take<1,1> not allowed. Empty layouts not allowed.
Concatenation
Layout a = Layout<_3,_1>{}; // 3:1
Layout b = Layout<_4,_3>{}; // 4:3
Layout row = make_layout(a, b); // (3,4):(1,3)
Layout col = make_layout(b, a); // (4,3):(3,1)
Layout q = make_layout(row, col); // ((3,4),(4,3)):((1,3),(3,1))
Layout aa = make_layout(a); // (3):(1)
Layout aaa = make_layout(aa); // ((3)):((1))
Layout d = make_layout(a, make_layout(a), a); // (3,(3),3):(1,(1),1)
//
Layout a = Layout<_3,_1>{}; // 3:1
Layout b = Layout<_4,_3>{}; // 4:3
Layout ab = append(a, b); // (3,4):(1,3)
Layout ba = prepend(a, b); // (4,3):(3,1)
Layout c = append(ab, ab); // (3,4,(3,4)):(1,3,(1,3))
Layout d = replace<2>(c, b); // (3,4,4):(1,3,3)
聚合和展平
Layout a = Layout<Shape<_2,_3,_5,_7>>{}; // (_2,_3,_5,_7):(_1,_2,_6,_30)
Layout b = group<0,2>(a); // ((_2,_3),_5,_7):((_1,_2),_6,_30)
Layout c = group<1,3>(b); // ((_2,_3),(_5,_7)):((_1,_2),(_6,_30))
Layout f = flatten(b); // (_2,_3,_5,_7):(_1,_2,_6,_30)
Layout e = flatten(c); // (_2,_3,_5,_7):(_1,_2,_6,_30)
slicing
通常更适合在Tensor
s上使用。
总结
Shape
定义坐标空间Stride
定义索引空间