树-代码模板

  1. 1. 线段树
    1. 1.1. 普通写法
    2. 1.2. ZKW线段树
  2. 2. 字典树(前缀树)

线段树

普通写法

建树。。。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
class NumArray {
private:
vector<int> segmentTree;
int n;

void build(int node, int s, int e, vector<int> &nums) {
if (s == e) {
segmentTree[node] = nums[s];
return;
}
int m = s + (e - s) / 2;
build(node * 2 + 1, s, m, nums);
build(node * 2 + 2, m + 1, e, nums);
segmentTree[node] = segmentTree[node * 2 + 1] + segmentTree[node * 2 + 2];
}

void change(int index, int val, int node, int s, int e) {
if (s == e) {
segmentTree[node] = val;
return;
}
int m = s + (e - s) / 2;
if (index <= m) {
change(index, val, node * 2 + 1, s, m);
} else {
change(index, val, node * 2 + 2, m + 1, e);
}
segmentTree[node] = segmentTree[node * 2 + 1] + segmentTree[node * 2 + 2];
}

int range(int left, int right, int node, int s, int e) {
if (left == s && right == e) {
return segmentTree[node];
}
int m = s + (e - s) / 2;
if (right <= m) {
return range(left, right, node * 2 + 1, s, m);
} else if (left > m) {
return range(left, right, node * 2 + 2, m + 1, e);
} else {
return range(left, m, node * 2 + 1, s, m) + range(m + 1, right, node * 2 + 2, m + 1, e);
}
}

public:
NumArray(vector<int>& nums) : n(nums.size()), segmentTree(nums.size() * 4) {
build(0, 0, n - 1, nums);
}

void update(int index, int val) {
change(index, val, 0, 0, n - 1);
}

int sumRange(int left, int right) {
return range(left, right, 0, 0, n - 1);
}
};

ZKW线段树

利用位运算。。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class Trie{
private:
int tr[500000];
int pre_len=1;
public:
//更新下标为index的值
void update(int index, int diff) {
for(int i = index + pre_len; i > 0 ; i >>= 1) tr[i] += diff;
}

//更新下标为index的值
void set(int index, int value) {
int diff = value - tr[index + pre_len];
update(index, diff);
}

//求闭区间[left, right]的和
int sumRange(int left, int right) {
int res = 0;
for(int i = left + pre_len, j = right + pre_len; i <= j; i >>= 1, j >>= 1){
if(i % 2 == 1) res += tr[i++];
if(j % 2 == 0) res += tr[j--];
}
return res;
}

//初始化
void init(int n) {
while(prelen < n) prelen <<= 1;
memset(tr, 0, sizeof(tr));
}
};

字典树(前缀树)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
struct Trie {
vector<Trie*> chs;
int cost;
Trie() : chs(26), cost(-1){};

void insert(string& s, int val) {
Trie* node = this;
for (char& ch : s) {
if (node->chs[ch - 'a'] == nullptr) {
node->chs[ch - 'a'] = new Trie();
}
node = node->chs[ch - 'a'];
}
node->cost = node->cost == -1 ? val : min(node->cost, val);
}
};