cute 教程 01

本节主要探讨cute layouts,本质上一个Layout是从坐标(coord)空间到索引(index)空间的映射。

Layouts 提出了一个针对多维数组访问的通用接口,它隐藏了数据元素是如何存储在内存上的细节。例如,一个row-major的 MxN的layout和一个col-major的 MxN layout可以在软件层面上等同。

除此之外,cute也提供 Layout 代数,可以被组合或是操作来构建更复杂的layouts;这可以帮助用户做类似于将数据的layouts划分到线程的layouts上。

基础类型与概念

Integers

cute 主要利用 dynamic (known only at run-time) 和 static (known only at compile-time) 的integers

  • dynamic: 就是c++中的普通数据类型,例如 int size_t 等, 任何可以被std::is_integral<T> 接受的类型都是dynamic integer
  • static: 是 std::integral_const<Value> 的实例化,这些类型将value编码成 static constexpr 的成员,同时也支持将其转为dynamic类型;cute也定义了它自己的与CUDA兼容的static 类型 cute::C<value> 并重载了数学运算。cute提供了一些aliase Int<1>, Int<2>, Int<3>, _1, _2...

cute 尝试以同样方式处理static和dynamic integers。在下面的例子中,所有的static和dynamic都可以相互替代。

cute 提供了一些对于integers的traits:

  • cute::is_integral<T>: 检测 T 是否是 static 或 dynamic integer
  • cute::is_std_integral<T>: 检测 T 是否为 dynamic
  • cute::is_static<T>: 检查 T 是否为 空类型; 没有任何动态信息,例如其中没有数据成员...
  • cute::is_constant<N, T>: 检查 T 是否为 static 并且 值为N

Tuple

tuple是一个有限(0+)的顺序表。 cute::tuple 类似于 std::tuple 的行为,但是可以在host或device上工作,对模板参数有一定限制。

IntTuple

定义:要么是一个integer,要么是一个IntTuple的tuple。

IntTuple的例子:

  • int{2}, dynamic int 2
  • Int<3>{}, static int 3
  • make_tuple(int{2}, Int<3>{}): dynamic-2 和 static-3的tuple

cute在很多地方重用了 IntTuple的概念,包括 Shape Stride Step Coord.

操作:

  • rank(IntTuple): 元素的个数,注意这里的元素可能是IntTuple
  • get<I>(IntTuple): IntTuple中的第 I 个元素
  • depth(IntTuple): 层级的个数;单个int的depth为0;int的tuple的depth为1...
  • size(IntTuple): 所有元素的

我们可以通过括号表示层级,例如 (3, (6, 2), 8)

shapes 和 strides

就是 IntTuple

Layout

一个 Layout(Shape, Stride)的tuple,实现了从坐标空间到索引空间的映射。

Tensor

Layout可以与数据(例如 指针或者数据)组合 来创建一个 Tensor。通过 Layout产生的索引可以用于迭代器的下标来检索对应的数据

layout的创建和使用

LayoutIntTuples 的 pair:

  • Shape: 定义 layout 的抽象形状
  • Stride: 负责将shape内的坐标映射到索引空间

操作:

  • rank(Layout): Layout 有多少 modes,等价于 Layout.shape 的 tuple size
  • get<I>(Layout): Layout 中 的第 I 个 sub-Layout, 其中 I < rank.
  • depth(Layout): depth(Layout.shape)
  • shape(Layout):
  • stride(Layout)
  • size(Layout): size(shape(Layout))
  • cosize(Layout): A(size(A)-1) + 1, 这个表达式是指找到A中最后一个坐标空间的元素,将其进行映射得到索引空间的元素,加1是因为一般是从0可以存储

层次化的访问函数

IntTuples 和 Layouts 可以任意嵌套;为了方便起见,我们对于上面的操作将定义一些使用integers的序列的版本,而不仅仅是一个integer。例如,我们允许 get<I...>(x), 这里的 I...是一个 "C++ 的参数pack", 用于指定 0+ 个模板参数。

  • get<I0, I1, ..., IN>(x) := get<IN(...(get<I0>(x))...)>
  • rank<I...>(x) := rank(get<I...>(x))
  • depth<I...>(x) := depth(get<I...>(x))
  • shape<I...>(x) := shape(get<I...>(x))
  • size<I...>(x) := size(get<I...>(x))

