前言 去年国庆时想写一下 CUTLASS 3.x 相关的博客(当时已经写了一部分 layout algebra 的内容),不过因为一些原因搁置了。在那之后知乎也出现了一些不错的博文来介绍 CUTLASS 3.x(比如 @reed 的系列博文),官方文档在半年前发布 3.4 版本的时候也更新了不少内容(感觉)。最近组里的一些活和这个有关系,就重新复习(预习)了一下官方的文档,过程中发现文档上关于 layout composition 满足左分配律的条件似乎有些纰漏,仔细研究了一下,发现确实只是个无伤大雅的纰漏,不过研究的过程还是有点意思的,就记录一下。
CuTe Layout 先简单介绍一下 CuTe Layout。一个 layout 可以看作两个 int-tuple 的组合 ,其中 和 分别称为 的 shape 和 stride(注:在 CuTe Layout 的实际定义中, 和 也可以是嵌套的 tuple,但在大部分时候可以将它们展平来看待,因此本文仅考虑非嵌套的 tuple)。 也可以表示为 ,其中每个 被称为 的一个 sub-layout。
一个 layout 可以看作一个限制定义域的整数线性变换,将一个 维整数坐标 ( ) 映射到一维坐标:
每个 layout 都有一个 natural layout ,直观上理解就是一个 shape 和其相同的 contiguous column-major layout:
也可以接收一个一维坐标 作为参数,将其用 进行逆映射,得到一个 维坐标后,再映射到一个新的一维坐标,可以递归定义为(其中 为取模符号):
Layout Composition 由此可以定义 layout 的 composition 操作 :两个 layout 和 的组合 满足 ,也就是 将多维坐标 映射为一维坐标 后再由 映射为一维坐标 。 假设 , ,如果要让 仍然是个 layout(能表示成 shape:stride 的形式),则有如下要求(其中 表示 整除 ):
\tag not allowed in aligned environment \begin{aligned} &\exists 1\le j\le k\le n : (\pi_ {j-1}|w|\pi_j)\land (\pi_ {k-1}|wx|\pi_ {k})\tag{1} \ \text{where }& \pi_k=\prod_ {i=0}^k s_i\quad (\text{let }s_0=1) \end{aligned}
简单来说就是 stride (以及整个 layout 所覆盖的范围 ) 在整除这个偏序关系上要介于 的某两个相邻的 shape 前缀积 和 之间。CuTe 实现上是通过 shape_div
和 shape_mod
操作的约束来表现这个限制,我这里将其本质概括了一下。
关于左分配律满足条件的纰漏 接下来就是 CuTe 的文档和实现上存在纰漏的地方了。前文讨论了形如当 , 时, 的结果仍然是合法的 layout 的条件。而对于一般的情况,也就是 , 时,官方的文档认为当 是单射函数时,满足左分配律 。这个结论咋一看挺显然的,但是稍微推理一下就发现并不是这么回事。
对于任意 ( ) ,我们有:
要让 ,就要有:
简单起见,我们这里仅讨论 的情况。设 ,设 , ,有:
因为 ,所以有:
Misplaced & 0=d_1(r-r%s_1)-d_2\lfloor r/s_1\rfloor=\begin{cases} 0 & 0\le r\le s_1-1\ d_1s_1-d_2 & s_1\le r\le 2s_1-2 \end{cases}
也就是说以下条件至少要成立一个:
\tag not allowed in aligned environment \begin{aligned} &\forall 0\le c_i < x_i:\sum_ {i=1}^2(w_ic_i)%s_1 < s_1\tag{2}\ &d_1s_1=d_2\tag{3} \end{aligned}
条件 可以转化为:
考虑到约束 的要求,有:
不难看出条件 有可能都不满足,即使按照官方文档的要求 是单射函数。我们可以构造一个简单的例子,令 , ,显然 是个单射函数,但是条件 和 都不满足。可以编写代码验证一下(这里为方便起见用到了 CuTe 的 python 接口 pycute
,可以参考官方文档 进行安装):
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 import pycute L = pycute.Layout((36 , 18 ), (1 , 72 )) P = pycute.Layout((9 , 4 ), (4 , 9 )) Q = pycute.composition(L, P)print (f"L = {L} " )print (f"P = {P} " )print (f"Q = {Q} " ) vis = set ()for c in range (P.size()): p = P(c) if p in vis: print ("P is not injective" ) vis.add(p) expected, actual = L(p), Q(c) if expected != actual: print (f"(L o P)({c} ) = L(P({c} )) = {expected} != {actual} = Q({c} )" )
输出结果如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 L = : P = : Q = : = L) = 73 != 37 = Q = L) = 77 != 41 = Q = L) = 74 != 38 = Q = L) = 78 != 42 = Q = L) = 82 != 46 = Q = L) = 86 != 50 = Q = L) = 75 != 39 = Q = L) = 79 != 43 = Q = L) = 83 != 47 = Q = L) = 87 != 51 = Q = L) = 91 != 55 = Q = L) = 95 != 59 = Q
既然官方的实现存在这样的纰漏,那些使用了 CUTLASS 3.x 的算子岂不是可能有 BUG?这点应该不用担心,目前来说大概是没什么问题的,据我观察官方使用到 composition 的场合, 中的 总是满足 (不失一般性,假设 的 sub-layout 按 stride 升序排序,也就是 ),在这个条件下,layout composition 总是满足左分配律的。事实上,日常中使用的 layout 一般也确实满足这个条件,而且 layout complement 操作的合法性检查似乎也会保证这个性质(composition 经常和 complement 结合着使用)。至于为什么这个条件是 layout composition 满足左分配律的一个充分条件,以后有空再写写证明吧……