前言
前阵子看了下TVM的functor相关的源码和@Ubp.a的一篇有关非侵入式访问者接口的文章,有了点想法,试着写了一个较为泛化的非侵入式访问者接口。
(对相关概念不熟的建议先看一下Ubp.a的文章)
基本思路
思路和Ubp.a的文章的类似:显式地保存一个虚表,在访问者类构造时将目标的子类和对应的派发函数登记到表上,在访问时根据传入的基类指针选择派发的函数:
注:该代码中存在一些问题,std::type_info
只实现了相等性的比较,不能直接作为std::map
的索引类型,需要再套一个std::type_index
。
泛化的虚表
通常来说在RTTI的支持下可以利用std::type_index
和typeid
来将一个静态类型或一个对象指针的动态类型转化为索引值,再利用std::map
或std::unordered_map
来检索。
不过一方面RTTI本身带有不能忽视的开销,另一方面std::map
等关联容器的检索开销相对于数组而言还是稍微有点大,因此有时用户可能会自行实现一套运行时类型接口,比如TVM就是如此。
因此有必要对虚表进行泛化。不难看出,虚表主要需要支持两种操作:
Set(Type t, Function f) -> Void
: 将t
的派发函数设置为f
Get(Object o) -> Function
: 返回对象o
对应的派发函数。
比如使用std::map
和std::type_index
检索的虚表可以抽象为:
1 2 3 4 5 6 7 8 9 10 11 12 13 14
| template<typename Base, typename Func> class default_vtable { public: template<typename T> inline void Set(Func f) { data_[std::type_index(typeid(T))] = f; } inline Func Get(Base *base) { return data_[std::type_index(typeid(*base))]; }
private: std::map<std::type_index, Func> data_; };
|
而TVM中使用的虚表(tvm::NodeFunctor
)可以抽象为:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
| template <typename Base, typename Func> class node_vtable { public: template <typename TNode> inline void Set(Func f) { uint32_t tindex = TNode::RuntimeTypeIndex(); if (func_.size() <= tindex) { func_.resize(tindex + 1, nullptr); } CHECK(func_[tindex] == nullptr) << "Dispatch for " << TNode::_type_key << " is already set"; func_[tindex] = f; } inline Func Get(Base *base) { return func_[base->type_index()]; } private: std::vector<Func> func_; }
|
泛化的访问者接口
侵入式的访问者接口(每个被访问类需要实现一个accept
函数)除了侵入式这一个缺点以外,还有一个不太好的地方,就是因为accept
函数需要实现为虚函数,因此不能对函数本身直接泛化(通常需要对class进行泛化),这就导致访问者函数不容易泛化,也就是说针对同一种被访问对象,不容易实现visit函数类型不同的访问者。
非侵入式的接口则可以直接在接口类对visit函数类型进行泛化。
我写的接口如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
| template<typename, typename, typename, typename, template<typename, typename> typename = default_vtable> class GeneralVisitor;
template<typename Visitor, typename Base, typename ...Deriveds, template<typename, typename> typename Vtable, typename R, typename ...Args> class GeneralVisitor<Visitor, Base, std::tuple<Deriveds...>, R(Args...), Vtable> { using VtableType = Vtable<Base, R(*)(Visitor *, Base *, Args...)>;
public: R Visit(Base *base, Args ...args) { static VtableType vtable = BuildVtable(); return vtable.Get(base)(static_cast<Visitor *>(this), base, std::forward<Args>(args)...); }
private: template<typename Derived, typename ...Rest> static void Register(VtableType &vtable) { vtable.template Set<Derived>( [](Visitor *visitor, Base *base, Args ...args) -> R { return visitor->ImplVisit(static_cast<Derived *>(base), std::forward<Args>(args)...); } ); if constexpr (sizeof...(Rest) > 0) { Register<Rest...>(vtable); } }
static VtableType BuildVtable() { VtableType vtable; Register<Deriveds...>(vtable); return vtable; } };
|
其中各个模板参数含义分别为:
Visitor
:访问者类
Base
:被访问的基类
Deriveds
:被访问的派生类
R(Args...)
:visit函数的类型
Vtable
:虚表类型,默认为前文提到的用std::type_index
和std::map
实现的default_vtable
。
这边出于效率考虑,学习了TVM的一些做法:虚表保存的为函数指针而不是std::function
容器(这样的话保存的函数需要多一个参数来传入Visitor*
);将虚表声明为Visit
函数内的静态变量,这样一方面可以在Visit
第一次调用时才进行虚表的初始化,另一方面对于每个继承同一个模板类的类,只需要保存一个虚表的实例。
使用示例
简单的计算器
参见https://github.com/Light-of-Hers/GeneralVisitor