构造一个layout

可以包含任何static 或 dynamic int的组合

Layout s8 = make_layout(Int<8>{});
Layout d8 = make_layout(8);

Layout s2xs4 = make_layout(make_shape(Int<2>{}, Int<4>{}));

Layout s2xd4_a = make_layout(make_shape(Int<2>{}, 4), make_stride(Int<12>{}, Int<1>{}));

Layout s2xd4_col = make_layout(make_shape(Int<2>{},4),
                               LayoutLeft{});
Layout s2xd4_row = make_layout(make_shape(Int<2>{},4),
                               LayoutRight{});

Layout s2xh4 = make_layout(make_shape (2,make_shape (2,2)),
                           make_stride(4,make_stride(2,1)));
Layout s2xh4_col = make_layout(shape(s2xh4),
                               LayoutLeft{});

cute 通常使用 make_* 函数,主要利用CTAD(ctor template arg deduction) 避免 重复 static 或 dynamic int 类型。

当省略Stride时,默认使用LayoutLeft的stride,可以认为是通用 col-major 的 stride 生成。

可以对每个layout调用print方法得到下面的结果:

s8        :  _8:_1
d8        :  8:_1
s2xs4     :  (_2,_4):(_1,_2)
s2xd4     :  (_2,4):(_1,_2)
s2xd4_a   :  (_2,4):(_12,_1)
s2xd4_col :  (_2,4):(_1,_2)
s2xd4_row :  (_2,4):(4,_1)
s2xh4     :  (2,(2,2)):(4,(2,1))
s2xh4_col :  (2,(2,2)):(_1,(2,4))

Shape:Stride 是 Layout 通常的记号;需要注意ShapeStridecongruent的, static_assert(congruent(my_shape, my_stride));

使用layout

layout的基本使用是将由shape定义的坐标空间映射到由stride定义的索引空间。

// mode-2
template <class Shape, class Stride>
void print2D(Layout<Shape, Stride> const& layout) {
    for (int m{0}; m < size<0>(layout); ++m) {
        for (int n{0}; n < size<1>(layout); ++n) {
            printf("%3d ", layout(m,n));
        }
        printf("\n");
    }
}

// any-mode
// template
void print1D(Layout<Shape, Stride> const& layout) {
    for (int i{0}; i < size(layout); ++i){
        printf("%3d ", layout(i));
    }
}

cute 也提供更多可以可视化layout的工具, print_layout 函数可以产生一个 layout 映射的2维表格。

vector layouts

我们定义任何 rank==1的layout为vector,例如 8:1 可以解释为8个元素的vector,每个索引是连续的。((4, 2): (2,1)) 也是一个vector。

Layout: ((4,2): (2,1))
Coord: 0 1 2 3 4 5 6 7
Index: 0 2 4 6 1 3 5 7

矩阵的例子

那么我们可以定义任何rank==2的layout为matrix。例如,

Shape : (4,2)
Stride : (1,4)
0 4
1 5
2 6
3 7

Layout 相关概念

Layout 兼容性

我们说layout A 和 B兼容是指A的shape和B的shape是兼容的,当:

  • A的size等于B的size
  • 所有A的坐标在B中也是有效的

例如:

  • Shape (4,6) < ((2,2),6)
  • ((2,3),4) !< ((2,2),(3,2))
  • (24) !< 24

这意味着 兼容 是一个 weak 偏序,满足 反身性,反对称性,传递性

Layouts coord

有了兼容的概念,我们可以强调每个Layout可以接受多种坐标;cute通过 colexicographical order 提供了这些坐标之间的映射关系。

因此,所有layout提供了两个基本的映射:

  • 输入坐标 -> 自然坐标: 通过Shape
  • 自然坐标 -> 索引: 通过Stride

坐标映射

从输入坐标到自然坐标的映射是通过从右到左的顺序进行操作。

(3, (2,2))的shape为例,存在3种坐标: 1-D, 2-D, 以及 自然坐标 h-D

