CUTE 布局代数的理解
本文人工精心排版,无AI生成。
Layout
Layout由Shape和Stride组成,Shape决定了Coord的范围,Coord与Stride的内积决定了元素的线性地址Index。
所以,Layout本质是一个整数到整数的映射函数R(c)。 其中,c可以是整数、二维坐标、自然坐标(h-D),映射结果为元素的线性地址。
二维与自然坐标一般为逆字典序(a,b,c,...),表现为a先动,b次之,c最后动。可以类比cuda中的threadIdx,blockIdx等。
lda与Pitch是类似的概念,lda常见于早期的矩阵乘法实现,表示矩阵行(列)之间的跨度,单位是元素个数;Pitch类似,不过单位是字节。
lda中的ld表示leading dimension。
应用示例
auto g_ = make_layout(Layout<Shape<_8>>{}, Layout<Shape<_9>>{});
print(g_);printf("\n"); // ((_8),(_9)):((_1),(_1))
std::cout << size(g_) <<std::endl; // _72
std::cout << cosize(g_) <<std::endl; // _16
cosize是映射范围的大小,物理大小,size是定义域的大小,逻辑上的大小。
合并 Coalesce
考虑一个仅包含两个整数模式的布局,进行合并coalesce((s0,s1):(d0,d1)),此布局的合并结果为s0:d0 ++ s1:d1,有四种情况:
s0:d0 ++ _1:d1 => s0:d0Shape中为1的地方,坐标只能为0,对于最终index的贡献为0*d=0_1:d0 ++ s1:d1 => s1:d1,同上s0:d0 ++ s1:s0*d0 => s0*s1:d0,即(s0,s1):(d0,s0*d0)这种模式可以合并,结果为(s0*s1,d0)- 其余情况无法合并
第三种情况的推导:考虑一维坐标x,二维坐标i,j,有x = i+j*s0(注意,是逆字典序),对应的index为i*d0+j*(s0*d0) = (i+j*s0)*d0 = x*d0,
所以原来的映射可以简化为x*d0,原来的定义域为s0*s1,所以合并结果为(s0*s1,d0)
通过两两合并,来处理秩>2的情况
按模式合并
在coalesce中,trg_profile的数值不重要,只有模式重要,整数对应的模式需要coalesce。
复合 Composition
R = A o B, R(c) = A(B(c)),
对于单模复合
(a:b) o (s:d)
对于坐标 \(i\),
\(B(i) = i*d, A(B(i)) = i*d*b\),
坐标的定义域取决于\(B\),所以复合的结果是 \((s:b*d)\)
对于多模复合
当 B 是单射(injective)时,复合对拼接(concatenation)满足左分配律:
$ A ∘ B = A ∘ (B₀, B₁, ...) = (A ∘ B₀, A ∘ B₁, ...) $
单射指坐标不同时,映射到的线性地址也不同。 在某些KV-Cache中,多个不同坐标可能映射到同一地址,这时就不是单射了。
现在问题转换为\(A = (M₀, M₁, ..., Mₙ):(d₀, d₁, ..., dₙ),Bᵢ = s:d\),计算$ A o Bᵢ $:
- 第一步,计算“步进布局”(Strided Layout):
对 shape (M₀, M₁, ...) 除以 d:
- 若 M₀ >= d:shape 变为 (M₀/d, M₁, ...),对应 stride 变为 (d·d₀, d₁, ...),结束
- 若 M₀ < d 且 M₀ | d: M₀ = 1,对应 stride 变为 (d·d₀, d₁, ...), d = d/M₀ 处理下一维
注意,M₀ | d表示M₀能整除d,这叫做步长整除条件,在Cute中可能会被静态检查
- 第二步,对Shape取模:
对 shape (M₀, M₁, ...) 取模 s:
Mᵢ = min(Mᵢ, s), s = ceil(s / Mᵢ)
- 若 s <= 1:后续M变为1,结束
- 若 s > 1:处理下一维
注意,复合后定义域有s个元素
- 第三步,合并维度:
注意将Shape中为1的维度去掉
应用示例
20:2 o (5,4):(4,1),将布局解释为按行主序排列的 5x4 矩阵:
拆分: 20:2 o 5:4, 20:2 o 4:1
/4 : 5:8, /1: 20:2
%5 : 5:8, %4: 4:2
合并:(5,4):(8,2)
(10,2):(16,4) o (5,4):(1,5),将布局解释为列主序的 5x4 矩阵:
拆分: (10,2):(16,4) o (5:1), (10,2):(16,4) o (4:5)
/1: (10,2):(16,4), /5: (2, 2):(80,4)
%5: (5,1):(16,4), %4: (2,2):(80,4)
合并: 5:16, (2,2):(80,4) -> (5,(2,2)):(16,(80,4))
按模式复合
我们通过Tiler来实现对应维度的复合,Tiler是Layout组成的元组。Shape可以视为Stride=1的若干Layout组成的元组
Shape<_64, _32>{}
// 语义上等价于(示意)
make_tuple(
Layout<Shape<_64>, Stride<_1>>{},
Layout<Shape<_32>, Stride<_1>>{}
)
// (12,(4,8)):(59,(13,1))
auto a = make_layout(make_shape (12,make_shape ( 4,8)),
make_stride(59,make_stride(13,1)));
// (3, 8)
auto tiler = make_shape(Int<3>{}, Int<8>{});
// Equivalent to <3:1, 8:1>
// auto tiler = make_tile(Layout<_3,_1>{}, // Apply 3:1 to mode-0
// Layout<_8,_1>{}); // Apply 8:1 to mode-1
// (_3,(4,2)):(59,(13,1))
auto result = composition(a, tiler);
// Identical to
auto same_r = make_layout(composition(layout<0>(a), get<0>(tiler)),
composition(layout<1>(a), get<1>(tiler)));
补 Complement
如果我们把Composition看作从Layout A中挑选特定的坐标,即Tile里有什么,那么Complement就是告诉我们那些没有被选中的坐标怎么描述?
补具有如下的后置条件:
// @post cosize(make_layout(@a layout_a, @a result))) >= size(@a cotarget)
// @post cosize(@a result) >= round_up(size(@a cotarget), cosize(@a layout_a))
// @post for all i, 1 <= i < size(@a result),
// @a result(i-1) < @a result(i)
// @post for all i, 1 <= i < size(@a result),
// for all j, 0 <= j < size(@a layout_a),
// @a result(i) != @a layout_a(j)
Layout complement(LayoutA const& layout_a, Shape const& cotarget)
其中,cosize指目标空间的大小,即映射到的范围,亦即线性内存地址空间。size指shape的大小,即函数的定义域。cotarget可以理解为全集。
这几个后置条件本质就是在说Complement的性质:
- 大小有界:对应前两个后置条件
- 有序性,步长是正的,且递增
- A和R具有不相交的值域
应用示例
A = (4,2,3):(2,1,8)
B = 4:2
B^* = (2,3):(1,8)
除法
\(A \oslash B := A \circ (B,B^*)\)
实现为
template <class LShape, class LStride,
class TShape, class TStride>
auto logical_divide(Layout<LShape,LStride> const& layout,
Layout<TShape,TStride> const& tiler)
{
return composition(layout, make_layout(tiler, complement(tiler, size(layout))));
}
从公式中可以看到结果分为两部分,\(A \circ B\),这可以看作一个Tile;$ A \circ B^* $,这是B在A中的补,它描述了Tile在A中的布局。
// A: shape is (9,32)
auto layout_a = make_layout(make_shape (Int< 9>{}, make_shape (Int< 4>{}, Int<8>{})),
make_stride(Int<59>{}, make_stride(Int<13>{}, Int<1>{})));
// B: shape is (3,8)
auto tiler = make_tile(Layout<_3,_3>{}, // Apply 3:3 to mode-0
Layout<Shape <_2,_4>, // Apply (2,4):(1,8) to mode-1
Stride<_1,_8>>{});
auto ld = logical_divide(layout_a, tiler);
print(ld);printf("\n"); // ((_3,_3),((_2,_4),(_2,_2))):((_177,_59),((_13,_2),(_26,_1)))
auto zd = zipped_divide(layout_a, tiler);
print(zd);printf("\n"); // ((_3,(_2,_4)),(_3,(_2,_2))):((_177,(_13,_2)),(_59,(_26,_1)))

