CUTLASS 3.x CuTe Layout Composition 的一处纰漏

First Post:

Blog Link:

前言

去年国庆时想写一下 CUTLASS 3.x 相关的博客(当时已经写了一部分 layout algebra 的内容),不过因为一些原因搁置了。在那之后知乎也出现了一些不错的博文来介绍 CUTLASS 3.x(比如 @reed 的系列博文),官方文档在半年前发布 3.4 版本的时候也更新了不少内容(感觉)。最近组里的一些活和这个有关系,就重新复习(预习)了一下官方的文档,过程中发现文档上关于 layout composition 满足左分配律的条件似乎有些纰漏,仔细研究了一下,发现确实只是个无伤大雅的纰漏,不过研究的过程还是有点意思的,就记录一下。

CuTe Layout

先简单介绍一下 CuTe Layout。一个 layout 可以看作两个 int-tuple 的组合 ,其中 分别称为 的 shape 和 stride(注:在 CuTe Layout 的实际定义中, 也可以是嵌套的 tuple,但在大部分时候可以将它们展平来看待,因此本文仅考虑非嵌套的 tuple)。 也可以表示为 ,其中每个 被称为 的一个 sub-layout。

一个 layout 可以看作一个限制定义域的整数线性变换,将一个 维整数坐标 () 映射到一维坐标:

每个 layout 都有一个 natural layout ,直观上理解就是一个 shape 和其相同的 contiguous column-major layout:

也可以接收一个一维坐标 作为参数,将其用 进行逆映射,得到一个 维坐标后,再映射到一个新的一维坐标,可以递归定义为(其中为取模符号):

Layout Composition

由此可以定义 layout 的 composition 操作 :两个 layout 的组合 满足 ,也就是 将多维坐标 映射为一维坐标 后再由 映射为一维坐标
假设 ,如果要让 仍然是个 layout(能表示成 shape:stride 的形式),则有如下要求(其中 表示 整除 ):

\tag not allowed in aligned environment \begin{aligned} &\exists 1\le j\le k\le n : (\pi_ {j-1}|w|\pi_j)\land (\pi_ {k-1}|wx|\pi_ {k})\tag{1} \ \text{where }& \pi_k=\prod_ {i=0}^k s_i\quad (\text{let }s_0=1) \end{aligned}

简单来说就是 stride (以及整个 layout 所覆盖的范围 ) 在整除这个偏序关系上要介于 的某两个相邻的 shape 前缀积 之间。CuTe 实现上是通过 shape_divshape_mod 操作的约束来表现这个限制,我这里将其本质概括了一下。

关于左分配律满足条件的纰漏

接下来就是 CuTe 的文档和实现上存在纰漏的地方了。前文讨论了形如当 时, 的结果仍然是合法的 layout 的条件。而对于一般的情况,也就是 时,官方的文档认为当 是单射函数时,满足左分配律 。这个结论咋一看挺显然的,但是稍微推理一下就发现并不是这么回事。

对于任意 () ,我们有:

要让 ,就要有:

简单起见,我们这里仅讨论 的情况。设 ,设 ,有:

因为 ,所以有:

Misplaced & 0=d_1(r-r%s_1)-d_2\lfloor r/s_1\rfloor=\begin{cases} 0 & 0\le r\le s_1-1\ d_1s_1-d_2 & s_1\le r\le 2s_1-2 \end{cases}

也就是说以下条件至少要成立一个:

\tag not allowed in aligned environment \begin{aligned} &\forall 0\le c_i < x_i:\sum_ {i=1}^2(w_ic_i)%s_1 < s_1\tag{2}\ &d_1s_1=d_2\tag{3} \end{aligned}

条件可以转化为:

考虑到约束的要求,有:

不难看出条件有可能都不满足,即使按照官方文档的要求 是单射函数。我们可以构造一个简单的例子,令 ,显然 是个单射函数,但是条件都不满足。可以编写代码验证一下(这里为方便起见用到了 CuTe 的 python 接口 pycute,可以参考官方文档进行安装):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import pycute
L = pycute.Layout((36, 18), (1, 72))
P = pycute.Layout((9, 4), (4, 9))
Q = pycute.composition(L, P)
print(f"L = {L}")
print(f"P = {P}")
print(f"Q = {Q}")
vis = set()
for c in range(P.size()):
p = P(c)
if p in vis:
print("P is not injective")
vis.add(p)
expected, actual = L(p), Q(c)
if expected != actual:
print(f"(L o P)({c}) = L(P({c})) = {expected} != {actual} = Q({c})")

输出结果如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
L = (36, 18):(1, 72)
P = (9, 4):(4, 9)
Q = (9, 4):(4, 9)
(L o P)(16) = L(P(16)) = 73 != 37 = Q(16)
(L o P)(17) = L(P(17)) = 77 != 41 = Q(17)
(L o P)(23) = L(P(23)) = 74 != 38 = Q(23)
(L o P)(24) = L(P(24)) = 78 != 42 = Q(24)
(L o P)(25) = L(P(25)) = 82 != 46 = Q(25)
(L o P)(26) = L(P(26)) = 86 != 50 = Q(26)
(L o P)(30) = L(P(30)) = 75 != 39 = Q(30)
(L o P)(31) = L(P(31)) = 79 != 43 = Q(31)
(L o P)(32) = L(P(32)) = 83 != 47 = Q(32)
(L o P)(33) = L(P(33)) = 87 != 51 = Q(33)
(L o P)(34) = L(P(34)) = 91 != 55 = Q(34)
(L o P)(35) = L(P(35)) = 95 != 59 = Q(35)

既然官方的实现存在这样的纰漏,那些使用了 CUTLASS 3.x 的算子岂不是可能有 BUG?这点应该不用担心,目前来说大概是没什么问题的,据我观察官方使用到 composition 的场合, 中的 总是满足 (不失一般性,假设 的 sub-layout 按 stride 升序排序,也就是 ),在这个条件下,layout composition 总是满足左分配律的。事实上,日常中使用的 layout 一般也确实满足这个条件,而且 layout complement 操作的合法性检查似乎也会保证这个性质(composition 经常和 complement 结合着使用)。至于为什么这个条件是 layout composition 满足左分配律的一个充分条件,以后有空再写写证明吧……