CUTE 布局代数的理解

本文人工精心排版,无AI生成。

Layout

LayoutShapeStride组成,Shape决定了Coord的范围,CoordStride的内积决定了元素的线性地址Index
所以,Layout本质是一个整数到整数的映射函数R(c)。 其中,c可以是整数、二维坐标、自然坐标(h-D),映射结果为元素的线性地址。

二维与自然坐标一般为逆字典序(a,b,c,...),表现为a先动,b次之,c最后动。可以类比cuda中的threadIdxblockIdx等。

ldaPitch是类似的概念,lda常见于早期的矩阵乘法实现,表示矩阵行(列)之间的跨度,单位是元素个数;Pitch类似,不过单位是字节。

lda中的ld表示leading dimension

应用示例

    auto g_ = make_layout(Layout<Shape<_8>>{}, Layout<Shape<_9>>{});
    print(g_);printf("\n"); // ((_8),(_9)):((_1),(_1))
    std::cout << size(g_) <<std::endl; // _72
    std::cout << cosize(g_) <<std::endl; // _16

cosize是映射范围的大小,物理大小,size是定义域的大小,逻辑上的大小。

合并 Coalesce

考虑一个仅包含两个整数模式的布局,进行合并coalesce((s0,s1):(d0,d1)),此布局的合并结果为s0:d0 ++ s1:d1,有四种情况:

  • s0:d0 ++ _1:d1 => s0:d0 Shape中为1的地方,坐标只能为0,对于最终index的贡献为0*d=0
  • _1:d0 ++ s1:d1 => s1:d1,同上
  • s0:d0 ++ s1:s0*d0 => s0*s1:d0,即(s0,s1):(d0,s0*d0)这种模式可以合并,结果为(s0*s1,d0)
  • 其余情况无法合并

第三种情况的推导:考虑一维坐标x,二维坐标i,j,有x = i+j*s0(注意,是逆字典序),对应的index为i*d0+j*(s0*d0) = (i+j*s0)*d0 = x*d0,
所以原来的映射可以简化为x*d0,原来的定义域为s0*s1,所以合并结果为(s0*s1,d0)

通过两两合并,来处理秩>2的情况

按模式合并

coalesce中,trg_profile的数值不重要,只有模式重要,整数对应的模式需要coalesce

复合 Composition

R = A o B, R(c) = A(B(c))

对于单模复合

(a:b) o (s:d)
对于坐标 \(i\)
\(B(i) = i*d, A(B(i)) = i*d*b\)
坐标的定义域取决于\(B\),所以复合的结果是 \((s:b*d)\)

对于多模复合

当 B 是单射(injective)时,复合对拼接(concatenation)满足左分配律:

$ A ∘ B = A ∘ (B₀, B₁, ...) = (A ∘ B₀, A ∘ B₁, ...) $

单射指坐标不同时,映射到的线性地址也不同。 在某些KV-Cache中,多个不同坐标可能映射到同一地址,这时就不是单射了。

现在问题转换为\(A = (M₀, M₁, ..., Mₙ):(d₀, d₁, ..., dₙ),Bᵢ = s:d\),计算$ A o Bᵢ $:

  • 第一步,计算“步进布局”(Strided Layout):
对 shape (M₀, M₁, ...) 除以 d:
- 若 M₀ >= d:shape 变为 (M₀/d, M₁, ...),对应 stride 变为 (d·d₀, d₁, ...),结束
- 若 M₀ < d 且 M₀ | d: M₀ = 1,对应 stride 变为 (d·d₀, d₁, ...), d = d/M₀ 处理下一维

注意,M₀ | d表示M₀能整除d,这叫做步长整除条件,在Cute中可能会被静态检查

  • 第二步,对Shape取模:
对 shape (M₀, M₁, ...) 取模 s:
Mᵢ = min(Mᵢ, s), s = ceil(s / Mᵢ)
- 若 s <= 1:后续M变为1,结束
- 若 s > 1:处理下一维

注意,复合后定义域有s个元素

  • 第三步,合并维度:
    注意将Shape中为1的维度去掉

