0%

重新封装 Segment Tree 类

读者可能读过我的博客的第一篇文章,其中提到了一种线段树的理解方式,而这篇博客将封装一个 Segment Tree 类,以便于更灵活地使用这种数据结构。


笔者注

本文可能涉及到一些更高级的 C++ 语法和特性,读者可自行查阅:

  • C++ 类的继承
  • 类模板实参推导 (CTAD)
  • 模板类
  • 型别推导和转换
  • 虚函数

以下为我认为不错的文章或文档,在此引用:

腾讯云:[C++] 深入理解面向对象编程特性 : 继承

C语言中文网:C++继承和派生简明教程

API Reference Document:类模板实参推导 (CTAD)(C++17 起)


前段时间我发布的一篇博文重新理解线段树及线段树的应用中提出了一种思维框架,用于更高效地建构线段树算法,在实际情况中,我们可能会遇到更加复杂的多棵树甚至多种树之间嵌套的问题,此时我们对于线段树真正灵活的使用则有了更高的要求。我曾经让 ChatGPT 帮我在原先的代码基础上重新构建一个真正的广泛线段树类,然而由于某些原因,效果难以达到我们的预期。本文中的代码已经解决了这个棘手的问题,并且维持了代码的高可维护性和优美程度,希望读者喜欢。

C++ Code

这里以最简单的线段树应用为例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#include <bits/stdc++.h>
using namespace std;

namespace SegmentTree {
#define mid ((data.l+data.r)>>1)
template <typename T, typename L, typename SubClass>
class tree_data {
public:
T seg; L tag; int l, r;
protected:
tree_data() : l(), r(), seg(), tag() {}
tree_data(const int &_l, const int &_r) : seg(), tag(), l(_l), r(_r) {}
static T seg_merge(const T &a, const T &b);
static L tag_merge(const L &a, const L &b);
virtual void tag_apply(const L &x) = 0;
};

template <typename TD, typename T, typename L>
class Node {
private:
void release() {
if(ls) ls->recieve(data.tag);
if(rs) rs->recieve(data.tag);
data.tag = L();
}
void recieve(const L &x) {
data.tag_apply(x);
data.tag = TD::tag_merge(data.tag, x);
}
public:
TD data;
Node *ls, *rs;
Node(const int &s, const int &e, const T *arr) :
data(s, e), ls(NULL), rs(NULL) {
if(data.l == data.r) { data.seg = arr[data.l]; return; }
ls = new Node(data.l, mid, arr);
rs = new Node(mid+1, data.r, arr);
data.seg = TD::seg_merge(ls->data.seg, rs->data.seg);
}
Node(const int &s, const int &e) :
data(s, e), ls(NULL), rs(NULL) {
if(data.l == data.r) return;
ls = new Node(data.l, mid);
rs = new Node(mid+1, data.r);
data.seg = TD::seg_merge(ls->data.seg, rs->data.seg);
}
Node(const Node *base) : data(base->data), ls(base->ls), rs(base->rs) {}
T query(const int &s, const int &e) {
if(s <= data.l && data.r <= e) return data.seg;
release();
T res = T();
if(s <= mid) res = TD::seg_merge(res, ls->query(s, e));
if(e > mid) res = TD::seg_merge(res, rs->query(s, e));
return res;
}
void modify(const int &s, const int &e, const L &x) {
if(s <= data.l && data.r <= e) { recieve(x); return; }
release();
if(s <= mid) ls->modify(s, e, x);
if(e > mid) rs->modify(s, e, x);
data.seg = TD::seg_merge(ls->data.seg,rs->data.seg);
}
//==============================START_DEBUG=================================//
void Print(string base) {
cout << base << "---[" << data.l << "," << data.r << "] ";
data.Print(); cout << "\n";
if(ls) ls->Print(base+" |");
if(rs) rs->Print(base+" ");
}
//===============================END_DEBUG==================================//
};

template <typename TD>
class SegTree {
Node<TD, typename TD::T, typename TD::L> *root;
public:
SegTree(const int &s, const int &e, const typename TD::T *arr) :
root(new Node<TD, typename TD::T, typename TD::L>(s, e, arr)) {}
SegTree(const int &s, const int &e) :
root(new Node<TD, typename TD::T, typename TD::L>(s, e)) {}
typename TD::T query(const int &s, const int &e) { return root->query(s,e); }
void modify(const int &s, const int &e, const typename TD::L &x) { root->modify(s,e,x); }
//==============================START_DEBUG=================================//
void Print() { root->Print(""); }
//===============================END_DEBUG==================================//
};
#undef mid
};

class ADD_SUM_TREE : public SegmentTree::tree_data<int, int, ADD_SUM_TREE> {
public:
using T = int; using L = int;
ADD_SUM_TREE(int _l, int _r) : tree_data(_l, _r) {}
static int seg_merge(const int &a, const int &b) { return a + b; }
static int tag_merge(const int &a, const int &b) { return a + b; }
void tag_apply(const int &x) { seg += x * (r - l + 1); }
//==============================START_DEBUG=================================//
void Print() { cout << seg << " " << tag; }
//===============================END_DEBUG==================================//
};

int a[10] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
int main() {
SegmentTree::SegTree<ADD_SUM_TREE> tr(1, 9);
tr.modify(4,9,1);
tr.Print();
tr.modify(5,9,2);
tr.Print();
cout << tr.query(3,9) << "\n";
for(int i = 1; i <= 9; ++i) {
cout << tr.query(i,i) << " ";
}
cout << "\n";
cout << tr.query(3,9) << "\n";
tr.Print();
return 0;
}