注意,未被tile的维度L保持不动。
四种 divide 的典型使用场景
- logical_divide — 最通用的形式,当需要明确区分"在哪个 tile 里"和"tile 内部位置",且两个方向需要独立操作时使用。
- zipped_divide — kernel 中最常见。get<0>(layout) 拿到 tile 的 layout,get<1>(layout) 拿到遍历所有 tile 的坐标。例如把一个矩阵切成很多 (TileM, TileN) 的小块,用 for 循环遍历第二个维度。
- tiled_divide — 当需要对 RestM 和 RestN 分别循环(双重 for 循环),但 tile 内坐标需要作为整体传给 MMA/copy atom 时用。
- flat_divide — 最低级别,与直接操作所有维度等价,适合需要完全展开坐标的底层操作。
注意,logical_divide的结果保留了模式A的语义,其Shape((_3,_3),((_2,_4),(_2,_2))),M-mode对应9,N-mode对应32。
上面layout_a布局如下:

在使用zipped_divide后,布局如下:

乘法
\(A \otimes B := (A, A^* \circ B)\)
实现为
template <class LShape, class LStride,
class TShape, class TStride>
auto logical_product(Layout<LShape,LStride> const& layout,
Layout<TShape,TStride> const& tiler)
{
return make_layout(layout, composition(complement(layout, size(layout)*cosize(tiler)), tiler));
}
在值域空间内复制,需要容纳 tiler 份 layout。
和 Concatenation 的区别
- 拼接只是单纯地改变了整数到整数的映射,不能保证物理上不重叠。
- 乘法通过
Complement保证物理上不重叠。Layout B里的Stride可能会被合理放大。
应用示例
- 对于
logical/zipped/tiled/flat需要区分乘以Layout,和乘以Tile(按维度乘);

