作死向:修改TVM的底层数据结构
前言
最近给某个新设备写TVM后端时,出于一些原因需要修改TVM底层的DLDataType
的lanes
的类型,将其从uint16_t
改为int32_t
(或者uint32_t
),以支持较大长度的vectorization schedule。刚开始以为很简单,但改着改着发现很多地方都以一种非引用的方式牵涉到DLDataType
,没改全的话运行时会出现莫名其妙的错误,结果这样不断调试不断修改,花了一天多时间才搞定。为纪念吾辈之努力(全部木大じゃない),谨以此文记录之。
作死过程
简单的修改
首先把3rdparty/dlpack/include/dlpack/dlpack.h
的DLDataType::lanes
改为int32_t
(顺便把code和bits也都改成uint16_t
,这样可以构成一个64-bits的packed-struct,方便跨域的传输)。
1 |
|
然后编译一下,有几个涉及类型转换地方的会报错,修改之:
include/tvm/runtime/data_type.h:75
:1
2
3
4
5
6
7
8DataType(int code, int bits, int lanes) {
data_.code = static_cast<uint16_t>(code); // MARK
data_.bits = static_cast<uint16_t>(bits); // MARK
data_.lanes = static_cast<int32_t>(lanes); // MARK
if (code == kBFloat) {
CHECK_EQ(bits, 16);
}
}include/tvm/runtime/serializer.h:41,46
:1
2
3
4
5
6
7
8
9
10
11inline static void Write(Stream* strm, const DLDataType& dtype) {
Handler<uint16_t>::Write(strm, dtype.code); // MARK
Handler<uint16_t>::Write(strm, dtype.bits); // MARK
Handler<int32_t>::Write(strm, dtype.lanes); // MARK
}
inline static bool Read(Stream* strm, DLDataType* dtype) {
if (!Handler<uint16_t>::Read(strm, &(dtype->code))) return false; // MARK
if (!Handler<uint16_t>::Read(strm, &(dtype->bits))) return false; // MARK
if (!Handler<int32_t>::Read(strm, &(dtype->lanes))) return false; // MARK
return true;
}
Cython相关
之后编译就没问题了,但是运行时会出问题,原因在于python和C++的交互部分(主要是Cython代码)。在python部分用grep
搜索一下DLDataType
和DataType
,发现以下几处需要修改的部分:
python/tvm/_ffi/_cython/base.pxi:47
:1
2
3
4ctypedef struct DLDataType:
uint16_t code # MARK
uint16_t bits # MARK
int32_t lanes # MARKpython/tvm/_ffi/runtime_ctypes.py:64
:1
2
3_fields_ = [("type_code", ctypes.c_uint16), # MARK
("bits", ctypes.c_uint16), # MARK
("lanes", ctypes.c_int32)] # MARK
LLVM生成的Host端代码
之后可以正常地进行代码生成了(tvm.build
),但一旦部署运行就会出问题,比如运行下面的代码:
1 |
|
会报错:
1 |
|
尴尬的是从traceback上来看该错误是一个Foreign-Function-Call造成的,仅靠调试很难进行溯源。
用grep
搜索dtype is expected to be
,发现了错误的可能来源:src/tir/transforms/arg_binder.cc:170
,修改之:
1 |
|
但是这样修改一下,甚至连tvm.build
都运行不了了,在LLVM/CPU代码生成(Host端代码生成)时会出问题。这里困惑了我很久,之后我把arg_binder.cc
的修改取消后,查看了一下生成的LLVM IR才发现问题:
1 |
|
原来TVM生成Host端代码时还会自己再构造一遍DLTensor
、DLContext
、DLDataType
等底层结构(对应代码的%0
、%1
、%2
)。
这样一来就很好找了,直接用grep
搜索StructType::create
,很快就发现了错误来源:src/target/llvm/codegen_cpu.c:47
,修改之:
1 |
|
改完之后就可以正常进行代码生成以及运行了。
RPC数据传输
过了一段时间,用autotvm的时候又出了问题,根据报错信息溯源,问题似乎出在tvm的RPC模块,想必是跟DLDataType
的数据读写有关。在RPC部分(src/runtime/rpc/
)搜索DLDataType
和kTVMDataType
,很快就发现了问题源头,在src/runtime/rpc/rpc_protocol:225,344
:
1 |
|
1 |
|
这里tvm传输TVMValue
需要凑齐8字节,而它“知道”DLDataType
的大小为4字节(又是一处擅自假设DLDataType
的结构与大小的代码),所以再额外读/写了4字节的padding。这里我们直接吧padding部分删掉就行了(因为之前把DLDataType
改成了一个8字节的packed-struct)。
总结
修改DLDataType
的lanes
的类型至少需要修改以下几处:
3rdparty/dlpack/include/dlpack/dlpack.h:106
include/tvm/runtime/data_type.h:75
include/tvm/runtime/serializer.h:41,46
python/tvm/_ffi/_cython/base.pxi:47
python/tvm/_ffi/runtime_ctypes.py:64
src/tir/transforms/arg_binder.cc:170
- 以下几处也可以酌情修改:
src/runtime/stackvm/stackvm.cc:532
src/tir/transforms/lower_tvm_builtin.cc:194
- 以下几处也可以酌情修改:
src/target/llvm/codegen_cpu.c:47
src/runtime/rpc/rpc_protocol:225,344
修改其他的底层类型所涉及到的地方也与之类似。
不得不说这应该算TVM设计不合理的地方(或者说设计时根本就没考虑到修改底层类型这种事)。个人认为,除了DLDataType
定义的地方(以及python和cython部分的代码),其他地方应该直接引用DLDataType
进行编码(比如利用decltype
来获取field的类型,据此进行派发)。