1-D 2-D Natural 1-D 2-D Natural
0 (0,0) (0,(0,0)) 9 (0,3) (0,(1,1))
1 (1,0) (1,(0,0)) 10 (1,3) (1,(1,1))
2 (2,0) (2,(0,0)) 11 (2,3) (2,(1,1))
3 (0,1) (0,(1,0)) 12 (0,4) (0,(0,2))
4 (1,1) (1,(1,0)) 13 (1,4) (1,(0,2))
5 (2,1) (2,(1,0)) 14 (2,4) (2,(0,2))
6 (0,2) (0,(0,1)) 15 (0,5) (0,(1,2))
7 (1,2) (1,(0,1)) 16 (1,5) (1,(1,2))
8 (2,2) (2,(0,1)) 17 (2,5) (2,(1,2))

cute::idx2crd(idx, shape)负责坐标映射,主要返回这个shape对应的自然坐标。

索引映射

自然坐标到索引的映射可以通过自然坐标与Stride的内积获得。cute::crd2idx(c, shape, stride)负责索引映射。

Layout 操作

sublayouts

// layout<I...>
Layout a   = Layout<Shape<_4,Shape<_3,_6>>>{}; // (4,(3,6)):(1,(4,12))
Layout a0  = layout<0>(a);                     // 4:1
Layout a1  = layout<1>(a);                     // (3,6):(4,12)
Layout a10 = layout<1,0>(a);                   // 3:4
Layout a11 = layout<1,1>(a);                   // 6:12

// select<I...>
Layout a   = Layout<Shape<_2,_3,_5,_7>>{};     // (2,3,5,7):(1,2,6,30)
Layout a13 = select<1,3>(a);                   // (3,7):(2,30)
Layout a01 = select<0,1,3>(a);                 // (2,3,7):(1,2,30)
Layout a2  = select<2>(a);                     // (5):(6)

// take<ModeBegin, ModeEnd>
Layout a   = Layout<Shape<_2,_3,_5,_7>>{};     // (2,3,5,7):(1,2,6,30)
Layout a13 = take<1,3>(a);                     // (3,5):(2,6)
Layout a14 = take<1,4>(a);                     // (3,5,7):(2,6,30)
// take<1,1> not allowed. Empty layouts not allowed.

Concatenation

Layout a = Layout<_3,_1>{};                     // 3:1
Layout b = Layout<_4,_3>{};                     // 4:3
Layout row = make_layout(a, b);                 // (3,4):(1,3)
Layout col = make_layout(b, a);                 // (4,3):(3,1)
Layout q   = make_layout(row, col);             // ((3,4),(4,3)):((1,3),(3,1))
Layout aa  = make_layout(a);                    // (3):(1)
Layout aaa = make_layout(aa);                   // ((3)):((1))
Layout d   = make_layout(a, make_layout(a), a); // (3,(3),3):(1,(1),1)

// 
Layout a = Layout<_3,_1>{};                     // 3:1
Layout b = Layout<_4,_3>{};                     // 4:3
Layout ab = append(a, b);                       // (3,4):(1,3)
Layout ba = prepend(a, b);                      // (4,3):(3,1)
Layout c  = append(ab, ab);                     // (3,4,(3,4)):(1,3,(1,3))
Layout d  = replace<2>(c, b);                   // (3,4,4):(1,3,3)

聚合和展平

Layout a = Layout<Shape<_2,_3,_5,_7>>{};  // (_2,_3,_5,_7):(_1,_2,_6,_30)
Layout b = group<0,2>(a);                 // ((_2,_3),_5,_7):((_1,_2),_6,_30)
Layout c = group<1,3>(b);                 // ((_2,_3),(_5,_7)):((_1,_2),(_6,_30))
Layout f = flatten(b);                    // (_2,_3,_5,_7):(_1,_2,_6,_30)
Layout e = flatten(c);                    // (_2,_3,_5,_7):(_1,_2,_6,_30)

slicing

通常更适合在Tensors上使用。

总结

  • Shape定义坐标空间
  • Stride定义索引空间
posted @ 2025-03-20 15:30  xwher  阅读(83)  评论(0)    收藏  举报