乘以Layout的效果是,将A视作一个点,然后按照B的布局重复

注意,B是Tile,所以采用方括号表示,它表示按对应维度进行乘法,效果是对应维度的大小翻倍。
这种写法不推荐,因为B不够直观。
Layout<Shape<_2,_5>,Stride<_5,_1>> a{};
Layout<Shape<_3,_4>,Stride<_1,_3>> b{}; // auto b = make_layout(make_shape(_3,_4),make_stride(_1,_3))
auto b1 = make_layout(Layout<_3,_1>{},Layout<_4,_3>{}); // concatenation
Layout<Layout<_3,_1>,Layout<_4,_3>> b2{}; // 这是嵌套的布局 _3:_1:_4:_3
auto tiler = make_tile(Layout<_3,_5>{},Layout<_4,_6>{});
// (_3:_5,_4:_6) tiler是Layout的元组
// 下面两种方式等价
auto r = blocked_product(a,b); // 更推荐这种方式得到上图的效果
auto e = logical_product(a,tiler);
// r和e结果相同,((_2,_3),(_5,_4)):((_5,_10),(_1,_30)),区别在于:
// logical_product需要使用Tile类,其stride要考虑到a
// blocked_tiler可以使用Layout来描述tiler,更直观
// 下面与zipped tiled除法有点像
auto f = zipped_product(a,tiler);
print(f);printf("\n"); // ((_2,_5),(_3,_4)):((_5,_1),(_10,_30)) 第1维是tiler的形状
auto g = tiled_product(a,tiler); // ((_2,_5),_3,_4):((_5,_1),_10,_30) tiler的形状被展开
print(g);printf("\n");
- 对于
blocked/raked总是乘以Layout
下图展示了raked_product的效果,元素被交错
![image6]()


浙公网安备 33010602011771号