1use crate::{
2 empty,
3 hash_tree::{fork, fork_hash, labeled_hash, leaf_hash, Hash},
4 labeled, leaf, pruned, HashTree, HashTreeNode,
5};
6use std::cmp::Ordering::{self, Equal, Greater, Less};
7use std::{borrow::Cow, fmt::Debug};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10enum Color {
11 Red,
12 Black,
13}
14
15impl Color {
16 fn flip_assign(&mut self) {
17 *self = self.flip()
18 }
19
20 fn flip(self) -> Self {
21 match self {
22 Self::Red => Self::Black,
23 Self::Black => Self::Red,
24 }
25 }
26}
27
28pub trait AsHashTree {
30 fn root_hash(&self) -> Hash;
33
34 fn as_hash_tree(&self) -> HashTree;
36}
37
38impl AsHashTree for Vec<u8> {
39 fn root_hash(&self) -> Hash {
40 leaf_hash(&self[..])
41 }
42
43 fn as_hash_tree(&self) -> HashTree {
44 leaf(Cow::from(&self[..]))
45 }
46}
47
48impl AsHashTree for Hash {
49 fn root_hash(&self) -> Hash {
50 leaf_hash(&self[..])
51 }
52
53 fn as_hash_tree(&self) -> HashTree {
54 leaf(Cow::from(&self[..]))
55 }
56}
57
58impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> AsHashTree for RbTree<K, V> {
59 fn root_hash(&self) -> Hash {
60 match self.root.as_ref() {
61 None => empty().digest(),
62 Some(n) => n.subtree_hash,
63 }
64 }
65
66 fn as_hash_tree(&self) -> HashTree {
67 Node::full_witness_tree(&self.root, Node::data_tree)
68 }
69}
70
71#[derive(PartialEq, Debug, Clone, Copy)]
72enum KeyBound<'a> {
73 Exact(&'a [u8]),
74 Neighbor(&'a [u8]),
75}
76
77impl<'a> AsRef<[u8]> for KeyBound<'a> {
78 fn as_ref(&self) -> &'a [u8] {
79 match self {
80 KeyBound::Exact(key) => key,
81 KeyBound::Neighbor(key) => key,
82 }
83 }
84}
85
86type NodeRef<K, V> = Option<Box<Node<K, V>>>;
87
88#[derive(Clone, Debug)]
93struct Node<K, V> {
94 key: K,
95 value: V,
96 left: NodeRef<K, V>,
97 right: NodeRef<K, V>,
98 color: Color,
99
100 subtree_hash: Hash,
103}
104
105impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> Node<K, V> {
106 fn new(key: K, value: V) -> Box<Node<K, V>> {
107 let value_hash = value.root_hash();
108 let data_hash = labeled_hash(key.as_ref(), &value_hash);
109 Box::new(Self {
110 key,
111 value,
112 left: None,
113 right: None,
114 color: Color::Red,
115 subtree_hash: data_hash,
116 })
117 }
118
119 fn data_hash(&self) -> Hash {
120 labeled_hash(self.key.as_ref(), &self.value.root_hash())
121 }
122
123 fn left_hash_tree(&self) -> HashTree {
124 match self.left.as_ref() {
125 None => empty(),
126 Some(l) => pruned(l.subtree_hash),
127 }
128 }
129
130 fn right_hash_tree(&self) -> HashTree {
131 match self.right.as_ref() {
132 None => empty(),
133 Some(r) => pruned(r.subtree_hash),
134 }
135 }
136
137 fn visit<'a, F>(n: &'a NodeRef<K, V>, f: &mut F)
138 where
139 F: 'a + FnMut(&'a [u8], &'a V),
140 {
141 if let Some(n) = n {
142 Self::visit(&n.left, f);
143 (*f)(n.key.as_ref(), &n.value);
144 Self::visit(&n.right, f)
145 }
146 }
147
148 fn data_tree(&self) -> HashTree {
149 labeled(self.key.as_ref(), self.value.as_hash_tree())
150 }
151
152 fn subtree_with<'a>(&'a self, f: impl FnOnce(&'a V) -> HashTree) -> HashTree {
153 labeled(self.key.as_ref(), f(&self.value))
154 }
155
156 fn witness_tree(&self) -> HashTree {
157 labeled(self.key.as_ref(), pruned(self.value.root_hash()))
158 }
159
160 fn full_witness_tree<'a>(n: &'a NodeRef<K, V>, f: fn(&'a Node<K, V>) -> HashTree) -> HashTree {
161 match n {
162 None => empty(),
163 Some(n) => three_way_fork(
164 Self::full_witness_tree(&n.left, f),
165 f(n),
166 Self::full_witness_tree(&n.right, f),
167 ),
168 }
169 }
170
171 fn update_subtree_hash(&mut self) {
172 self.subtree_hash = self.compute_subtree_hash();
173 }
174
175 fn compute_subtree_hash(&self) -> Hash {
176 let h = self.data_hash();
177
178 match (self.left.as_ref(), self.right.as_ref()) {
179 (None, None) => h,
180 (Some(l), None) => fork_hash(&l.subtree_hash, &h),
181 (None, Some(r)) => fork_hash(&h, &r.subtree_hash),
182 (Some(l), Some(r)) => fork_hash(&l.subtree_hash, &fork_hash(&h, &r.subtree_hash)),
183 }
184 }
185}
186
187#[derive(PartialEq, Debug)]
188enum Visit {
189 Pre,
190 In,
191 Post,
192}
193
194#[derive(Debug)]
196pub struct Iter<'a, K, V> {
197 visit: Visit,
202 parents: Vec<&'a Node<K, V>>,
203}
204
205impl<'a, K, V> Iter<'a, K, V> {
206 fn step(&mut self) -> bool {
215 match self.parents.last() {
216 Some(tip) => {
217 match self.visit {
218 Visit::Pre => {
219 if let Some(l) = &tip.left {
220 self.parents.push(l);
221 } else {
222 self.visit = Visit::In;
223 }
224 }
225 Visit::In => {
226 if let Some(r) = &tip.right {
227 self.parents.push(r);
228 self.visit = Visit::Pre;
229 } else {
230 self.visit = Visit::Post;
231 }
232 }
233 Visit::Post => {
234 let tip = self.parents.pop().unwrap();
235 if let Some(parent) = self.parents.last() {
236 if parent
237 .left
238 .as_ref()
239 .map(|l| l.as_ref() as *const Node<K, V>)
240 == Some(tip as *const Node<K, V>)
241 {
242 self.visit = Visit::In;
243 }
244 }
245 }
246 }
247 true
248 }
249 None => false,
250 }
251 }
252}
253
254impl<'a, K, V> std::iter::Iterator for Iter<'a, K, V> {
255 type Item = (&'a K, &'a V);
256
257 fn next(&mut self) -> Option<Self::Item> {
258 while self.step() {
259 if self.visit == Visit::In {
260 return self.parents.last().map(|n| (&n.key, &n.value));
261 }
262 }
263 None
264 }
265}
266
267#[derive(Default, Clone)]
270pub struct RbTree<K, V> {
271 root: NodeRef<K, V>,
272}
273
274impl<K, V> PartialEq for RbTree<K, V>
275where
276 K: 'static + AsRef<[u8]> + PartialEq,
277 V: 'static + AsHashTree + PartialEq,
278{
279 fn eq(&self, other: &Self) -> bool {
280 self.iter().eq(other.iter())
281 }
282}
283
284impl<K, V> Eq for RbTree<K, V>
285where
286 K: 'static + AsRef<[u8]> + Eq,
287 V: 'static + AsHashTree + Eq,
288{
289}
290
291impl<K, V> PartialOrd for RbTree<K, V>
292where
293 K: 'static + AsRef<[u8]> + PartialOrd,
294 V: 'static + AsHashTree + PartialOrd,
295{
296 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
297 self.iter().partial_cmp(other.iter())
298 }
299}
300
301impl<K, V> Ord for RbTree<K, V>
302where
303 K: 'static + AsRef<[u8]> + Ord,
304 V: 'static + AsHashTree + Ord,
305{
306 fn cmp(&self, other: &Self) -> Ordering {
307 self.iter().cmp(other.iter())
308 }
309}
310
311impl<K, V> std::iter::FromIterator<(K, V)> for RbTree<K, V>
312where
313 K: 'static + AsRef<[u8]>,
314 V: 'static + AsHashTree,
315{
316 fn from_iter<T>(iter: T) -> Self
317 where
318 T: IntoIterator<Item = (K, V)>,
319 {
320 let mut t = RbTree::<K, V>::new();
321 for (k, v) in iter.into_iter() {
322 t.insert(k, v);
323 }
324 t
325 }
326}
327
328impl<K, V> std::fmt::Debug for RbTree<K, V>
329where
330 K: 'static + AsRef<[u8]> + std::fmt::Debug,
331 V: 'static + AsHashTree + std::fmt::Debug,
332{
333 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
334 let mut list = f.debug_list();
335 for (k, v) in self.iter() {
336 list.entry(&format_args!(
337 "({}, {:#?})",
338 String::from_utf8_lossy(k.as_ref()),
339 v.as_hash_tree()
340 ));
341 }
342 list.finish()
343 }
344}
345
346impl<K, V> RbTree<K, V> {
347 pub const fn new() -> Self {
349 Self { root: None }
350 }
351
352 pub const fn is_empty(&self) -> bool {
354 self.root.is_none()
355 }
356}
357
358impl<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static> RbTree<K, V> {
359 pub fn get(&self, key: &[u8]) -> Option<&V> {
361 let mut root = self.root.as_ref();
362 while let Some(n) = root {
363 match key.cmp(n.key.as_ref()) {
364 Equal => return Some(&n.value),
365 Less => root = n.left.as_ref(),
366 Greater => root = n.right.as_ref(),
367 }
368 }
369 None
370 }
371
372 pub fn modify(&mut self, key: &[u8], f: impl FnOnce(&mut V)) {
374 fn go<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
375 h: &mut NodeRef<K, V>,
376 k: &[u8],
377 f: impl FnOnce(&mut V),
378 ) {
379 if let Some(h) = h {
380 match k.as_ref().cmp(h.key.as_ref()) {
381 Equal => {
382 f(&mut h.value);
383 h.update_subtree_hash();
384 }
385 Less => {
386 go(&mut h.left, k, f);
387 h.update_subtree_hash();
388 }
389 Greater => {
390 go(&mut h.right, k, f);
391 h.update_subtree_hash();
392 }
393 }
394 }
395 }
396 go(&mut self.root, key, f)
397 }
398
399 fn range_witness<'a>(
400 &'a self,
401 left: Option<KeyBound<'a>>,
402 right: Option<KeyBound<'a>>,
403 f: fn(&'a Node<K, V>) -> HashTree,
404 ) -> HashTree {
405 match (left, right) {
406 (None, None) => Node::full_witness_tree(&self.root, f),
407 (Some(l), None) => self.witness_range_above(l, f),
408 (None, Some(r)) => self.witness_range_below(r, f),
409 (Some(l), Some(r)) => self.witness_range_between(l, r, f),
410 }
411 }
412
413 pub fn witness(&self, key: &[u8]) -> HashTree {
419 self.nested_witness(key, |v| v.as_hash_tree())
420 }
421
422 pub fn nested_witness<'a>(&'a self, key: &[u8], f: impl FnOnce(&'a V) -> HashTree) -> HashTree {
426 if let Some(t) = self.lookup_and_build_witness(key, f) {
427 return t;
428 }
429 self.range_witness(
430 self.lower_bound(key),
431 self.upper_bound(key),
432 Node::witness_tree,
433 )
434 }
435
436 pub fn keys(&self) -> HashTree {
440 Node::full_witness_tree(&self.root, Node::witness_tree)
441 }
442
443 pub fn key_range(&self, first: &[u8], last: &[u8]) -> HashTree {
447 self.range_witness(
448 self.lower_bound(first),
449 self.upper_bound(last),
450 Node::witness_tree,
451 )
452 }
453
454 pub fn value_range(&self, first: &[u8], last: &[u8]) -> HashTree {
457 self.range_witness(
458 self.lower_bound(first),
459 self.upper_bound(last),
460 Node::data_tree,
461 )
462 }
463
464 pub fn keys_with_prefix(&self, prefix: &[u8]) -> HashTree {
467 self.range_witness(
468 self.lower_bound(prefix),
469 self.right_prefix_neighbor(prefix),
470 Node::witness_tree,
471 )
472 }
473
474 pub fn iter(&self) -> Iter<'_, K, V> {
476 match &self.root {
477 None => Iter {
478 visit: Visit::Pre,
479 parents: vec![],
480 },
481 Some(n) => Iter {
482 visit: Visit::Pre,
483 parents: vec![n],
484 },
485 }
486 }
487
488 pub fn for_each<'a, F>(&'a self, mut f: F)
490 where
491 F: 'a + FnMut(&'a [u8], &'a V),
492 {
493 Node::visit(&self.root, &mut f)
494 }
495
496 fn witness_range_above<'a>(
497 &'a self,
498 lo: KeyBound<'a>,
499 f: fn(&'a Node<K, V>) -> HashTree,
500 ) -> HashTree {
501 fn go<'a, K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
502 n: &'a NodeRef<K, V>,
503 lo: KeyBound<'a>,
504 f: fn(&'a Node<K, V>) -> HashTree,
505 ) -> HashTree {
506 match n {
507 None => empty(),
508 Some(n) => match n.key.as_ref().cmp(lo.as_ref()) {
509 Equal => three_way_fork(
510 n.left_hash_tree(),
511 match lo {
512 KeyBound::Exact(_) => f(n),
513 KeyBound::Neighbor(_) => n.witness_tree(),
514 },
515 Node::full_witness_tree(&n.right, f),
516 ),
517 Less => three_way_fork(
518 n.left_hash_tree(),
519 pruned(n.data_hash()),
520 go(&n.right, lo, f),
521 ),
522 Greater => three_way_fork(
523 go(&n.left, lo, f),
524 f(n),
525 Node::full_witness_tree(&n.right, f),
526 ),
527 },
528 }
529 }
530 go(&self.root, lo, f)
531 }
532
533 fn witness_range_below<'a>(
534 &'a self,
535 hi: KeyBound<'a>,
536 f: fn(&'a Node<K, V>) -> HashTree,
537 ) -> HashTree {
538 fn go<'a, K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
539 n: &'a NodeRef<K, V>,
540 hi: KeyBound<'a>,
541 f: fn(&'a Node<K, V>) -> HashTree,
542 ) -> HashTree {
543 match n {
544 None => empty(),
545 Some(n) => match n.key.as_ref().cmp(hi.as_ref()) {
546 Equal => three_way_fork(
547 Node::full_witness_tree(&n.left, f),
548 match hi {
549 KeyBound::Exact(_) => f(n),
550 KeyBound::Neighbor(_) => n.witness_tree(),
551 },
552 n.right_hash_tree(),
553 ),
554 Greater => three_way_fork(
555 go(&n.left, hi, f),
556 pruned(n.data_hash()),
557 n.right_hash_tree(),
558 ),
559 Less => three_way_fork(
560 Node::full_witness_tree(&n.left, f),
561 f(n),
562 go(&n.right, hi, f),
563 ),
564 },
565 }
566 }
567 go(&self.root, hi, f)
568 }
569
570 fn witness_range_between<'a>(
571 &'a self,
572 lo: KeyBound<'a>,
573 hi: KeyBound<'a>,
574 f: fn(&'a Node<K, V>) -> HashTree,
575 ) -> HashTree {
576 debug_assert!(
577 lo.as_ref() <= hi.as_ref(),
578 "lo = {:?} > hi = {:?}",
579 lo.as_ref(),
580 hi.as_ref()
581 );
582 fn go<'a, K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
583 n: &'a NodeRef<K, V>,
584 lo: KeyBound<'a>,
585 hi: KeyBound<'a>,
586 f: fn(&'a Node<K, V>) -> HashTree,
587 ) -> HashTree {
588 match n {
589 None => empty(),
590 Some(n) => {
591 let k = n.key.as_ref();
592 match (lo.as_ref().cmp(k), k.cmp(hi.as_ref())) {
593 (Less, Less) => {
594 three_way_fork(go(&n.left, lo, hi, f), f(n), go(&n.right, lo, hi, f))
595 }
596 (Equal, Equal) => three_way_fork(
597 n.left_hash_tree(),
598 match (lo, hi) {
599 (KeyBound::Exact(_), _) => f(n),
600 (_, KeyBound::Exact(_)) => f(n),
601 _ => n.witness_tree(),
602 },
603 n.right_hash_tree(),
604 ),
605 (_, Equal) => three_way_fork(
606 go(&n.left, lo, hi, f),
607 match hi {
608 KeyBound::Exact(_) => f(n),
609 KeyBound::Neighbor(_) => n.witness_tree(),
610 },
611 n.right_hash_tree(),
612 ),
613 (Equal, _) => three_way_fork(
614 n.left_hash_tree(),
615 match lo {
616 KeyBound::Exact(_) => f(n),
617 KeyBound::Neighbor(_) => n.witness_tree(),
618 },
619 go(&n.right, lo, hi, f),
620 ),
621 (Less, Greater) => three_way_fork(
622 go(&n.left, lo, hi, f),
623 pruned(n.data_hash()),
624 n.right_hash_tree(),
625 ),
626 (Greater, Less) => three_way_fork(
627 n.left_hash_tree(),
628 pruned(n.data_hash()),
629 go(&n.right, lo, hi, f),
630 ),
631 _ => pruned(n.subtree_hash),
632 }
633 }
634 }
635 }
636 go(&self.root, lo, hi, f)
637 }
638
639 fn lower_bound(&self, key: &[u8]) -> Option<KeyBound<'_>> {
640 fn go<'a, K: 'static + AsRef<[u8]>, V>(
641 n: &'a NodeRef<K, V>,
642 key: &[u8],
643 ) -> Option<KeyBound<'a>> {
644 n.as_ref().and_then(|n| {
645 let node_key = n.key.as_ref();
646 match node_key.cmp(key) {
647 Less => go(&n.right, key).or(Some(KeyBound::Neighbor(node_key))),
648 Equal => Some(KeyBound::Exact(node_key)),
649 Greater => go(&n.left, key),
650 }
651 })
652 }
653 go(&self.root, key)
654 }
655
656 fn upper_bound(&self, key: &[u8]) -> Option<KeyBound<'_>> {
657 fn go<'a, K: 'static + AsRef<[u8]>, V>(
658 n: &'a NodeRef<K, V>,
659 key: &[u8],
660 ) -> Option<KeyBound<'a>> {
661 n.as_ref().and_then(|n| {
662 let node_key = n.key.as_ref();
663 match node_key.cmp(key) {
664 Less => go(&n.right, key),
665 Equal => Some(KeyBound::Exact(node_key)),
666 Greater => go(&n.left, key).or(Some(KeyBound::Neighbor(node_key))),
667 }
668 })
669 }
670 go(&self.root, key)
671 }
672
673 fn right_prefix_neighbor(&self, prefix: &[u8]) -> Option<KeyBound<'_>> {
674 fn is_prefix_of(p: &[u8], x: &[u8]) -> bool {
675 if p.len() > x.len() {
676 return false;
677 }
678 &x[0..p.len()] == p
679 }
680 fn go<'a, K: 'static + AsRef<[u8]>, V>(
681 n: &'a NodeRef<K, V>,
682 prefix: &[u8],
683 ) -> Option<KeyBound<'a>> {
684 n.as_ref().and_then(|n| {
685 let node_key = n.key.as_ref();
686 match node_key.cmp(prefix) {
687 Greater if is_prefix_of(prefix, node_key) => go(&n.right, prefix),
688 Greater => go(&n.left, prefix).or(Some(KeyBound::Neighbor(node_key))),
689 Less | Equal => go(&n.right, prefix),
690 }
691 })
692 }
693 go(&self.root, prefix)
694 }
695
696 fn lookup_and_build_witness<'a>(
697 &'a self,
698 key: &[u8],
699 f: impl FnOnce(&'a V) -> HashTree,
700 ) -> Option<HashTree> {
701 fn go<'a, K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
702 n: &'a NodeRef<K, V>,
703 key: &[u8],
704 f: impl FnOnce(&'a V) -> HashTree,
705 ) -> Option<HashTree> {
706 n.as_ref().and_then(|n| match key.cmp(n.key.as_ref()) {
707 Equal => Some(three_way_fork(
708 n.left_hash_tree(),
709 n.subtree_with(f),
710 n.right_hash_tree(),
711 )),
712 Less => {
713 let subtree = go(&n.left, key, f)?;
714 Some(three_way_fork(
715 subtree,
716 pruned(n.data_hash()),
717 n.right_hash_tree(),
718 ))
719 }
720 Greater => {
721 let subtree = go(&n.right, key, f)?;
722 Some(three_way_fork(
723 n.left_hash_tree(),
724 pruned(n.data_hash()),
725 subtree,
726 ))
727 }
728 })
729 }
730 go(&self.root, key, f)
731 }
732
733 pub fn insert(&mut self, key: K, value: V) {
735 fn go<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
736 h: NodeRef<K, V>,
737 k: K,
738 v: V,
739 ) -> Box<Node<K, V>> {
740 match h {
741 None => Node::new(k, v),
742 Some(mut h) => {
743 match k.as_ref().cmp(h.key.as_ref()) {
744 Equal => {
745 h.value = v;
746 }
747 Less => {
748 h.left = Some(go(h.left, k, v));
749 }
750 Greater => {
751 h.right = Some(go(h.right, k, v));
752 }
753 }
754 h.update_subtree_hash();
755 balance(h)
756 }
757 }
758 }
759 let mut root = go(self.root.take(), key, value);
760 root.color = Color::Black;
761 self.root = Some(root);
762
763 #[cfg(test)]
764 debug_assert!(
765 is_balanced(&self.root),
766 "the tree is not balanced:\n{:?}",
767 DebugView(&self.root)
768 );
769 }
770
771 pub fn delete(&mut self, key: &[u8]) {
773 fn move_red_left<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
774 mut h: Box<Node<K, V>>,
775 ) -> Box<Node<K, V>> {
776 flip_colors(&mut h);
777 if is_red(&h.right.as_ref().unwrap().left) {
778 h.right = Some(rotate_right(h.right.take().unwrap()));
779 h = rotate_left(h);
780 flip_colors(&mut h);
781 }
782 h
783 }
784
785 fn move_red_right<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
786 mut h: Box<Node<K, V>>,
787 ) -> Box<Node<K, V>> {
788 flip_colors(&mut h);
789 if is_red(&h.left.as_ref().unwrap().left) {
790 h = rotate_right(h);
791 flip_colors(&mut h);
792 }
793 h
794 }
795
796 #[inline]
797 fn min<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
798 mut h: &mut Box<Node<K, V>>,
799 ) -> &mut Box<Node<K, V>> {
800 while h.left.is_some() {
801 h = h.left.as_mut().unwrap();
802 }
803 h
804 }
805
806 fn delete_min<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
807 mut h: Box<Node<K, V>>,
808 ) -> NodeRef<K, V> {
809 if h.left.is_none() {
810 debug_assert!(h.right.is_none());
811 drop(h);
812 return None;
813 }
814 if !is_red(&h.left) && !is_red(&h.left.as_ref().unwrap().left) {
815 h = move_red_left(h);
816 }
817 h.left = delete_min(h.left.unwrap());
818 h.update_subtree_hash();
819 Some(balance(h))
820 }
821
822 fn go<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
823 mut h: Box<Node<K, V>>,
824 key: &[u8],
825 ) -> NodeRef<K, V> {
826 if key < h.key.as_ref() {
827 debug_assert!(h.left.is_some(), "the key must be present in the tree");
828 if !is_red(&h.left) && !is_red(&h.left.as_ref().unwrap().left) {
829 h = move_red_left(h);
830 }
831 h.left = go(h.left.take().unwrap(), key);
832 } else {
833 if is_red(&h.left) {
834 h = rotate_right(h);
835 }
836 if key == h.key.as_ref() && h.right.is_none() {
837 debug_assert!(h.left.is_none());
838 drop(h);
839 return None;
840 }
841
842 if !is_red(&h.right) && !is_red(&h.right.as_ref().unwrap().left) {
843 h = move_red_right(h);
844 }
845
846 if key == h.key.as_ref() {
847 let m = min(h.right.as_mut().unwrap());
848 std::mem::swap(&mut h.key, &mut m.key);
849 std::mem::swap(&mut h.value, &mut m.value);
850 h.right = delete_min(h.right.take().unwrap());
851 } else {
852 h.right = go(h.right.take().unwrap(), key);
853 }
854 }
855 h.update_subtree_hash();
856 Some(balance(h))
857 }
858
859 if self.get(key).is_none() {
860 return;
861 }
862
863 if !is_red(&self.root.as_ref().unwrap().left) && !is_red(&self.root.as_ref().unwrap().right)
864 {
865 self.root.as_mut().unwrap().color = Color::Red;
866 }
867 self.root = go(self.root.take().unwrap(), key);
868 if let Some(n) = self.root.as_mut() {
869 n.color = Color::Black;
870 }
871
872 #[cfg(test)]
873 debug_assert!(
874 is_balanced(&self.root),
875 "unbalanced map: {:?}",
876 DebugView(&self.root)
877 );
878
879 debug_assert!(self.get(key).is_none());
880 }
881}
882
883fn three_way_fork(l: HashTree, m: HashTree, r: HashTree) -> HashTree {
884 match (l.root, m.root, r.root) {
885 (HashTreeNode::Empty(), m, HashTreeNode::Empty()) => HashTree { root: m },
886 (l, m, HashTreeNode::Empty()) => fork(HashTree { root: l }, HashTree { root: m }),
887 (HashTreeNode::Empty(), m, r) => fork(HashTree { root: m }, HashTree { root: r }),
888 (HashTreeNode::Pruned(lhash), HashTreeNode::Pruned(mhash), HashTreeNode::Pruned(rhash)) => {
889 pruned(fork_hash(&lhash, &fork_hash(&mhash, &rhash)))
890 }
891 (l, HashTreeNode::Pruned(mhash), HashTreeNode::Pruned(rhash)) => {
892 fork(HashTree { root: l }, pruned(fork_hash(&mhash, &rhash)))
893 }
894 (l, m, r) => fork(
895 HashTree { root: l },
896 fork(HashTree { root: m }, HashTree { root: r }),
897 ),
898 }
899}
900
901fn is_red<K, V>(x: &NodeRef<K, V>) -> bool {
903 x.as_ref().map(|h| h.color == Color::Red).unwrap_or(false)
904}
905
906fn balance<K: AsRef<[u8]> + 'static, V: AsHashTree + 'static>(
907 mut h: Box<Node<K, V>>,
908) -> Box<Node<K, V>> {
909 if is_red(&h.right) && !is_red(&h.left) {
910 h = rotate_left(h);
911 }
912 if is_red(&h.left) && is_red(&h.left.as_ref().unwrap().left) {
913 h = rotate_right(h);
914 }
915 if is_red(&h.left) && is_red(&h.right) {
916 flip_colors(&mut h)
917 }
918 h
919}
920
921fn rotate_right<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
923 mut h: Box<Node<K, V>>,
924) -> Box<Node<K, V>> {
925 debug_assert!(is_red(&h.left));
926
927 let mut x = h.left.take().unwrap();
928 h.left = x.right.take();
929 h.update_subtree_hash();
930
931 x.right = Some(h);
932 x.color = x.right.as_ref().unwrap().color;
933 x.right.as_mut().unwrap().color = Color::Red;
934 x.update_subtree_hash();
935
936 x
937}
938
939fn rotate_left<K: 'static + AsRef<[u8]>, V: AsHashTree + 'static>(
940 mut h: Box<Node<K, V>>,
941) -> Box<Node<K, V>> {
942 debug_assert!(is_red(&h.right));
943
944 let mut x = h.right.take().unwrap();
945 h.right = x.left.take();
946 h.update_subtree_hash();
947
948 x.left = Some(h);
949 x.color = x.left.as_ref().unwrap().color;
950 x.left.as_mut().unwrap().color = Color::Red;
951 x.update_subtree_hash();
952
953 x
954}
955
956fn flip_colors<K, V>(h: &mut Box<Node<K, V>>) {
957 h.color.flip_assign();
958 h.left.as_mut().unwrap().color.flip_assign();
959 h.right.as_mut().unwrap().color.flip_assign();
960}
961
962#[cfg(test)]
963fn is_balanced<K, V>(root: &NodeRef<K, V>) -> bool {
964 fn go<K, V>(node: &NodeRef<K, V>, mut num_black: usize) -> bool {
965 match node {
966 None => num_black == 0,
967 Some(ref n) => {
968 if !is_red(node) {
969 debug_assert!(num_black > 0);
970 num_black -= 1;
971 } else {
972 assert!(!is_red(&n.left));
973 assert!(!is_red(&n.right));
974 }
975 go(&n.left, num_black) && go(&n.right, num_black)
976 }
977 }
978 }
979
980 let mut num_black = 0;
981 let mut x = root;
982 while let Some(n) = x {
983 if !is_red(x) {
984 num_black += 1;
985 }
986 x = &n.left;
987 }
988 go(root, num_black)
989}
990
991#[cfg(test)]
992struct DebugView<'a, K, V>(&'a NodeRef<K, V>);
993
994#[cfg(test)]
995impl<'a, K: AsRef<[u8]>, V> std::fmt::Debug for DebugView<'a, K, V> {
996 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
997 fn go<K: AsRef<[u8]>, V>(
998 f: &mut std::fmt::Formatter<'_>,
999 node: &NodeRef<K, V>,
1000 offset: usize,
1001 ) -> std::fmt::Result {
1002 match node {
1003 None => writeln!(f, "{:width$}[B] <null>", "", width = offset),
1004 Some(ref h) => {
1005 writeln!(
1006 f,
1007 "{:width$}[{}] {:?}",
1008 "",
1009 if is_red(node) { "R" } else { "B" },
1010 h.key.as_ref(),
1011 width = offset
1012 )?;
1013 go(f, &h.left, offset + 2)?;
1014 go(f, &h.right, offset + 2)
1015 }
1016 }
1017 }
1018 go(f, self.0, 0)
1019 }
1020}
1021
1022#[cfg(test)]
1023mod tests;