fuel_merkle/binary/
root_calculator.rs

1use core::convert::Infallible;
2
3use crate::{
4    binary::{
5        empty_sum,
6        Node,
7    },
8    common::Bytes32,
9};
10
11use crate::alloc::borrow::ToOwned;
12use alloc::vec::Vec;
13
14#[derive(Debug)]
15pub(crate) enum NodeStackPushError<E> {
16    Callback(E),
17    TooLarge,
18}
19
20#[derive(Default, Debug, Clone, PartialEq)]
21#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
22pub struct MerkleRootCalculator {
23    stack: Vec<Node>,
24}
25
26impl MerkleRootCalculator {
27    pub fn new() -> Self {
28        Self { stack: Vec::new() }
29    }
30
31    pub fn new_with_stack(stack: Vec<Node>) -> Self {
32        Self { stack }
33    }
34
35    pub fn clear(&mut self) {
36        self.stack.clear();
37    }
38
39    /// Push a leaf to stack of nodes, propagating changes through the tree.
40    /// Calls `node_created` for each new node created, stopping on first error.
41    pub(crate) fn push_with_callback<F, E>(
42        &mut self,
43        node: Node,
44        mut node_created: F,
45    ) -> Result<(), NodeStackPushError<E>>
46    where
47        F: FnMut(&Node) -> Result<(), E>,
48    {
49        node_created(&node).map_err(NodeStackPushError::Callback)?;
50        self.stack.push(node);
51
52        // Propagate changes through the tree.
53        #[allow(clippy::arithmetic_side_effects)] // ensured by loop condition
54        while self.stack.len() > 1 {
55            let rhs = &self.stack[self.stack.len() - 1];
56            let lhs = &self.stack[self.stack.len() - 2];
57            if rhs.height() != lhs.height() {
58                break;
59            }
60
61            let parent_pos = lhs
62                .position()
63                .parent()
64                .map_err(|_| NodeStackPushError::TooLarge)?;
65            let new = Node::create_node(parent_pos, lhs, rhs);
66            node_created(&new).map_err(NodeStackPushError::Callback)?;
67            let _ = self.stack.pop();
68            let _ = self.stack.pop();
69            self.stack.push(new);
70        }
71
72        Ok(())
73    }
74
75    /// Push a new leaf node.
76    /// Panics if the tree would be too large to compute the root for.
77    /// In practice this never occurs, as you'd run out of memory first.
78    pub fn push(&mut self, data: &[u8]) {
79        let node = Node::create_leaf(0, data).expect("Zero is a valid index for a leaf");
80        self.push_with_callback::<_, Infallible>(node, |_| Ok(()))
81            .expect("Tree too large");
82    }
83
84    pub fn root(mut self) -> Bytes32 {
85        if self.stack.is_empty() {
86            return empty_sum().to_owned()
87        }
88        while self.stack.len() > 1 {
89            let right_child = self.stack.pop().expect("Checked in loop bound");
90            let left_child = self.stack.pop().expect("Checked in loop bound");
91            let merged_pos = left_child
92                .position()
93                .parent()
94                .expect("Left child has no parent");
95            let merged_node = Node::create_node(merged_pos, &left_child, &right_child);
96            self.stack.push(merged_node);
97        }
98        self.stack.pop().unwrap().hash().to_owned()
99    }
100
101    pub fn root_from_iterator<I: Iterator<Item = T>, T: AsRef<[u8]>>(
102        self,
103        iterator: I,
104    ) -> Bytes32 {
105        let mut calculator = MerkleRootCalculator::new();
106
107        for data in iterator {
108            calculator.push(data.as_ref());
109        }
110
111        calculator.root()
112    }
113
114    pub fn stack(&self) -> &Vec<Node> {
115        &self.stack
116    }
117}
118
119#[cfg(test)]
120mod test {
121    use super::*;
122    use crate::binary::in_memory::MerkleTree;
123    use fuel_merkle_test_helpers::TEST_DATA;
124    #[cfg(test)]
125    use serde_json as _;
126
127    #[test]
128    fn root_returns_the_empty_root_for_0_leaves() {
129        let tree = MerkleTree::new();
130        let calculate_root = MerkleRootCalculator::new();
131
132        assert_eq!(tree.root(), calculate_root.root());
133    }
134
135    #[test]
136    fn root_returns_the_merkle_root_for_1_leaf() {
137        let mut tree = MerkleTree::new();
138        let mut calculate_root = MerkleRootCalculator::new();
139
140        let data = &TEST_DATA[0..1]; // 1 leaf
141        for datum in data.iter() {
142            tree.push(datum);
143            calculate_root.push(datum)
144        }
145
146        assert_eq!(tree.root(), calculate_root.root());
147    }
148
149    #[test]
150    fn root_returns_the_merkle_root_for_7_leaves() {
151        let mut tree = MerkleTree::new();
152        let mut calculate_root = MerkleRootCalculator::new();
153
154        let data = &TEST_DATA[0..7];
155        for datum in data.iter() {
156            tree.push(datum);
157            calculate_root.push(datum)
158        }
159        assert_eq!(tree.root(), calculate_root.root());
160    }
161
162    #[test]
163    fn root_returns_the_merkle_root_for_100000_leaves() {
164        let mut tree = MerkleTree::new();
165        let mut calculate_root = MerkleRootCalculator::new();
166
167        for value in 0..10000u64 {
168            let data = value.to_le_bytes();
169            tree.push(&data);
170            calculate_root.push(&data);
171        }
172
173        assert_eq!(tree.root(), calculate_root.root());
174    }
175
176    #[test]
177    fn root_returns_the_merkle_root_from_iterator() {
178        let mut tree = MerkleTree::new();
179        let calculate_root = MerkleRootCalculator::new();
180
181        let data = &TEST_DATA[0..7];
182        for datum in data.iter() {
183            tree.push(datum);
184        }
185
186        let root = calculate_root.root_from_iterator(data.iter());
187
188        assert_eq!(tree.root(), root);
189    }
190
191    #[test]
192    #[cfg(feature = "serde")]
193    fn test_serialize_deserialize() {
194        let mut calculator = MerkleRootCalculator::new();
195
196        let data = &TEST_DATA[0..7];
197        for datum in data.iter() {
198            calculator.push(datum);
199        }
200        let json = serde_json::to_string(&calculator).unwrap();
201
202        let deserialized_calculator: MerkleRootCalculator =
203            serde_json::from_str(&json).expect("Unable to read from str");
204
205        assert_eq!(calculator.root(), deserialized_calculator.root());
206    }
207}