ic_certification/
nested_rb_tree.rs

1use crate::{empty, fork, labeled, leaf, pruned, AsHashTree, Hash, HashTree, HashTreeNode, RbTree};
2use std::fmt::{Debug, Formatter};
3
4pub trait NestedTreeKeyRequirements: Debug + Clone + AsRef<[u8]> + 'static {}
5pub trait NestedTreeValueRequirements: Debug + Clone + AsHashTree + 'static {}
6impl<T> NestedTreeKeyRequirements for T where T: Debug + Clone + AsRef<[u8]> + 'static {}
7impl<T> NestedTreeValueRequirements for T where T: Debug + Clone + AsHashTree + 'static {}
8
9#[derive(Clone)]
10pub enum NestedTree<K: NestedTreeKeyRequirements, V: NestedTreeValueRequirements> {
11    Leaf(V),
12    Nested(RbTree<K, NestedTree<K, V>>),
13}
14
15impl<K: NestedTreeKeyRequirements, V: NestedTreeValueRequirements> Default for NestedTree<K, V> {
16    fn default() -> Self {
17        NestedTree::Nested(RbTree::<K, NestedTree<K, V>>::new())
18    }
19}
20
21impl<K: NestedTreeKeyRequirements, V: NestedTreeValueRequirements> Debug for NestedTree<K, V> {
22    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
23        let s = match &self {
24            NestedTree::Leaf(leaf) => {
25                format!("NestedTree::Leaf({})", hex::encode(leaf.root_hash()))
26            }
27            NestedTree::Nested(rb_tree) => format!("NestedTree({:#?})", rb_tree),
28        };
29        write!(f, "{}", s)
30    }
31}
32
33impl<K: NestedTreeKeyRequirements, V: NestedTreeValueRequirements> AsHashTree for NestedTree<K, V> {
34    fn root_hash(&self) -> Hash {
35        match self {
36            NestedTree::Leaf(a) => a.root_hash(),
37            NestedTree::Nested(tree) => tree.root_hash(),
38        }
39    }
40
41    fn as_hash_tree(&self) -> HashTree {
42        match self {
43            NestedTree::Leaf(a) => a.as_hash_tree(),
44            NestedTree::Nested(tree) => tree.as_hash_tree(),
45        }
46    }
47}
48
49impl<K: NestedTreeKeyRequirements, V: NestedTreeValueRequirements> NestedTree<K, V> {
50    pub fn get(&self, path: &[K]) -> Option<&V> {
51        if let Some(key) = path.first() {
52            match self {
53                NestedTree::Leaf(_) => None,
54                NestedTree::Nested(tree) => tree
55                    .get(key.as_ref())
56                    .and_then(|child| child.get(&path[1..])),
57            }
58        } else {
59            match self {
60                NestedTree::Leaf(value) => Some(value),
61                NestedTree::Nested(_) => None,
62            }
63        }
64    }
65
66    /// Returns true if there is a leaf at the specified path
67    pub fn contains_leaf(&self, path: &[K]) -> bool {
68        if let Some(key) = path.first() {
69            match self {
70                NestedTree::Leaf(_) => false,
71                NestedTree::Nested(tree) => tree
72                    .get(key.as_ref())
73                    .map(|child| child.contains_leaf(&path[1..]))
74                    .unwrap_or(false),
75            }
76        } else {
77            matches!(self, NestedTree::Leaf(_))
78        }
79    }
80
81    /// Returns true if there is a leaf or a subtree at the specified path
82    pub fn contains_path(&self, path: &[K]) -> bool {
83        if let Some(key) = path.first() {
84            match self {
85                NestedTree::Leaf(_) => false,
86                NestedTree::Nested(tree) => tree
87                    .get(key.as_ref())
88                    .map(|child| child.contains_path(&path[1..]))
89                    .unwrap_or(false),
90            }
91        } else {
92            true
93        }
94    }
95
96    pub fn insert(&mut self, path: &[K], value: V) {
97        if let Some(key) = path.first() {
98            match self {
99                NestedTree::Leaf(_) => {
100                    self.clear();
101                    self.insert(path, value);
102                }
103                NestedTree::Nested(tree) => {
104                    if tree.get(key.as_ref()).is_some() {
105                        tree.modify(key.as_ref(), |child| child.insert(&path[1..], value));
106                    } else {
107                        tree.insert(key.clone(), NestedTree::default());
108                        self.insert(path, value);
109                    }
110                }
111            }
112        } else {
113            *self = NestedTree::Leaf(value);
114        }
115    }
116
117    pub fn delete(&mut self, path: &[K]) {
118        if let Some(key) = path.first() {
119            match self {
120                NestedTree::Leaf(_) => {}
121                NestedTree::Nested(tree) => {
122                    tree.modify(key.as_ref(), |child| child.delete(&path[1..]));
123
124                    // after deleting the subtree located at `path[1..]`,
125                    // check if the subtree located at `path[0]` is empty,
126                    // if it is, remove it
127                    if let Some(root) = tree.get(key.as_ref()) {
128                        match root {
129                            NestedTree::Leaf(_) => {}
130                            NestedTree::Nested(nested_tree) => {
131                                if nested_tree.is_empty() {
132                                    tree.delete(key.as_ref());
133                                }
134                            }
135                        }
136                    }
137                }
138            }
139        } else {
140            self.clear();
141        }
142    }
143
144    pub fn clear(&mut self) {
145        *self = NestedTree::default();
146    }
147
148    pub fn witness(&self, path: &[K]) -> HashTree {
149        if let Some(key) = path.first() {
150            match self {
151                NestedTree::Leaf(value) => value.as_hash_tree(),
152                NestedTree::Nested(tree) => {
153                    tree.nested_witness(key.as_ref(), |tree| tree.witness(&path[1..]))
154                }
155            }
156        } else {
157            self.as_hash_tree()
158        }
159    }
160}
161
162pub fn merge_hash_trees(lhs: HashTree, rhs: HashTree) -> HashTree {
163    match (lhs.root, rhs.root) {
164        (HashTreeNode::Pruned(l), HashTreeNode::Pruned(r)) => {
165            if l != r {
166                panic!("merge_hash_trees: inconsistent hashes");
167            }
168            pruned(l)
169        }
170        (HashTreeNode::Pruned(_), r) => HashTree { root: r },
171        (l, HashTreeNode::Pruned(_)) => HashTree { root: l },
172        (HashTreeNode::Fork(l), HashTreeNode::Fork(r)) => fork(
173            merge_hash_trees(HashTree { root: l.0 }, HashTree { root: r.0 }),
174            merge_hash_trees(HashTree { root: l.1 }, HashTree { root: r.1 }),
175        ),
176        (HashTreeNode::Labeled(l_label, l), HashTreeNode::Labeled(r_label, r)) => {
177            if l_label != r_label {
178                panic!("merge_hash_trees: inconsistent hash tree labels");
179            }
180            labeled(
181                l_label,
182                merge_hash_trees(HashTree { root: *l }, HashTree { root: *r }),
183            )
184        }
185        (HashTreeNode::Empty(), HashTreeNode::Empty()) => empty(),
186        (HashTreeNode::Empty(), r) => HashTree { root: r },
187        (l, HashTreeNode::Empty()) => HashTree { root: l },
188        (HashTreeNode::Leaf(l), HashTreeNode::Leaf(r)) => {
189            if l != r {
190                panic!("merge_hash_trees: inconsistent leaves");
191            }
192            leaf(l)
193        }
194        (_l, _r) => {
195            panic!("merge_hash_trees: inconsistent tree structure");
196        }
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203    use crate::LookupResult;
204    use rstest::*;
205
206    #[rstest]
207    fn nested_tree_operation() {
208        let mut tree: NestedTree<&str, Vec<u8>> = NestedTree::default();
209        // insertion
210        tree.insert(&["one", "two"], vec![2]);
211        tree.insert(&["one", "three"], vec![3]);
212
213        assert_eq!(tree.get(&["one", "two"]), Some(&vec![2]));
214        assert_eq!(tree.get(&["one", "three"]), Some(&vec![3]));
215        assert_eq!(tree.get(&["one", "two", "three"]), None);
216        assert_eq!(tree.get(&["one", "three", "two"]), None);
217        assert_eq!(tree.get(&["one"]), None);
218
219        assert!(tree.contains_leaf(&["one", "two"]));
220        assert!(tree.contains_leaf(&["one", "three"]));
221        assert!(!tree.contains_leaf(&["one", "two", "three"]));
222        assert!(!tree.contains_leaf(&["one", "three", "two"]));
223        assert!(!tree.contains_leaf(&["one"]));
224
225        assert!(tree.contains_path(&["one", "two"]));
226        assert!(tree.contains_path(&["one", "three"]));
227        assert!(!tree.contains_path(&["one", "two", "three"]));
228        assert!(!tree.contains_path(&["one", "three", "two"]));
229        assert!(tree.contains_path(&["one"]));
230
231        // deleting non-existent key doesn't do anything
232        tree.delete(&["one", "two", "three"]);
233
234        assert_eq!(tree.get(&["one", "two"]), Some(&vec![2]));
235        assert_eq!(tree.get(&["one", "three"]), Some(&vec![3]));
236        assert_eq!(tree.get(&["one", "two", "three"]), None);
237        assert_eq!(tree.get(&["one", "three", "two"]), None);
238        assert_eq!(tree.get(&["one"]), None);
239
240        assert!(tree.contains_leaf(&["one", "two"]));
241        assert!(tree.contains_leaf(&["one", "three"]));
242        assert!(!tree.contains_leaf(&["one", "two", "three"]));
243        assert!(!tree.contains_leaf(&["one", "three", "two"]));
244        assert!(!tree.contains_leaf(&["one"]));
245
246        assert!(tree.contains_path(&["one", "two"]));
247        assert!(tree.contains_path(&["one", "three"]));
248        assert!(!tree.contains_path(&["one", "two", "three"]));
249        assert!(!tree.contains_path(&["one", "three", "two"]));
250        assert!(tree.contains_path(&["one"]));
251
252        // deleting existing key works
253        tree.delete(&["one", "three"]);
254
255        assert_eq!(tree.get(&["one", "two"]), Some(&vec![2]));
256        assert_eq!(tree.get(&["one", "three"]), None);
257        assert_eq!(tree.get(&["one", "two", "three"]), None);
258        assert_eq!(tree.get(&["one", "three", "two"]), None);
259        assert_eq!(tree.get(&["one"]), None);
260
261        assert!(tree.contains_leaf(&["one", "two"]));
262        assert!(!tree.contains_leaf(&["one", "three"]));
263        assert!(!tree.contains_leaf(&["one", "two", "three"]));
264        assert!(!tree.contains_leaf(&["one", "three", "two"]));
265        assert!(!tree.contains_leaf(&["one"]));
266
267        assert!(tree.contains_path(&["one", "two"]));
268        assert!(!tree.contains_path(&["one", "three"]));
269        assert!(!tree.contains_path(&["one", "two", "three"]));
270        assert!(!tree.contains_path(&["one", "three", "two"]));
271        assert!(tree.contains_path(&["one"]));
272
273        // deleting subtree works
274        tree.delete(&["one"]);
275
276        assert_eq!(tree.get(&["one", "two"]), None);
277        assert_eq!(tree.get(&["one", "three"]), None);
278        assert_eq!(tree.get(&["one", "two", "three"]), None);
279        assert_eq!(tree.get(&["one", "three", "two"]), None);
280        assert_eq!(tree.get(&["one"]), None);
281
282        assert!(!tree.contains_leaf(&["one", "two"]));
283        assert!(!tree.contains_leaf(&["one", "three"]));
284        assert!(!tree.contains_leaf(&["one", "two", "three"]));
285        assert!(!tree.contains_leaf(&["one", "three", "two"]));
286        assert!(!tree.contains_leaf(&["one"]));
287
288        assert!(!tree.contains_path(&["one", "two"]));
289        assert!(!tree.contains_path(&["one", "three"]));
290        assert!(!tree.contains_path(&["one", "two", "three"]));
291        assert!(!tree.contains_path(&["one", "three", "two"]));
292        assert!(!tree.contains_path(&["one"]));
293    }
294
295    #[rstest]
296    fn delete_removes_empty_subpaths() {
297        let mut tree: NestedTree<&str, Vec<u8>> = NestedTree::default();
298
299        tree.insert(&["one", "two", "three", "four"], vec![4]);
300        tree.insert(&["one", "two", "three", "five"], vec![5]);
301        tree.insert(&["one", "two", "six"], vec![6]);
302        tree.insert(&["one", "seven"], vec![7]);
303
304        assert!(tree.contains_leaf(&["one", "two", "three", "four"]));
305        assert!(tree.contains_leaf(&["one", "two", "three", "five"]));
306        assert!(tree.contains_leaf(&["one", "two", "six"]));
307        assert!(tree.contains_leaf(&["one", "seven"]));
308
309        assert!(tree.contains_path(&["one", "two", "three", "four"]));
310        assert!(tree.contains_path(&["one", "two", "three", "five"]));
311        assert!(tree.contains_path(&["one", "two", "three"]));
312        assert!(tree.contains_path(&["one", "two", "six"]));
313        assert!(tree.contains_path(&["one", "two"]));
314        assert!(tree.contains_path(&["one", "seven"]));
315        assert!(tree.contains_path(&["one"]));
316
317        tree.delete(&["one", "two", "three", "four"]);
318
319        assert!(!tree.contains_leaf(&["one", "two", "three", "four"]));
320        assert!(tree.contains_leaf(&["one", "two", "three", "five"]));
321        assert!(tree.contains_leaf(&["one", "two", "six"]));
322        assert!(tree.contains_leaf(&["one", "seven"]));
323
324        assert!(!tree.contains_path(&["one", "two", "three", "four"]));
325        assert!(tree.contains_path(&["one", "two", "three", "five"]));
326        assert!(tree.contains_path(&["one", "two", "three"]));
327        assert!(tree.contains_path(&["one", "two", "six"]));
328        assert!(tree.contains_path(&["one", "two"]));
329        assert!(tree.contains_path(&["one", "seven"]));
330        assert!(tree.contains_path(&["one"]));
331
332        tree.delete(&["one", "two", "three", "five"]);
333
334        assert!(!tree.contains_leaf(&["one", "two", "three", "four"]));
335        assert!(!tree.contains_leaf(&["one", "two", "three", "five"]));
336        assert!(tree.contains_leaf(&["one", "two", "six"]));
337        assert!(tree.contains_leaf(&["one", "seven"]));
338
339        assert!(!tree.contains_path(&["one", "two", "three", "four"]));
340        assert!(!tree.contains_path(&["one", "two", "three", "five"]));
341        assert!(!tree.contains_path(&["one", "two", "three"]));
342        assert!(tree.contains_path(&["one", "two", "six"]));
343        assert!(tree.contains_path(&["one", "two"]));
344        assert!(tree.contains_path(&["one", "seven"]));
345        assert!(tree.contains_path(&["one"]));
346
347        tree.delete(&["one", "two", "six"]);
348
349        assert!(!tree.contains_leaf(&["one", "two", "three", "four"]));
350        assert!(!tree.contains_leaf(&["one", "two", "three", "five"]));
351        assert!(!tree.contains_leaf(&["one", "two", "six"]));
352        assert!(tree.contains_leaf(&["one", "seven"]));
353
354        assert!(!tree.contains_path(&["one", "two", "three", "four"]));
355        assert!(!tree.contains_path(&["one", "two", "three", "five"]));
356        assert!(!tree.contains_path(&["one", "two", "three"]));
357        assert!(!tree.contains_path(&["one", "two", "six"]));
358        assert!(!tree.contains_path(&["one", "two"]));
359        assert!(tree.contains_path(&["one", "seven"]));
360        assert!(tree.contains_path(&["one"]));
361
362        tree.delete(&["one", "seven"]);
363
364        assert!(!tree.contains_leaf(&["one", "two", "three", "four"]));
365        assert!(!tree.contains_leaf(&["one", "two", "three", "five"]));
366        assert!(!tree.contains_leaf(&["one", "two", "six"]));
367        assert!(!tree.contains_leaf(&["one", "seven"]));
368
369        assert!(!tree.contains_path(&["one", "two", "three", "four"]));
370        assert!(!tree.contains_path(&["one", "two", "three", "five"]));
371        assert!(!tree.contains_path(&["one", "two", "three"]));
372        assert!(!tree.contains_path(&["one", "two", "six"]));
373        assert!(!tree.contains_path(&["one", "two"]));
374        assert!(!tree.contains_path(&["one", "seven"]));
375        assert!(!tree.contains_path(&["one"]));
376    }
377
378    #[rstest]
379    fn merge_hash_trees_merge_witness() {
380        let mut tree: NestedTree<&str, Vec<u8>> = NestedTree::default();
381        tree.insert(&["one", "two"], vec![1]);
382        tree.insert(&["one", "three"], vec![2]);
383        tree.insert(&["two", "two"], vec![3]);
384        tree.insert(&["two", "three"], vec![4]);
385
386        let witness_one_two = tree.witness(&["one", "two"]);
387        let witness_one_three = tree.witness(&["two", "three"]);
388        let witness_merged = merge_hash_trees(witness_one_two, witness_one_three);
389
390        assert!(matches!(
391            witness_merged.lookup_path(&["one", "two"]),
392            LookupResult::Found(val) if val == vec![1]
393        ));
394        assert!(matches!(
395            witness_merged.lookup_path(&["one", "three"]),
396            LookupResult::Unknown
397        ));
398        assert!(matches!(
399            witness_merged.lookup_path(&["two", "three"]),
400            LookupResult::Found(val) if val == vec![4]
401        ));
402        assert!(matches!(
403            witness_merged.lookup_path(&["two", "two"]),
404            LookupResult::Unknown
405        ));
406
407        let witness_merged_left_empty = merge_hash_trees(empty(), witness_merged.clone());
408        assert_eq!(witness_merged_left_empty, witness_merged);
409
410        let witness_merged_right_empty = merge_hash_trees(witness_merged.clone(), empty());
411        assert_eq!(witness_merged_right_empty, witness_merged);
412    }
413
414    #[rstest]
415    // empty
416    #[case::empty_labeled(empty(), labeled_a(), labeled_a())]
417    #[case::labeled_empty(labeled_a(), empty(), labeled_a())]
418    #[case::empty_leaf(empty(), leaf_a(), leaf_a())]
419    #[case::leaf_empty(leaf_a(), empty(), leaf_a())]
420    // pruned
421    #[case::pruned_pruned(pruned_a(), pruned_a(), pruned_a())]
422    #[case::pruned_labeled(pruned_a(), labeled_a(), labeled_a())]
423    #[case::labeled_pruned(labeled_a(), pruned_a(), labeled_a())]
424    #[case::pruned_leaf(pruned_a(), leaf_a(), leaf_a())]
425    #[case::leaf_pruned(leaf_a(), pruned_a(), leaf_a())]
426    #[case::empty_pruned(empty(), pruned_a(), empty())]
427    #[case::pruned_empty(pruned_a(), empty(), empty())]
428    // matching
429    #[case::empty_empty(empty(), empty(), empty())]
430    #[case::fork_fork(fork_a(), fork_a(), fork_a())]
431    #[case::leaf_leaf(leaf_a(), leaf_a(), leaf_a())]
432    // mismatched
433    fn merge_hash_trees_operation(
434        #[case] lhs: HashTree,
435        #[case] rhs: HashTree,
436        #[case] merged: HashTree,
437    ) {
438        assert_eq!(merge_hash_trees(lhs, rhs), merged);
439    }
440
441    #[rstest]
442    #[should_panic]
443    #[case::mismatched_pruned(pruned_a(), pruned_b())]
444    #[should_panic]
445    #[case::mismatched_labeled(labeled_a(), labeled_b())]
446    #[should_panic]
447    #[case::mismatched_leaves(leaf_a(), leaf_b())]
448    #[should_panic]
449    #[case::mismatched_leaf_and_fork(leaf_a(), fork_a())]
450    #[should_panic]
451    #[case::mismatched_fork_and_leaf(fork_a(), leaf_a())]
452    #[should_panic]
453    #[case::mismatched_label_and_fork(labeled_a(), fork_a())]
454    #[should_panic]
455    #[case::mismatched_fork_and_label(fork_a(), labeled_a())]
456    #[should_panic]
457    #[case::mismatched_leaf_and_fork(leaf_a(), fork_a())]
458    #[should_panic]
459    #[case::mismatched_fork_and_leaf(fork_a(), leaf_a())]
460    fn merge_hash_trees_inconsistent_structure(#[case] lhs: HashTree, #[case] rhs: HashTree) {
461        merge_hash_trees(lhs, rhs);
462    }
463
464    #[test]
465    fn should_display_labels_and_hex_hashes() {
466        let label_1 = "label 1";
467        let label_2 = "label 2";
468
469        let value_1 = [1, 2, 3, 4, 5];
470        let value_2 = [7, 8, 9, 10];
471
472        let mut tree: NestedTree<&str, Vec<u8>> = NestedTree::default();
473        tree.insert(&[label_1, label_2], value_1.to_vec());
474        tree.insert(&[label_2, label_1], value_2.to_vec());
475
476        let s = format!("{:?}", tree);
477        assert!(s.contains(label_1));
478        assert!(s.contains(label_2));
479        assert!(s.contains(&format!("0x{}", hex::encode(value_1))));
480        assert!(s.contains(&format!("0x{}", hex::encode(value_2))));
481    }
482
483    #[fixture]
484    fn pruned_a() -> HashTree {
485        pruned(Hash::from([0u8; 32]))
486    }
487
488    #[fixture]
489    fn pruned_b() -> HashTree {
490        pruned(Hash::from([1u8; 32]))
491    }
492
493    #[fixture]
494    fn labeled_a() -> HashTree {
495        labeled("foo", pruned_a())
496    }
497
498    #[fixture]
499    fn labeled_b() -> HashTree {
500        labeled("bar", pruned_a())
501    }
502
503    #[fixture]
504    fn leaf_a() -> HashTree {
505        leaf(Hash::from([0u8; 32]))
506    }
507
508    #[fixture]
509    fn leaf_b() -> HashTree {
510        leaf(Hash::from([1u8; 32]))
511    }
512
513    #[fixture]
514    fn fork_a() -> HashTree {
515        fork(leaf_a(), leaf(Hash::from([1u8; 32])))
516    }
517}