fuel_merkle/binary/
root_calculator.rs1use core::convert::Infallible;
2
3use crate::{
4 binary::{
5 empty_sum,
6 Node,
7 },
8 common::Bytes32,
9};
10
11use crate::alloc::borrow::ToOwned;
12use alloc::vec::Vec;
13
14#[derive(Debug)]
15pub(crate) enum NodeStackPushError<E> {
16 Callback(E),
17 TooLarge,
18}
19
20#[derive(Default, Debug, Clone, PartialEq)]
21#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
22pub struct MerkleRootCalculator {
23 stack: Vec<Node>,
24}
25
26impl MerkleRootCalculator {
27 pub fn new() -> Self {
28 Self { stack: Vec::new() }
29 }
30
31 pub fn new_with_stack(stack: Vec<Node>) -> Self {
32 Self { stack }
33 }
34
35 pub fn new_from_existing_leaves(leaf_hashes: impl Iterator<Item = Bytes32>) -> Self {
36 let mut calculator = Self::new();
37 leaf_hashes.for_each(|leaf| {
38 calculator
39 .push_with_callback::<_, Infallible>(
40 Node::create_leaf_with_hash(0, leaf)
41 .expect("Zero is a valid index for a leaf; qed"),
42 |_| Ok(()),
43 )
44 .expect("Tree too large; qed");
45 });
46 calculator
47 }
48
49 pub fn clear(&mut self) {
50 self.stack.clear();
51 }
52
53 pub(crate) fn push_with_callback<F, E>(
56 &mut self,
57 node: Node,
58 mut node_created: F,
59 ) -> Result<(), NodeStackPushError<E>>
60 where
61 F: FnMut(&Node) -> Result<(), E>,
62 {
63 node_created(&node).map_err(NodeStackPushError::Callback)?;
64 self.stack.push(node);
65
66 #[allow(clippy::arithmetic_side_effects)] while self.stack.len() > 1 {
69 let rhs = &self.stack[self.stack.len() - 1];
70 let lhs = &self.stack[self.stack.len() - 2];
71 if rhs.height() != lhs.height() {
72 break;
73 }
74
75 let parent_pos = lhs
76 .position()
77 .parent()
78 .map_err(|_| NodeStackPushError::TooLarge)?;
79 let new = Node::create_node(parent_pos, lhs, rhs);
80 node_created(&new).map_err(NodeStackPushError::Callback)?;
81 let _ = self.stack.pop();
82 let _ = self.stack.pop();
83 self.stack.push(new);
84 }
85
86 Ok(())
87 }
88
89 pub fn push(&mut self, data: &[u8]) {
93 let node = Node::create_leaf(0, data).expect("Zero is a valid index for a leaf");
94 self.push_with_callback::<_, Infallible>(node, |_| Ok(()))
95 .expect("Tree too large");
96 }
97
98 pub fn root(mut self) -> Bytes32 {
99 if self.stack.is_empty() {
100 return empty_sum().to_owned()
101 }
102 while self.stack.len() > 1 {
103 let right_child = self.stack.pop().expect("Checked in loop bound");
104 let left_child = self.stack.pop().expect("Checked in loop bound");
105 let merged_pos = left_child
106 .position()
107 .parent()
108 .expect("Left child has no parent");
109 let merged_node = Node::create_node(merged_pos, &left_child, &right_child);
110 self.stack.push(merged_node);
111 }
112 self.stack.pop().unwrap().hash().to_owned()
113 }
114
115 pub fn root_from_iterator<I: Iterator<Item = T>, T: AsRef<[u8]>>(
116 self,
117 iterator: I,
118 ) -> Bytes32 {
119 let mut calculator = MerkleRootCalculator::new();
120
121 for data in iterator {
122 calculator.push(data.as_ref());
123 }
124
125 calculator.root()
126 }
127
128 pub fn stack(&self) -> &Vec<Node> {
129 &self.stack
130 }
131}
132
133#[cfg(test)]
134mod test {
135 use super::*;
136 use crate::binary::{
137 in_memory::MerkleTree,
138 leaf_sum,
139 };
140 use fuel_merkle_test_helpers::TEST_DATA;
141 #[cfg(test)]
142 use serde_json as _;
143
144 #[test]
145 fn root_returns_the_empty_root_for_0_leaves() {
146 let tree = MerkleTree::new();
147 let calculate_root = MerkleRootCalculator::new();
148
149 assert_eq!(tree.root(), calculate_root.root());
150 }
151
152 #[test]
153 fn root_returns_the_merkle_root_for_1_leaf() {
154 let mut tree = MerkleTree::new();
155 let mut calculate_root = MerkleRootCalculator::new();
156
157 let data = &TEST_DATA[0..1]; for datum in data.iter() {
159 tree.push(datum);
160 calculate_root.push(datum)
161 }
162
163 assert_eq!(tree.root(), calculate_root.root());
164 }
165
166 #[test]
167 fn root_returns_the_merkle_root_for_7_leaves() {
168 let mut tree = MerkleTree::new();
169 let mut calculate_root = MerkleRootCalculator::new();
170
171 let data = &TEST_DATA[0..7];
172 for datum in data.iter() {
173 tree.push(datum);
174 calculate_root.push(datum)
175 }
176 assert_eq!(tree.root(), calculate_root.root());
177 }
178
179 #[test]
180 fn root_returns_the_merkle_root_for_100000_leaves() {
181 let mut tree = MerkleTree::new();
182 let mut calculate_root = MerkleRootCalculator::new();
183
184 for value in 0..10000u64 {
185 let data = value.to_le_bytes();
186 tree.push(&data);
187 calculate_root.push(&data);
188 }
189
190 assert_eq!(tree.root(), calculate_root.root());
191 }
192
193 #[test]
194 fn root_returns_the_merkle_root_from_iterator() {
195 let mut tree = MerkleTree::new();
196 let calculate_root = MerkleRootCalculator::new();
197
198 let data = &TEST_DATA[0..7];
199 for datum in data.iter() {
200 tree.push(datum);
201 }
202
203 let root = calculate_root.root_from_iterator(data.iter());
204
205 assert_eq!(tree.root(), root);
206 }
207
208 #[test]
209 fn root_returns_the_merkle_root_correct_new_existing() {
210 let calculate_root = MerkleRootCalculator::new();
211 let data = &TEST_DATA[0..7];
212 let root = calculate_root.root_from_iterator(data.iter());
213
214 let new_calculate_root = MerkleRootCalculator::new_from_existing_leaves(
215 data.iter().map(|d| leaf_sum(d)),
216 );
217 assert_eq!(new_calculate_root.root(), root);
218 }
219
220 #[test]
221 #[cfg(feature = "serde")]
222 fn test_serialize_deserialize() {
223 let mut calculator = MerkleRootCalculator::new();
224
225 let data = &TEST_DATA[0..7];
226 for datum in data.iter() {
227 calculator.push(datum);
228 }
229 let json = serde_json::to_string(&calculator).unwrap();
230
231 let deserialized_calculator: MerkleRootCalculator =
232 serde_json::from_str(&json).expect("Unable to read from str");
233
234 assert_eq!(calculator.root(), deserialized_calculator.root());
235 }
236}