C++泛化的非侵入式访问者接口

Post Date:

Blog Link:

前言

前阵子看了下TVM的functor相关的源码和@Ubp.a的一篇有关非侵入式访问者接口的文章,有了点想法,试着写了一个较为泛化的非侵入式访问者接口。

(对相关概念不熟的建议先看一下Ubp.a的文章

基本思路

思路和Ubp.a的文章的类似:显式地保存一个虚表,在访问者类构造时将目标的子类和对应的派发函数登记到表上,在访问时根据传入的基类指针选择派发的函数:

注:该代码中存在一些问题,std::type_info只实现了相等性的比较,不能直接作为std::map的索引类型,需要再套一个std::type_index

泛化的虚表

通常来说在RTTI的支持下可以利用std::type_indextypeid来将一个静态类型或一个对象指针的动态类型转化为索引值,再利用std::mapstd::unordered_map来检索。

不过一方面RTTI本身带有不能忽视的开销,另一方面std::map等关联容器的检索开销相对于数组而言还是稍微有点大,因此有时用户可能会自行实现一套运行时类型接口,比如TVM就是如此。

因此有必要对虚表进行泛化。不难看出,虚表主要需要支持两种操作:

  • Set(Type t, Function f) -> Void: 将t的派发函数设置为f
  • Get(Object o) -> Function: 返回对象o对应的派发函数。

比如使用std::mapstd::type_index检索的虚表可以抽象为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
template<typename Base, typename Func>
class default_vtable {
public:
template<typename T>
inline void Set(Func f) {
data_[std::type_index(typeid(T))] = f;
}
inline Func Get(Base *base) {
return data_[std::type_index(typeid(*base))];
}

private:
std::map<std::type_index, Func> data_;
};

而TVM中使用的虚表(tvm::NodeFunctor)可以抽象为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
template <typename Base, typename Func>
class node_vtable {
public:
template <typename TNode>
inline void Set(Func f) {
uint32_t tindex = TNode::RuntimeTypeIndex();
if (func_.size() <= tindex) {
func_.resize(tindex + 1, nullptr);
}
CHECK(func_[tindex] == nullptr)
<< "Dispatch for " << TNode::_type_key << " is already set";
func_[tindex] = f;
}
inline Func Get(Base *base) {
return func_[base->type_index()];
}
private:
std::vector<Func> func_;
}

泛化的访问者接口

侵入式的访问者接口(每个被访问类需要实现一个accept函数)除了侵入式这一个缺点以外,还有一个不太好的地方,就是因为accept函数需要实现为虚函数,因此不能对函数本身直接泛化(通常需要对class进行泛化),这就导致访问者函数不容易泛化,也就是说针对同一种被访问对象,不容易实现visit函数类型不同的访问者。

非侵入式的接口则可以直接在接口类对visit函数类型进行泛化。

我写的接口如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
template<typename, typename, typename, typename, 
template<typename, typename> typename = default_vtable>
class GeneralVisitor;

template<typename Visitor, typename Base, typename ...Deriveds,
template<typename, typename> typename Vtable, typename R, typename ...Args>
class GeneralVisitor<Visitor, Base, std::tuple<Deriveds...>, R(Args...), Vtable> {
using VtableType = Vtable<Base, R(*)(Visitor *, Base *, Args...)>;

public:
R Visit(Base *base, Args ...args) {
static VtableType vtable = BuildVtable();
return vtable.Get(base)(static_cast<Visitor *>(this), base, std::forward<Args>(args)...);
}

private:
template<typename Derived, typename ...Rest>
static void Register(VtableType &vtable) {
vtable.template Set<Derived>(
[](Visitor *visitor, Base *base, Args ...args) -> R {
return visitor->ImplVisit(static_cast<Derived *>(base),
std::forward<Args>(args)...);
}
);
if constexpr (sizeof...(Rest) > 0) {
Register<Rest...>(vtable);
}
}

static VtableType BuildVtable() {
VtableType vtable;
Register<Deriveds...>(vtable);
return vtable;
}
};

其中各个模板参数含义分别为:

  • Visitor:访问者类
  • Base:被访问的基类
  • Deriveds:被访问的派生类
  • R(Args...):visit函数的类型
  • Vtable:虚表类型,默认为前文提到的用std::type_indexstd::map实现的default_vtable

这边出于效率考虑,学习了TVM的一些做法:虚表保存的为函数指针而不是std::function容器(这样的话保存的函数需要多一个参数来传入Visitor*);将虚表声明为Visit函数内的静态变量,这样一方面可以在Visit第一次调用时才进行虚表的初始化,另一方面对于每个继承同一个模板类的类,只需要保存一个虚表的实例。

使用示例

简单的计算器

参见https://github.com/Light-of-Hers/GeneralVisitor