应用示例

  • 20:2 o (5,4):(4,1),将布局解释为按行主序排列的 5x4 矩阵:
拆分: 20:2 o 5:4, 20:2 o 4:1
/4 : 5:8, /1: 20:2
%5 : 5:8, %4: 4:2
合并:(5,4):(8,2)
  • (10,2):(16,4) o (5,4):(1,5),将布局解释为列主序的 5x4 矩阵:
拆分: (10,2):(16,4) o (5:1), (10,2):(16,4) o (4:5)
/1: (10,2):(16,4), /5: (2, 2):(80,4)
%5: (5,1):(16,4), %4: (2,2):(80,4)
合并: 5:16, (2,2):(80,4) -> (5,(2,2)):(16,(80,4))

按模式复合

我们通过Tiler来实现对应维度的复合,TilerLayout组成的元组。Shape可以视为Stride=1的若干Layout组成的元组

Shape<_64, _32>{}

// 语义上等价于(示意)
make_tuple(
  Layout<Shape<_64>, Stride<_1>>{},
  Layout<Shape<_32>, Stride<_1>>{}
)
// (12,(4,8)):(59,(13,1))
auto a = make_layout(make_shape (12,make_shape ( 4,8)),
                     make_stride(59,make_stride(13,1)));
// (3, 8)
auto tiler = make_shape(Int<3>{}, Int<8>{});
// Equivalent to <3:1, 8:1>
// auto tiler = make_tile(Layout<_3,_1>{},  // Apply 3:1 to mode-0
//                        Layout<_8,_1>{}); // Apply 8:1 to mode-1

// (_3,(4,2)):(59,(13,1))
auto result = composition(a, tiler);
// Identical to
auto same_r = make_layout(composition(layout<0>(a), get<0>(tiler)),
                          composition(layout<1>(a), get<1>(tiler)));

补 Complement

如果我们把Composition看作从Layout A中挑选特定的坐标,即Tile里有什么,那么Complement就是告诉我们那些没有被选中的坐标怎么描述?

补具有如下的后置条件:

// @post cosize(make_layout(@a layout_a, @a result))) >= size(@a cotarget)
// @post cosize(@a result) >= round_up(size(@a cotarget), cosize(@a layout_a))
// @post for all i, 1 <= i < size(@a result),
//         @a result(i-1) < @a result(i)
// @post for all i, 1 <= i < size(@a result),
//         for all j, 0 <= j < size(@a layout_a),
//           @a result(i) != @a layout_a(j)
Layout complement(LayoutA const& layout_a, Shape const& cotarget)

其中,cosize指目标空间的大小,即映射到的范围,亦即线性内存地址空间。size指shape的大小,即函数的定义域。cotarget可以理解为全集。
这几个后置条件本质就是在说Complement的性质:

  • 大小有界:对应前两个后置条件
  • 有序性,步长是正的,且递增
  • A和R具有不相交的值域

应用示例

A = (4,2,3):(2,1,8) 
B = 4:2
B^* = (2,3):(1,8)

除法

\(A \oslash B := A \circ (B,B^*)\)

实现为

template <class LShape, class LStride,
          class TShape, class TStride>
auto logical_divide(Layout<LShape,LStride> const& layout,
                    Layout<TShape,TStride> const& tiler)
{
  return composition(layout, make_layout(tiler, complement(tiler, size(layout))));
}

从公式中可以看到结果分为两部分,\(A \circ B\),这可以看作一个Tile;$ A \circ B^* $,这是B在A中的补,它描述了Tile在A中的布局。

    // A: shape is (9,32)
    auto layout_a = make_layout(make_shape (Int< 9>{}, make_shape (Int< 4>{}, Int<8>{})),
                                make_stride(Int<59>{}, make_stride(Int<13>{}, Int<1>{})));
    // B: shape is (3,8)
    auto tiler = make_tile(Layout<_3,_3>{},   // Apply     3:3     to mode-0
                           Layout<Shape <_2,_4>,      // Apply (2,4):(1,8) to mode-1
                                  Stride<_1,_8>>{});

    auto ld = logical_divide(layout_a, tiler);
    print(ld);printf("\n"); // ((_3,_3),((_2,_4),(_2,_2))):((_177,_59),((_13,_2),(_26,_1)))
    auto zd = zipped_divide(layout_a, tiler);
    print(zd);printf("\n"); // ((_3,(_2,_4)),(_3,(_2,_2))):((_177,(_13,_2)),(_59,(_26,_1)))

