C++模板元编程中的代数数据类型(ADT)及模式匹配等操作:以编译期AVL树为例

First Post:

Blog Link:

前言

众所周知C++是可以在编译期进行用户可控的图灵完全计算的,其方法主要是利用模板元编程(TMP),以及在C++11出现并在C++14和C++17中不断增强的constexpr相关的语法(关于这个,clang的源码中有个很有意思的测试样例:a href=”https://github.com/llvm-mirror/clang/blob/master/test/SemaCXX/constexpr-turing.cpp“>用constexpr实现编译期图灵机)。其中constexpr在使用上会更加简洁,而TMP写起来则繁杂些。不过TMP在形式上支持一些比较先进的PL原语,比如模式匹配,这是单纯用constexpr做不到的。

前几天学习OCaml的时候用OCaml实现了一个函数式AVL树,事后想了想好像用C++的TMP也可以实现个类似的编译期AVL树,试着做了一下,发现将OCaml中的ADT以及基于模式匹配的对于ADT的操作转化到TMP竟然出乎意料地顺滑。于是便写此文章,以编译期AVL树为例,展现一下用TMP实现ADT的简易性。

关于AVL树的定义与操作,如有需要请自行搜索,本文不提供详细说明。


AVL树定义

首先是定义AVL树的ADT。

1
2
3
4
5
6
7
8
9
10
11
type avl_tree =
| AvlNull
| AvlNode of int * int * avl_tree * avl_tree

let avl_height t =
match t with
| AvlNull -> 0
| AvlNode (h, _, _, _) -> h
;;

let avl_node v l r = AvlNode (max (avl_height l) (avl_height r) + 1, v, l, r)

其中avl_node是AvlNode的构造器,avl_heightavl_tree的树高属性的萃取器。

转化为C++ TMP代码:

1
2
3
4
5
6
7
8
9
struct avl_tree {};
struct avl_null : public avl_tree {
static constexpr int height = 0;
};
template<int V, typename L, typename R>
struct avl_node : public avl_tree {
static_assert(std::is_base_of_v<avl_tree, L> && std::is_base_of_v<avl_tree, R>);
static constexpr int height = std::max(L::height, R::height) + 1;
};

其中用继承来表现variant。

这里和OCaml的实现有点不同,将树高作为类的静态常量保存,不过无伤大雅。

avl_node的构造中会对模板参数LR进行检查,保证其为avl_tree


AVL树操作

元素查询

OCaml的代码非常直接:

1
2
3
4
5
6
let rec avl_query i t =
match t with
| AvlNode (_, v, l, r) ->
if i < v then avl_query i l else if i > v then avl_query i r else true
| _ -> false
;;

可以用模板偏特化来实现模式匹配。其中:

  • 类模板的基础定义可以表示模式匹配中的缺省情况(如下面的static constexpr bool value = false;)。
  • 对模板的不同形式的偏特化可以表示模式匹配的不同模式。
1
2
3
4
5
6
7
8
9
10
template<int I, typename T>
struct avl_query : public avl_operation<T> {
static constexpr bool value = false;
};
template<int I, typename T>
constexpr bool avl_query_v = avl_query<I, T>::value;
template<int V, typename L, typename R, int I>
struct avl_query<I, avl_node<V, L, R>> {
static constexpr bool value = I < V ? avl_query_v<I, L> : I > V ? avl_query_v<I, R> : true;
};

其中avl_operation主要用于检查模板参数是否为avl_tree

1
2
3
4
template<typename T>
struct avl_operation {
static_assert(std::is_base_of_v<avl_tree, T>);
};

旋转调整

针对AVL树四种需要旋转的模式进行调整:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
let rec avl_adjust t =
match t with
| AvlNode (_, v1, AvlNode (h2, v2, a, b), c)
when h2 > avl_height c + 1 && avl_height a >= avl_height b ->
avl_node v2 a (avl_node v1 b c)
| AvlNode (_, v1, a, AvlNode (h2, v2, b, c))
when h2 > avl_height a + 1 && avl_height c >= avl_height b ->
avl_node v2 (avl_node v1 a b) c
| AvlNode (_, v1, AvlNode (h2, v2, a, AvlNode (h3, v3, b, c)), d)
when h2 > avl_height d + 1 && h3 > avl_height a ->
avl_node v3 (avl_node v2 a b) (avl_node v1 c d)
| AvlNode (_, v1, a, AvlNode (h2, v2, AvlNode (h3, v3, b, c), d))
when h2 > avl_height a + 1 && h3 > avl_height d ->
avl_node v3 (avl_node v1 a b) (avl_node v2 c d)
| _ -> t
;;

OCaml的代码中用到了模式匹配的when子句,这个可用std::enable_if来模拟:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
template<typename T, typename = void>
struct avl_adjust : public avl_operation<T> {
using type = T;
};
template<typename T>
using avl_adjust_t = typename avl_adjust<T>::type;
template<int V1, int V2, typename A, typename B, typename C>
struct avl_adjust<avl_node<V1, avl_node<V2, A, B>, C>,
std::enable_if_t<(avl_node<V2, A, B>::height > C::height + 1 &&
A::height >= B::height)>> {
using type = avl_node<V2, A, avl_node<V1, B, C>>;
};
template<int V1, int V2, typename A, typename B, typename C>
struct avl_adjust<avl_node<V1, A, avl_node<V2, B, C>>,
std::enable_if_t<(avl_node<V2, B, C>::height > A::height + 1 &&
C::height >= B::height)>> {
using type = avl_node<V2, avl_node<V1, A, B>, C>;
};
template<int V1, int V2, int V3, typename A, typename B, typename C, typename D>
struct avl_adjust<avl_node<V1, avl_node<V2, A, avl_node<V3, B, C>>, D>,
std::enable_if_t<(avl_node<V2, A, avl_node<V3, B, C>>::height > D::height + 1 &&
avl_node<V3, B, C>::height > A::height)>> {
using type = avl_node<V3, avl_node<V2, A, B>, avl_node<V1, C, D>>;
};
template<int V1, int V2, int V3, typename A, typename B, typename C, typename D>
struct avl_adjust<avl_node<V1, A, avl_node<V2, avl_node<V3, B, C>, D>>,
std::enable_if_t<(avl_node<V2, avl_node<V3, B, C>, D>::height > A::height + 1 &&
avl_node<V3, B, C>::height > D::height)>> {
using type = avl_node<V3, avl_node<V1, A, B>, avl_node<V2, C, D>>;
};

该段代码可以很好地展现C++ TMP对模式匹配的支持。

插入元素

OCaml代码:

1
2
3
4
5
6
7
8
9
10
let rec avl_insert i t =
match t with
| AvlNode (_, v, l, r) ->
if i < v
then avl_adjust @@ avl_node v (avl_insert i l) r
else if i > v
then avl_adjust @@ avl_node v l (avl_insert i r)
else t
| _ -> avl_node i AvlNull AvlNull
;;

注意多段的条件判断单纯使用C++ TMP不是很好实现,因为std::conditional会对两个分支都求值,而直接用模板匹配来实现这么多个分支又太过繁琐。因此借助了一些constexpr相关语法的帮助,让实现更加简洁:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
template<int I, typename T>
struct avl_insert : public avl_operation<T> {
using type = avl_node<I, avl_null, avl_null>;
};
template<int I, typename T>
using avl_insert_t = typename avl_insert<I, T>::type;
template<int V, typename L, typename R, int I>
struct avl_insert<I, avl_node<V, L, R>> {
private:
static auto infer() {
if constexpr (I < V) {
return avl_adjust_t<avl_node<V, avl_insert_t<I, L>, R>>{};
} else if constexpr (I > V) {
return avl_adjust_t<avl_node<V, L, avl_insert_t<I, R>>>{};
} else {
return avl_node<V, L, R>{};
}
}
public:
using type = decltype(infer());
};

最大值/最小值

实现非常简单,主要是为删除元素做铺垫。

OCaml实现:

1
2
3
4
5
6
7
8
9
10
11
let rec avl_max t =
match t with
| AvlNode (_, v, l, r) -> max v @@ avl_max r
| _ -> min_int
;;

let rec avl_min t =
match t with
| AvlNode (_, v, l, r) -> min v @@ avl_min l
| _ -> max_int
;;

C++实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
template<typename T>
struct avl_extreme : public avl_operation<T> {
static constexpr int max_value = std::numeric_limits<int>::min();
static constexpr int min_value = std::numeric_limits<int>::max();
};
template<typename T>
constexpr int avl_max_v = avl_extreme<T>::max_value;
template<typename T>
constexpr int avl_min_v = avl_extreme<T>::min_value;
template<int V, typename L, typename R>
struct avl_extreme<avl_node<V, L, R>> {
static constexpr int max_value = std::max(V, avl_max_v<R>);
static constexpr int min_value = std::min(V, avl_min_v<L>);
};

删除元素

直接上代码吧……

OCaml代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
let rec avl_remove i t =
match t with
| AvlNode (_, v, l, r) ->
if i < v
then avl_adjust @@ avl_node v (avl_remove i l) r
else if i > v
then avl_adjust @@ avl_node v l (avl_remove i r)
else if l == AvlNull
then r
else if r == AvlNull
then l
else (
let nv = avl_min r in
avl_adjust @@ avl_node nv l (avl_remove nv r))
| _ -> t
;;

C++代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
template<int I, typename T>
struct avl_remove : public avl_operation<T> {
using type = T;
};
template<int I, typename T>
using avl_remove_t = typename avl_remove<I, T>::type;
template<int V, typename L, typename R, int I>
struct avl_remove<I, avl_node<V, L, R>> {
private:
static auto infer() {
if constexpr (I < V) {
return avl_adjust_t<avl_node<V, avl_remove_t<I, L>, R>>{};
} else if constexpr (I > V) {
return avl_adjust_t<avl_node<V, L, avl_remove_t<I, R>>>{};
} else if constexpr (std::is_same_v<L, avl_null>) {
return R{};
} else if constexpr (std::is_same_v<R, avl_null>) {
return L{};
} else {
constexpr int NV = avl_min_v<R>;
return avl_adjust_t<avl_node<NV, L, avl_remove_t<NV, R>>>{};
}
}
public:
using type = decltype(infer());
};

其中用利用constexpr变量可以实现let ... in ...这样的局部绑定。


简单测试

OCaml部分

先定义打印AVL树的函数:

1
2
3
4
5
6
7
8
let rec avl_print ?(depth = 0) t =
match t with
| AvlNode (_, v, l, r) ->
print_string @@ String.make (depth * 2) '-' ^ string_of_int v;
print_newline ();
avl_print l ~depth:(depth + 1);
avl_print r ~depth:(depth + 1)
| _ -> ()

之后构造简单的测试样例:先按顺序插入1~10,之后删除1,3,5。

1
2
3
4
5
6
let t =
List.init 10 (fun x -> x + 1) |> List.fold_left (fun t x -> avl_insert x t) AvlNull
in
avl_print t;
print_newline ();
avl_remove 1 t |> avl_remove 3 |> avl_remove 5 |> avl_print;

输出结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
4
--2
----1
----3
--8
----6
------5
------7
----9
------10

8
--4
----2
----6
------7
--9
----10

C++部分

定义打印函数,以及用于构造连续插入1~n的AVL树的元函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
template<typename T>
struct avl_printer : public avl_operation<T> {
static void print(int depth = 0) {}
};
template<int V, typename L, typename R>
struct avl_printer<avl_node<V, L, R>> {
static void print(int depth = 0) {
for (int i = 0; i < depth; ++i)
std::cout << "--";
std::cout << V << std::endl;
avl_printer<L>::print(depth + 1);
avl_printer<R>::print(depth + 1);
}
};

template<int I>
struct avl_build_range;
template<int I>
using avl_build_range_t = typename avl_build_range<I>::type;
template<>
struct avl_build_range<0> {
using type = avl_null;
};
template<int I>
struct avl_build_range {
using type = avl_insert_t<I, avl_build_range_t<I - 1>>;
};

之后构造简单测试:插入1~10,之后删除1,3,5:

1
2
3
4
5
6
int main() {
using t = avl_build_range_t<10>;
avl_printer<t>::print();
std::cout << std::endl;
avl_printer<avl_remove_t<5, avl_remove_t<3, avl_remove_t<1, t>>>>::print();
}

打印结果和OCaml的一样(废话……):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
4
--2
----1
----3
--8
----6
------5
------7
----9
------10

8
--4
----2
----6
------7
--9
----10

后记

本文旨在结合ADT、模式匹配等比较先进的PL概念,提供一些定义C++编译期数据结构的思路和较为简洁的写法。至于编译期AVL树则纯属玩具(毕竟对编译时间的影响还是挺大的,而且一般编译器的模板运算递归深度顶多一两千),看着图一乐就行了。