ic_certification/rb_tree/
mod.rs

1use crate::{
2    empty,
3    hash_tree::{fork, fork_hash, labeled_hash, leaf_hash, Hash},
4    labeled, leaf, pruned, HashTree, HashTreeNode,
5};
6use std::cmp::Ordering::{self, Equal, Greater, Less};
7use std::{borrow::Cow, fmt::Debug};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10enum Color {
11    Red,
12    Black,
13}
14
15impl Color {
16    fn flip_assign(&mut self) {
17        *self = self.flip()
18    }
19
20    fn flip(self) -> Self {
21        match self {
22            Self::Red => Self::Black,
23            Self::Black => Self::Red,
24        }
25    }
26}
27
28/// Types that can be converted into a [`HashTree`].
29pub trait AsHashTree {
30    /// Returns the root hash of the tree without constructing it.
31    /// Must be equivalent to `as_hash_tree().reconstruct()`.
32    fn root_hash(&self) -> Hash;
33
34    /// Constructs a hash tree corresponding to the data.
35    fn as_hash_tree(&self) -> HashTree;
36}
37
38impl AsHashTree for Vec<u8> {
39    fn root_hash(&self) -> Hash {
40        leaf_hash(&self[..])
41    }
42
43    fn as_hash_tree(&self) -> HashTree {
44        leaf(Cow::from(&self[..]))
45    }
46}
47
48impl AsHashTree for Hash {
49    fn root_hash(&self) -> Hash {
50        leaf_hash(&self[..])
51    }
52
53    fn as_hash_tree(&self) -> HashTree {
54        leaf(Cow::from(&self[..]))
55    }
56}
57
58impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> AsHashTree for RbTree<K, V> {
59    fn root_hash(&self) -> Hash {
60        match self.root.as_ref() {
61            None => empty().digest(),
62            Some(n) => n.subtree_hash,
63        }
64    }
65
66    fn as_hash_tree(&self) -> HashTree {
67        Node::full_witness_tree(&self.root, Node::data_tree)
68    }
69}
70
71#[derive(PartialEq, Debug, Clone, Copy)]
72enum KeyBound<'a> {
73    Exact(&'a [u8]),
74    Neighbor(&'a [u8]),
75}
76
77impl<'a> AsRef<[u8]> for KeyBound<'a> {
78    fn as_ref(&self) -> &'a [u8] {
79        match self {
80            KeyBound::Exact(key) => key,
81            KeyBound::Neighbor(key) => key,
82        }
83    }
84}
85
86type NodeRef<K, V> = Option<Box<Node<K, V>>>;
87
88// 1. All leaves are black.
89// 2. Children of a red node are black.
90// 3. Every path from a node goes through the same number of black
91//    nodes.
92#[derive(Clone, Debug)]
93struct Node<K, V> {
94    key: K,
95    value: V,
96    left: NodeRef<K, V>,
97    right: NodeRef<K, V>,
98    color: Color,
99
100    /// Hash of the full hash tree built from this node and its
101    /// children. It needs to be recomputed after every rotation.
102    subtree_hash: Hash,
103}
104
105impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> Node<K, V> {
106    fn new(key: K, value: V) -> Box<Node<K, V>> {
107        let value_hash = value.root_hash();
108        let data_hash = labeled_hash(key.as_ref(), &value_hash);
109        Box::new(Self {
110            key,
111            value,
112            left: None,
113            right: None,
114            color: Color::Red,
115            subtree_hash: data_hash,
116        })
117    }
118
119    fn data_hash(&self) -> Hash {
120        labeled_hash(self.key.as_ref(), &self.value.root_hash())
121    }
122
123    fn left_hash_tree(&self) -> HashTree {
124        match self.left.as_ref() {
125            None => empty(),
126            Some(l) => pruned(l.subtree_hash),
127        }
128    }
129
130    fn right_hash_tree(&self) -> HashTree {
131        match self.right.as_ref() {
132            None => empty(),
133            Some(r) => pruned(r.subtree_hash),
134        }
135    }
136
137    fn visit<'a, F>(n: &'a NodeRef<K, V>, f: &mut F)
138    where
139        F: 'a + FnMut(&'a [u8], &'a V),
140    {
141        if let Some(n) = n {
142            Self::visit(&n.left, f);
143            (*f)(n.key.as_ref(), &n.value);
144            Self::visit(&n.right, f)
145        }
146    }
147
148    fn data_tree(&self) -> HashTree {
149        labeled(self.key.as_ref(), self.value.as_hash_tree())
150    }
151
152    fn subtree_with<'a>(&'a self, f: impl FnOnce(&'a V) -> HashTree) -> HashTree {
153        labeled(self.key.as_ref(), f(&self.value))
154    }
155
156    fn witness_tree(&self) -> HashTree {
157        labeled(self.key.as_ref(), pruned(self.value.root_hash()))
158    }
159
160    fn full_witness_tree<'a>(n: &'a NodeRef<K, V>, f: fn(&'a Node<K, V>) -> HashTree) -> HashTree {
161        match n {
162            None => empty(),
163            Some(n) => three_way_fork(
164                Self::full_witness_tree(&n.left, f),
165                f(n),
166                Self::full_witness_tree(&n.right, f),
167            ),
168        }
169    }
170
171    fn update_subtree_hash(&mut self) {
172        self.subtree_hash = self.compute_subtree_hash();
173    }
174
175    fn compute_subtree_hash(&self) -> Hash {
176        let h = self.data_hash();
177
178        match (self.left.as_ref(), self.right.as_ref()) {
179            (None, None) => h,
180            (Some(l), None) => fork_hash(&l.subtree_hash, &h),
181            (None, Some(r)) => fork_hash(&h, &r.subtree_hash),
182            (Some(l), Some(r)) => fork_hash(&l.subtree_hash, &fork_hash(&h, &r.subtree_hash)),
183        }
184    }
185}
186
187#[derive(PartialEq, Debug)]
188enum Visit {
189    Pre,
190    In,
191    Post,
192}
193
194/// Iterator over a RbTree.
195#[derive(Debug)]
196pub struct Iter<'a, K, V> {
197    /// Invariants:
198    /// 1. visit == Pre: none of the nodes in parents were visited yet.
199    /// 2. visit == In:  the last node in parents and all its left children are visited.
200    /// 3. visit == Post: all the nodes reachable from the last node in parents are visited.
201    visit: Visit,
202    parents: Vec<&'a Node<K, V>>,
203}
204
205impl<'a, K, V> Iter<'a, K, V> {
206    /// This function is an adaptation of the traverse_step procedure described in
207    /// section 7.2 "Bidirectional Bifurcate Coordinates" of
208    /// "Elements of Programming" by A. Stepanov and P. McJones, p. 118.
209    /// http://elementsofprogramming.com/eop.pdf
210    ///
211    /// The main difference is that our nodes don't have parent links for two reasons:
212    /// 1. They don't play well with safe Rust ownership model.
213    /// 2. Iterating a tree shouldn't be an operation common enough to complicate the code.
214    fn step(&mut self) -> bool {
215        match self.parents.last() {
216            Some(tip) => {
217                match self.visit {
218                    Visit::Pre => {
219                        if let Some(l) = &tip.left {
220                            self.parents.push(l);
221                        } else {
222                            self.visit = Visit::In;
223                        }
224                    }
225                    Visit::In => {
226                        if let Some(r) = &tip.right {
227                            self.parents.push(r);
228                            self.visit = Visit::Pre;
229                        } else {
230                            self.visit = Visit::Post;
231                        }
232                    }
233                    Visit::Post => {
234                        let tip = self.parents.pop().unwrap();
235                        if let Some(parent) = self.parents.last() {
236                            if parent
237                                .left
238                                .as_ref()
239                                .map(|l| l.as_ref() as *const Node<K, V>)
240                                == Some(tip as *const Node<K, V>)
241                            {
242                                self.visit = Visit::In;
243                            }
244                        }
245                    }
246                }
247                true
248            }
249            None => false,
250        }
251    }
252}
253
254impl<'a, K, V> std::iter::Iterator for Iter<'a, K, V> {
255    type Item = (&'a K, &'a V);
256
257    fn next(&mut self) -> Option<Self::Item> {
258        while self.step() {
259            if self.visit == Visit::In {
260                return self.parents.last().map(|n| (&n.key, &n.value));
261            }
262        }
263        None
264    }
265}
266
267/// Implements mutable left-leaning red-black trees as defined in
268/// <https://www.cs.princeton.edu/~rs/talks/LLRB/LLRB.pdf>
269#[derive(Default, Clone)]
270pub struct RbTree<K, V> {
271    root: NodeRef<K, V>,
272}
273
274impl<K, V> PartialEq for RbTree<K, V>
275where
276    K: 'static + AsRef<[u8]> + PartialEq,
277    V: 'static + AsHashTree + PartialEq,
278{
279    fn eq(&self, other: &Self) -> bool {
280        self.iter().eq(other.iter())
281    }
282}
283
284impl<K, V> Eq for RbTree<K, V>
285where
286    K: 'static + AsRef<[u8]> + Eq,
287    V: 'static + AsHashTree + Eq,
288{
289}
290
291impl<K, V> PartialOrd for RbTree<K, V>
292where
293    K: 'static + AsRef<[u8]> + PartialOrd,
294    V: 'static + AsHashTree + PartialOrd,
295{
296    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
297        self.iter().partial_cmp(other.iter())
298    }
299}
300
301impl<K, V> Ord for RbTree<K, V>
302where
303    K: 'static + AsRef<[u8]> + Ord,
304    V: 'static + AsHashTree + Ord,
305{
306    fn cmp(&self, other: &Self) -> Ordering {
307        self.iter().cmp(other.iter())
308    }
309}
310
311impl<K, V> std::iter::FromIterator<(K, V)> for RbTree<K, V>
312where
313    K: 'static + AsRef<[u8]>,
314    V: 'static + AsHashTree,
315{
316    fn from_iter<T>(iter: T) -> Self
317    where
318        T: IntoIterator<Item = (K, V)>,
319    {
320        let mut t = RbTree::<K, V>::new();
321        for (k, v) in iter.into_iter() {
322            t.insert(k, v);
323        }
324        t
325    }
326}
327
328impl<K, V> std::fmt::Debug for RbTree<K, V>
329where
330    K: 'static + AsRef<[u8]> + std::fmt::Debug,
331    V: 'static + AsHashTree + std::fmt::Debug,
332{
333    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
334        let mut list = f.debug_list();
335        for (k, v) in self.iter() {
336            list.entry(&format_args!(
337                "({}, {:#?})",
338                String::from_utf8_lossy(k.as_ref()),
339                v.as_hash_tree()
340            ));
341        }
342        list.finish()
343    }
344}
345
346impl<K, V> RbTree<K, V> {
347    /// Constructs a new empty tree.
348    pub const fn new() -> Self {
349        Self { root: None }
350    }
351
352    /// Returns true if the map is empty.
353    pub const fn is_empty(&self) -> bool {
354        self.root.is_none()
355    }
356}
357
358impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> RbTree<K, V> {
359    /// Looks up the key in the map and returns the associated value, if there is one.
360    pub fn get(&self, key: &[u8]) -> Option<&V> {
361        let mut root = self.root.as_ref();
362        while let Some(n) = root {
363            match key.cmp(n.key.as_ref()) {
364                Equal => return Some(&n.value),
365                Less => root = n.left.as_ref(),
366                Greater => root = n.right.as_ref(),
367            }
368        }
369        None
370    }
371
372    /// Updates the value corresponding to the specified key.
373    pub fn modify(&mut self, key: &[u8], f: impl FnOnce(&mut V)) {
374        fn go<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
375            h: &mut NodeRef<K, V>,
376            k: &[u8],
377            f: impl FnOnce(&mut V),
378        ) {
379            if let Some(h) = h {
380                match k.as_ref().cmp(h.key.as_ref()) {
381                    Equal => {
382                        f(&mut h.value);
383                        h.update_subtree_hash();
384                    }
385                    Less => {
386                        go(&mut h.left, k, f);
387                        h.update_subtree_hash();
388                    }
389                    Greater => {
390                        go(&mut h.right, k, f);
391                        h.update_subtree_hash();
392                    }
393                }
394            }
395        }
396        go(&mut self.root, key, f)
397    }
398
399    fn range_witness<'a>(
400        &'a self,
401        left: Option<KeyBound<'a>>,
402        right: Option<KeyBound<'a>>,
403        f: fn(&'a Node<K, V>) -> HashTree,
404    ) -> HashTree {
405        match (left, right) {
406            (None, None) => Node::full_witness_tree(&self.root, f),
407            (Some(l), None) => self.witness_range_above(l, f),
408            (None, Some(r)) => self.witness_range_below(r, f),
409            (Some(l), Some(r)) => self.witness_range_between(l, r, f),
410        }
411    }
412
413    /// Constructs a hash tree that acts as a proof that there is a
414    /// entry with the specified key in this map.  The proof also
415    /// contains the value in question.
416    ///
417    /// If the key is not in the map, returns a proof of absence.
418    pub fn witness(&self, key: &[u8]) -> HashTree {
419        self.nested_witness(key, |v| v.as_hash_tree())
420    }
421
422    /// Like `witness`, but gives the caller more control over the
423    /// construction of the value witness.  This method is useful for
424    /// constructing witnesses for nested certified maps.
425    pub fn nested_witness<'a>(&'a self, key: &[u8], f: impl FnOnce(&'a V) -> HashTree) -> HashTree {
426        if let Some(t) = self.lookup_and_build_witness(key, f) {
427            return t;
428        }
429        self.range_witness(
430            self.lower_bound(key),
431            self.upper_bound(key),
432            Node::witness_tree,
433        )
434    }
435
436    /// Returns a witness enumerating all the keys in this map.  The
437    /// resulting tree doesn't include values, they are replaced with
438    /// "Pruned" nodes.
439    pub fn keys(&self) -> HashTree {
440        Node::full_witness_tree(&self.root, Node::witness_tree)
441    }
442
443    /// Returns a witness for the keys in the specified range.  The
444    /// resulting tree doesn't include values, they are replaced with
445    /// "Pruned" nodes.
446    pub fn key_range(&self, first: &[u8], last: &[u8]) -> HashTree {
447        self.range_witness(
448            self.lower_bound(first),
449            self.upper_bound(last),
450            Node::witness_tree,
451        )
452    }
453
454    /// Returns a witness for the key-value pairs in the specified range.
455    /// The resulting tree contains both keys and values.
456    pub fn value_range(&self, first: &[u8], last: &[u8]) -> HashTree {
457        self.range_witness(
458            self.lower_bound(first),
459            self.upper_bound(last),
460            Node::data_tree,
461        )
462    }
463
464    /// Returns a witness that enumerates all the keys starting with
465    /// the specified prefix.
466    pub fn keys_with_prefix(&self, prefix: &[u8]) -> HashTree {
467        self.range_witness(
468            self.lower_bound(prefix),
469            self.right_prefix_neighbor(prefix),
470            Node::witness_tree,
471        )
472    }
473
474    /// Creates an iterator over the map's keys and values.
475    pub fn iter(&self) -> Iter<'_, K, V> {
476        match &self.root {
477            None => Iter {
478                visit: Visit::Pre,
479                parents: vec![],
480            },
481            Some(n) => Iter {
482                visit: Visit::Pre,
483                parents: vec![n],
484            },
485        }
486    }
487
488    /// Enumerates all the key-value pairs in the tree.
489    pub fn for_each<'a, F>(&'a self, mut f: F)
490    where
491        F: 'a + FnMut(&'a [u8], &'a V),
492    {
493        Node::visit(&self.root, &mut f)
494    }
495
496    fn witness_range_above<'a>(
497        &'a self,
498        lo: KeyBound<'a>,
499        f: fn(&'a Node<K, V>) -> HashTree,
500    ) -> HashTree {
501        fn go<'a, K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
502            n: &'a NodeRef<K, V>,
503            lo: KeyBound<'a>,
504            f: fn(&'a Node<K, V>) -> HashTree,
505        ) -> HashTree {
506            match n {
507                None => empty(),
508                Some(n) => match n.key.as_ref().cmp(lo.as_ref()) {
509                    Equal => three_way_fork(
510                        n.left_hash_tree(),
511                        match lo {
512                            KeyBound::Exact(_) => f(n),
513                            KeyBound::Neighbor(_) => n.witness_tree(),
514                        },
515                        Node::full_witness_tree(&n.right, f),
516                    ),
517                    Less => three_way_fork(
518                        n.left_hash_tree(),
519                        pruned(n.data_hash()),
520                        go(&n.right, lo, f),
521                    ),
522                    Greater => three_way_fork(
523                        go(&n.left, lo, f),
524                        f(n),
525                        Node::full_witness_tree(&n.right, f),
526                    ),
527                },
528            }
529        }
530        go(&self.root, lo, f)
531    }
532
533    fn witness_range_below<'a>(
534        &'a self,
535        hi: KeyBound<'a>,
536        f: fn(&'a Node<K, V>) -> HashTree,
537    ) -> HashTree {
538        fn go<'a, K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
539            n: &'a NodeRef<K, V>,
540            hi: KeyBound<'a>,
541            f: fn(&'a Node<K, V>) -> HashTree,
542        ) -> HashTree {
543            match n {
544                None => empty(),
545                Some(n) => match n.key.as_ref().cmp(hi.as_ref()) {
546                    Equal => three_way_fork(
547                        Node::full_witness_tree(&n.left, f),
548                        match hi {
549                            KeyBound::Exact(_) => f(n),
550                            KeyBound::Neighbor(_) => n.witness_tree(),
551                        },
552                        n.right_hash_tree(),
553                    ),
554                    Greater => three_way_fork(
555                        go(&n.left, hi, f),
556                        pruned(n.data_hash()),
557                        n.right_hash_tree(),
558                    ),
559                    Less => three_way_fork(
560                        Node::full_witness_tree(&n.left, f),
561                        f(n),
562                        go(&n.right, hi, f),
563                    ),
564                },
565            }
566        }
567        go(&self.root, hi, f)
568    }
569
570    fn witness_range_between<'a>(
571        &'a self,
572        lo: KeyBound<'a>,
573        hi: KeyBound<'a>,
574        f: fn(&'a Node<K, V>) -> HashTree,
575    ) -> HashTree {
576        debug_assert!(
577            lo.as_ref() <= hi.as_ref(),
578            "lo = {:?} > hi = {:?}",
579            lo.as_ref(),
580            hi.as_ref()
581        );
582        fn go<'a, K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
583            n: &'a NodeRef<K, V>,
584            lo: KeyBound<'a>,
585            hi: KeyBound<'a>,
586            f: fn(&'a Node<K, V>) -> HashTree,
587        ) -> HashTree {
588            match n {
589                None => empty(),
590                Some(n) => {
591                    let k = n.key.as_ref();
592                    match (lo.as_ref().cmp(k), k.cmp(hi.as_ref())) {
593                        (Less, Less) => {
594                            three_way_fork(go(&n.left, lo, hi, f), f(n), go(&n.right, lo, hi, f))
595                        }
596                        (Equal, Equal) => three_way_fork(
597                            n.left_hash_tree(),
598                            match (lo, hi) {
599                                (KeyBound::Exact(_), _) => f(n),
600                                (_, KeyBound::Exact(_)) => f(n),
601                                _ => n.witness_tree(),
602                            },
603                            n.right_hash_tree(),
604                        ),
605                        (_, Equal) => three_way_fork(
606                            go(&n.left, lo, hi, f),
607                            match hi {
608                                KeyBound::Exact(_) => f(n),
609                                KeyBound::Neighbor(_) => n.witness_tree(),
610                            },
611                            n.right_hash_tree(),
612                        ),
613                        (Equal, _) => three_way_fork(
614                            n.left_hash_tree(),
615                            match lo {
616                                KeyBound::Exact(_) => f(n),
617                                KeyBound::Neighbor(_) => n.witness_tree(),
618                            },
619                            go(&n.right, lo, hi, f),
620                        ),
621                        (Less, Greater) => three_way_fork(
622                            go(&n.left, lo, hi, f),
623                            pruned(n.data_hash()),
624                            n.right_hash_tree(),
625                        ),
626                        (Greater, Less) => three_way_fork(
627                            n.left_hash_tree(),
628                            pruned(n.data_hash()),
629                            go(&n.right, lo, hi, f),
630                        ),
631                        _ => pruned(n.subtree_hash),
632                    }
633                }
634            }
635        }
636        go(&self.root, lo, hi, f)
637    }
638
639    fn lower_bound(&self, key: &[u8]) -> Option<KeyBound<'_>> {
640        fn go<'a, K: 'static + AsRef<[u8]>, V>(
641            n: &'a NodeRef<K, V>,
642            key: &[u8],
643        ) -> Option<KeyBound<'a>> {
644            n.as_ref().and_then(|n| {
645                let node_key = n.key.as_ref();
646                match node_key.cmp(key) {
647                    Less => go(&n.right, key).or(Some(KeyBound::Neighbor(node_key))),
648                    Equal => Some(KeyBound::Exact(node_key)),
649                    Greater => go(&n.left, key),
650                }
651            })
652        }
653        go(&self.root, key)
654    }
655
656    fn upper_bound(&self, key: &[u8]) -> Option<KeyBound<'_>> {
657        fn go<'a, K: 'static + AsRef<[u8]>, V>(
658            n: &'a NodeRef<K, V>,
659            key: &[u8],
660        ) -> Option<KeyBound<'a>> {
661            n.as_ref().and_then(|n| {
662                let node_key = n.key.as_ref();
663                match node_key.cmp(key) {
664                    Less => go(&n.right, key),
665                    Equal => Some(KeyBound::Exact(node_key)),
666                    Greater => go(&n.left, key).or(Some(KeyBound::Neighbor(node_key))),
667                }
668            })
669        }
670        go(&self.root, key)
671    }
672
673    fn right_prefix_neighbor(&self, prefix: &[u8]) -> Option<KeyBound<'_>> {
674        fn is_prefix_of(p: &[u8], x: &[u8]) -> bool {
675            if p.len() > x.len() {
676                return false;
677            }
678            &x[0..p.len()] == p
679        }
680        fn go<'a, K: 'static + AsRef<[u8]>, V>(
681            n: &'a NodeRef<K, V>,
682            prefix: &[u8],
683        ) -> Option<KeyBound<'a>> {
684            n.as_ref().and_then(|n| {
685                let node_key = n.key.as_ref();
686                match node_key.cmp(prefix) {
687                    Greater if is_prefix_of(prefix, node_key) => go(&n.right, prefix),
688                    Greater => go(&n.left, prefix).or(Some(KeyBound::Neighbor(node_key))),
689                    Less | Equal => go(&n.right, prefix),
690                }
691            })
692        }
693        go(&self.root, prefix)
694    }
695
696    fn lookup_and_build_witness<'a>(
697        &'a self,
698        key: &[u8],
699        f: impl FnOnce(&'a V) -> HashTree,
700    ) -> Option<HashTree> {
701        fn go<'a, K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
702            n: &'a NodeRef<K, V>,
703            key: &[u8],
704            f: impl FnOnce(&'a V) -> HashTree,
705        ) -> Option<HashTree> {
706            n.as_ref().and_then(|n| match key.cmp(n.key.as_ref()) {
707                Equal => Some(three_way_fork(
708                    n.left_hash_tree(),
709                    n.subtree_with(f),
710                    n.right_hash_tree(),
711                )),
712                Less => {
713                    let subtree = go(&n.left, key, f)?;
714                    Some(three_way_fork(
715                        subtree,
716                        pruned(n.data_hash()),
717                        n.right_hash_tree(),
718                    ))
719                }
720                Greater => {
721                    let subtree = go(&n.right, key, f)?;
722                    Some(three_way_fork(
723                        n.left_hash_tree(),
724                        pruned(n.data_hash()),
725                        subtree,
726                    ))
727                }
728            })
729        }
730        go(&self.root, key, f)
731    }
732
733    /// Inserts a key-value entry into the map.
734    pub fn insert(&mut self, key: K, value: V) {
735        fn go<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
736            h: NodeRef<K, V>,
737            k: K,
738            v: V,
739        ) -> Box<Node<K, V>> {
740            match h {
741                None => Node::new(k, v),
742                Some(mut h) => {
743                    match k.as_ref().cmp(h.key.as_ref()) {
744                        Equal => {
745                            h.value = v;
746                        }
747                        Less => {
748                            h.left = Some(go(h.left, k, v));
749                        }
750                        Greater => {
751                            h.right = Some(go(h.right, k, v));
752                        }
753                    }
754                    h.update_subtree_hash();
755                    balance(h)
756                }
757            }
758        }
759        let mut root = go(self.root.take(), key, value);
760        root.color = Color::Black;
761        self.root = Some(root);
762
763        #[cfg(test)]
764        debug_assert!(
765            is_balanced(&self.root),
766            "the tree is not balanced:\n{:?}",
767            DebugView(&self.root)
768        );
769    }
770
771    /// Removes the specified key from the map.
772    pub fn delete(&mut self, key: &[u8]) {
773        fn move_red_left<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
774            mut h: Box<Node<K, V>>,
775        ) -> Box<Node<K, V>> {
776            flip_colors(&mut h);
777            if is_red(&h.right.as_ref().unwrap().left) {
778                h.right = Some(rotate_right(h.right.take().unwrap()));
779                h = rotate_left(h);
780                flip_colors(&mut h);
781            }
782            h
783        }
784
785        fn move_red_right<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
786            mut h: Box<Node<K, V>>,
787        ) -> Box<Node<K, V>> {
788            flip_colors(&mut h);
789            if is_red(&h.left.as_ref().unwrap().left) {
790                h = rotate_right(h);
791                flip_colors(&mut h);
792            }
793            h
794        }
795
796        #[inline]
797        fn min<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
798            mut h: &mut Box<Node<K, V>>,
799        ) -> &mut Box<Node<K, V>> {
800            while h.left.is_some() {
801                h = h.left.as_mut().unwrap();
802            }
803            h
804        }
805
806        fn delete_min<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
807            mut h: Box<Node<K, V>>,
808        ) -> NodeRef<K, V> {
809            if h.left.is_none() {
810                debug_assert!(h.right.is_none());
811                drop(h);
812                return None;
813            }
814            if !is_red(&h.left) && !is_red(&h.left.as_ref().unwrap().left) {
815                h = move_red_left(h);
816            }
817            h.left = delete_min(h.left.unwrap());
818            h.update_subtree_hash();
819            Some(balance(h))
820        }
821
822        fn go<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
823            mut h: Box<Node<K, V>>,
824            key: &[u8],
825        ) -> NodeRef<K, V> {
826            if key < h.key.as_ref() {
827                debug_assert!(h.left.is_some(), "the key must be present in the tree");
828                if !is_red(&h.left) && !is_red(&h.left.as_ref().unwrap().left) {
829                    h = move_red_left(h);
830                }
831                h.left = go(h.left.take().unwrap(), key);
832            } else {
833                if is_red(&h.left) {
834                    h = rotate_right(h);
835                }
836                if key == h.key.as_ref() && h.right.is_none() {
837                    debug_assert!(h.left.is_none());
838                    drop(h);
839                    return None;
840                }
841
842                if !is_red(&h.right) && !is_red(&h.right.as_ref().unwrap().left) {
843                    h = move_red_right(h);
844                }
845
846                if key == h.key.as_ref() {
847                    let m = min(h.right.as_mut().unwrap());
848                    std::mem::swap(&mut h.key, &mut m.key);
849                    std::mem::swap(&mut h.value, &mut m.value);
850                    h.right = delete_min(h.right.take().unwrap());
851                } else {
852                    h.right = go(h.right.take().unwrap(), key);
853                }
854            }
855            h.update_subtree_hash();
856            Some(balance(h))
857        }
858
859        if self.get(key).is_none() {
860            return;
861        }
862
863        if !is_red(&self.root.as_ref().unwrap().left) && !is_red(&self.root.as_ref().unwrap().right)
864        {
865            self.root.as_mut().unwrap().color = Color::Red;
866        }
867        self.root = go(self.root.take().unwrap(), key);
868        if let Some(n) = self.root.as_mut() {
869            n.color = Color::Black;
870        }
871
872        #[cfg(test)]
873        debug_assert!(
874            is_balanced(&self.root),
875            "unbalanced map: {:?}",
876            DebugView(&self.root)
877        );
878
879        debug_assert!(self.get(key).is_none());
880    }
881}
882
883fn three_way_fork(l: HashTree, m: HashTree, r: HashTree) -> HashTree {
884    match (l.root, m.root, r.root) {
885        (HashTreeNode::Empty(), m, HashTreeNode::Empty()) => HashTree { root: m },
886        (l, m, HashTreeNode::Empty()) => fork(HashTree { root: l }, HashTree { root: m }),
887        (HashTreeNode::Empty(), m, r) => fork(HashTree { root: m }, HashTree { root: r }),
888        (HashTreeNode::Pruned(lhash), HashTreeNode::Pruned(mhash), HashTreeNode::Pruned(rhash)) => {
889            pruned(fork_hash(&lhash, &fork_hash(&mhash, &rhash)))
890        }
891        (l, HashTreeNode::Pruned(mhash), HashTreeNode::Pruned(rhash)) => {
892            fork(HashTree { root: l }, pruned(fork_hash(&mhash, &rhash)))
893        }
894        (l, m, r) => fork(
895            HashTree { root: l },
896            fork(HashTree { root: m }, HashTree { root: r }),
897        ),
898    }
899}
900
901// helper functions
902fn is_red<K, V>(x: &NodeRef<K, V>) -> bool {
903    x.as_ref().map(|h| h.color == Color::Red).unwrap_or(false)
904}
905
906fn balance<K: AsRef<[u8]> + 'static, V: AsHashTree + 'static>(
907    mut h: Box<Node<K, V>>,
908) -> Box<Node<K, V>> {
909    if is_red(&h.right) && !is_red(&h.left) {
910        h = rotate_left(h);
911    }
912    if is_red(&h.left) && is_red(&h.left.as_ref().unwrap().left) {
913        h = rotate_right(h);
914    }
915    if is_red(&h.left) && is_red(&h.right) {
916        flip_colors(&mut h)
917    }
918    h
919}
920
921/// Make a left-leaning link lean to the right.
922fn rotate_right<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
923    mut h: Box<Node<K, V>>,
924) -> Box<Node<K, V>> {
925    debug_assert!(is_red(&h.left));
926
927    let mut x = h.left.take().unwrap();
928    h.left = x.right.take();
929    h.update_subtree_hash();
930
931    x.right = Some(h);
932    x.color = x.right.as_ref().unwrap().color;
933    x.right.as_mut().unwrap().color = Color::Red;
934    x.update_subtree_hash();
935
936    x
937}
938
939fn rotate_left<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
940    mut h: Box<Node<K, V>>,
941) -> Box<Node<K, V>> {
942    debug_assert!(is_red(&h.right));
943
944    let mut x = h.right.take().unwrap();
945    h.right = x.left.take();
946    h.update_subtree_hash();
947
948    x.left = Some(h);
949    x.color = x.left.as_ref().unwrap().color;
950    x.left.as_mut().unwrap().color = Color::Red;
951    x.update_subtree_hash();
952
953    x
954}
955
956fn flip_colors<K, V>(h: &mut Box<Node<K, V>>) {
957    h.color.flip_assign();
958    h.left.as_mut().unwrap().color.flip_assign();
959    h.right.as_mut().unwrap().color.flip_assign();
960}
961
962#[cfg(test)]
963fn is_balanced<K, V>(root: &NodeRef<K, V>) -> bool {
964    fn go<K, V>(node: &NodeRef<K, V>, mut num_black: usize) -> bool {
965        match node {
966            None => num_black == 0,
967            Some(ref n) => {
968                if !is_red(node) {
969                    debug_assert!(num_black > 0);
970                    num_black -= 1;
971                } else {
972                    assert!(!is_red(&n.left));
973                    assert!(!is_red(&n.right));
974                }
975                go(&n.left, num_black) && go(&n.right, num_black)
976            }
977        }
978    }
979
980    let mut num_black = 0;
981    let mut x = root;
982    while let Some(n) = x {
983        if !is_red(x) {
984            num_black += 1;
985        }
986        x = &n.left;
987    }
988    go(root, num_black)
989}
990
991#[cfg(test)]
992struct DebugView<'a, K, V>(&'a NodeRef<K, V>);
993
994#[cfg(test)]
995impl<'a, K: AsRef<[u8]>, V> std::fmt::Debug for DebugView<'a, K, V> {
996    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
997        fn go<K: AsRef<[u8]>, V>(
998            f: &mut std::fmt::Formatter<'_>,
999            node: &NodeRef<K, V>,
1000            offset: usize,
1001        ) -> std::fmt::Result {
1002            match node {
1003                None => writeln!(f, "{:width$}[B] <null>", "", width = offset),
1004                Some(ref h) => {
1005                    writeln!(
1006                        f,
1007                        "{:width$}[{}] {:?}",
1008                        "",
1009                        if is_red(node) { "R" } else { "B" },
1010                        h.key.as_ref(),
1011                        width = offset
1012                    )?;
1013                    go(f, &h.left, offset + 2)?;
1014                    go(f, &h.right, offset + 2)
1015                }
1016            }
1017        }
1018        go(f, self.0, 0)
1019    }
1020}
1021
1022#[cfg(test)]
1023mod tests;