image1

注意,未被tile的维度L保持不动。

四种 divide 的典型使用场景

  • logical_divide — 最通用的形式,当需要明确区分"在哪个 tile 里"和"tile 内部位置",且两个方向需要独立操作时使用。
  • zipped_divide — kernel 中最常见。get<0>(layout) 拿到 tile 的 layout,get<1>(layout) 拿到遍历所有 tile 的坐标。例如把一个矩阵切成很多 (TileM, TileN) 的小块,用 for 循环遍历第二个维度。
  • tiled_divide — 当需要对 RestM 和 RestN 分别循环(双重 for 循环),但 tile 内坐标需要作为整体传给 MMA/copy atom 时用。
  • flat_divide — 最低级别,与直接操作所有维度等价,适合需要完全展开坐标的底层操作。

注意,logical_divide的结果保留了模式A的语义,其Shape((_3,_3),((_2,_4),(_2,_2)))M-mode对应9,N-mode对应32。

上面layout_a布局如下:

image2

在使用zipped_divide后,布局如下:

image3

乘法

\(A \otimes B := (A, A^* \circ B)\)
实现为

template <class LShape, class LStride,
          class TShape, class TStride>
auto logical_product(Layout<LShape,LStride> const& layout,
                     Layout<TShape,TStride> const& tiler)
{
  return make_layout(layout, composition(complement(layout, size(layout)*cosize(tiler)), tiler));
}

在值域空间内复制,需要容纳 tiler 份 layout。

和 Concatenation 的区别

  • 拼接只是单纯地改变了整数到整数的映射,不能保证物理上不重叠。
  • 乘法通过 Complement 保证物理上不重叠。Layout B里的Stride可能会被合理放大。

应用示例

  • 对于logical/zipped/tiled/flat需要区分乘以Layout,和乘以Tile(按维度乘);

image5

乘以Layout的效果是,将A视作一个点,然后按照B的布局重复

image4

注意,BTile,所以采用方括号表示,它表示按对应维度进行乘法,效果是对应维度的大小翻倍。
这种写法不推荐,因为B不够直观。

    Layout<Shape<_2,_5>,Stride<_5,_1>> a{};
    Layout<Shape<_3,_4>,Stride<_1,_3>> b{}; // auto b = make_layout(make_shape(_3,_4),make_stride(_1,_3))
    
    auto b1 = make_layout(Layout<_3,_1>{},Layout<_4,_3>{}); // concatenation
    Layout<Layout<_3,_1>,Layout<_4,_3>> b2{}; // 这是嵌套的布局 _3:_1:_4:_3
    
    auto tiler = make_tile(Layout<_3,_5>{},Layout<_4,_6>{}); 
    // (_3:_5,_4:_6) tiler是Layout的元组
    
    // 下面两种方式等价
    auto r = blocked_product(a,b); // 更推荐这种方式得到上图的效果
    auto e = logical_product(a,tiler);
    // r和e结果相同,((_2,_3),(_5,_4)):((_5,_10),(_1,_30)),区别在于:
    // logical_product需要使用Tile类,其stride要考虑到a
    // blocked_tiler可以使用Layout来描述tiler,更直观
    
    // 下面与zipped tiled除法有点像
    auto f = zipped_product(a,tiler);
    print(f);printf("\n"); // ((_2,_5),(_3,_4)):((_5,_1),(_10,_30)) 第1维是tiler的形状
    auto g = tiled_product(a,tiler); // ((_2,_5),_3,_4):((_5,_1),_10,_30) tiler的形状被展开
    print(g);printf("\n");
  • 对于blocked/raked总是乘以Layout
    下图展示了raked_product的效果,元素被交错
    image6
posted @ 2026-04-07 10:03  许培风  阅读(11)  评论(0)    收藏  举报