1use crate::{collections::VecDeque, vec, vec::Vec};
2use core::cmp::Reverse;
3use core::marker::PhantomData;
4
5pub trait Merge {
6 type Item;
7 fn merge(left: &Self::Item, right: &Self::Item) -> Self::Item;
8}
9
10pub struct MerkleTree<T, M> {
11 nodes: Vec<T>,
12 merge: PhantomData<M>,
13}
14
15impl<T, M> MerkleTree<T, M>
16where
17 T: Ord + Default + Clone,
18 M: Merge<Item = T>,
19{
20 pub fn build_proof(&self, leaf_indices: &[u32]) -> Option<MerkleProof<T, M>> {
22 if self.nodes.is_empty() || leaf_indices.is_empty() {
23 return None;
24 }
25
26 let leaves_count = ((self.nodes.len() >> 1) + 1) as u32;
27 let mut indices = leaf_indices
28 .iter()
29 .map(|i| leaves_count + i - 1)
30 .collect::<Vec<_>>();
31
32 indices.sort_by_key(|i| Reverse(*i));
33 if indices[0] >= (leaves_count << 1) - 1 {
34 return None;
35 }
36
37 let mut lemmas = Vec::new();
38 let mut queue: VecDeque<u32> = indices.clone().into();
39
40 while let Some(index) = queue.pop_front() {
41 if index == 0 {
42 assert!(queue.is_empty());
43 break;
44 }
45 let sibling = index.sibling();
46 if Some(&sibling) == queue.front() {
47 queue.pop_front();
48 } else {
49 lemmas.push(self.nodes[sibling as usize].clone());
50 }
51
52 let parent = index.parent();
53 if parent != 0 {
54 queue.push_back(parent);
55 }
56 }
57
58 indices.sort_by_key(|i| &self.nodes[*i as usize]);
59
60 Some(MerkleProof {
61 indices,
62 lemmas,
63 merge: PhantomData,
64 })
65 }
66
67 pub fn root(&self) -> T {
68 if self.nodes.is_empty() {
69 T::default()
70 } else {
71 self.nodes[0].clone()
72 }
73 }
74
75 pub fn nodes(&self) -> &[T] {
76 &self.nodes
77 }
78}
79
80pub struct MerkleProof<T, M> {
81 indices: Vec<u32>,
82 lemmas: Vec<T>,
83 merge: PhantomData<M>,
84}
85
86impl<T, M> MerkleProof<T, M>
87where
88 T: Ord + Default + Clone,
89 M: Merge<Item = T>,
90{
91 pub fn new(indices: Vec<u32>, lemmas: Vec<T>) -> Self {
92 Self {
93 indices,
94 lemmas,
95 merge: PhantomData,
96 }
97 }
98
99 pub fn root(&self, leaves: &[T]) -> Option<T> {
100 if leaves.len() != self.indices.len() || leaves.is_empty() {
101 return None;
102 }
103
104 let mut leaves = leaves.to_vec();
105 leaves.sort();
106
107 let mut pre = self
108 .indices
109 .iter()
110 .zip(leaves.into_iter())
111 .map(|(i, l)| (*i, l))
112 .collect::<Vec<_>>();
113 pre.sort_by_key(|i| Reverse(i.0));
114
115 let mut queue: VecDeque<(u32, T)> = pre.into();
116 let mut lemmas_iter = self.lemmas.iter();
117
118 while let Some((index, node)) = queue.pop_front() {
119 if index == 0 {
120 if lemmas_iter.next().is_none() && queue.is_empty() {
122 return Some(node);
123 } else {
124 return None;
125 }
126 }
127
128 if let Some(sibling) = match queue.front() {
129 Some((front, _)) if *front == index.sibling() => queue.pop_front().map(|i| i.1),
130 _ => lemmas_iter.next().cloned(),
131 } {
132 let parent_node = if index.is_left() {
133 M::merge(&node, &sibling)
134 } else {
135 M::merge(&sibling, &node)
136 };
137
138 queue.push_back((index.parent(), parent_node));
139 }
140 }
141
142 None
143 }
144
145 pub fn verify(&self, root: &T, leaves: &[T]) -> bool {
146 match self.root(leaves) {
147 Some(r) => &r == root,
148 _ => false,
149 }
150 }
151
152 pub fn indices(&self) -> &[u32] {
153 &self.indices
154 }
155
156 pub fn lemmas(&self) -> &[T] {
157 &self.lemmas
158 }
159}
160
161#[derive(Default)]
162pub struct CBMT<T, M> {
163 data_type: PhantomData<T>,
164 merge: PhantomData<M>,
165}
166
167impl<T, M> CBMT<T, M>
168where
169 T: Ord + Default + Clone,
170 M: Merge<Item = T>,
171{
172 pub fn build_merkle_root(leaves: &[T]) -> T {
173 if leaves.is_empty() {
174 return T::default();
175 }
176
177 let mut queue = VecDeque::with_capacity((leaves.len() + 1) >> 1);
178
179 let mut iter = leaves.rchunks_exact(2);
180 while let Some([leaf1, leaf2]) = iter.next() {
181 queue.push_back(M::merge(leaf1, leaf2))
182 }
183 if let [leaf] = iter.remainder() {
184 queue.push_front(leaf.clone())
185 }
186
187 while queue.len() > 1 {
188 let right = queue.pop_front().unwrap();
189 let left = queue.pop_front().unwrap();
190 queue.push_back(M::merge(&left, &right));
191 }
192
193 queue.pop_front().unwrap()
194 }
195
196 pub fn build_merkle_tree(leaves: &[T]) -> MerkleTree<T, M> {
197 let len = leaves.len();
198 if len > 0 {
199 let mut nodes = vec![T::default(); len - 1];
200 nodes.extend_from_slice(leaves);
201
202 (0..len - 1)
203 .rev()
204 .for_each(|i| nodes[i] = M::merge(&nodes[(i << 1) + 1], &nodes[(i << 1) + 2]));
205
206 MerkleTree {
207 nodes,
208 merge: PhantomData,
209 }
210 } else {
211 MerkleTree {
212 nodes: vec![],
213 merge: PhantomData,
214 }
215 }
216 }
217
218 pub fn build_merkle_proof(leaves: &[T], leaf_indices: &[u32]) -> Option<MerkleProof<T, M>> {
219 Self::build_merkle_tree(leaves).build_proof(leaf_indices)
220 }
221
222 pub fn retrieve_leaves(leaves: &[T], proof: &MerkleProof<T, M>) -> Option<Vec<T>> {
224 if leaves.is_empty() || proof.indices().is_empty() {
225 return None;
226 }
227
228 let leaves_count = leaves.len() as u32;
229 let valid_indices_range = leaves_count - 1..(leaves_count << 1) - 1;
230 if proof
231 .indices()
232 .iter()
233 .all(|index| valid_indices_range.contains(index))
234 {
235 Some(
236 proof
237 .indices()
238 .iter()
239 .map(|index| leaves[(index + 1 - leaves_count) as usize].clone())
240 .collect(),
241 )
242 } else {
243 None
244 }
245 }
246}
247
248trait TreeIndex {
249 fn sibling(&self) -> Self;
250 fn parent(&self) -> Self;
251 fn is_left(&self) -> bool;
252}
253
254macro_rules! impl_tree_index {
255 ($t: ty) => {
256 impl TreeIndex for $t {
257 fn sibling(&self) -> $t {
258 if *self == 0 {
259 0
260 } else {
261 ((self + 1) ^ 1) - 1
262 }
263 }
264
265 fn parent(&self) -> $t {
266 if *self == 0 {
267 0
268 } else {
269 (self - 1) >> 1
270 }
271 }
272
273 fn is_left(&self) -> bool {
274 self & 1 == 1
275 }
276 }
277 };
278}
279
280impl_tree_index!(u32);
281impl_tree_index!(usize);
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286 use proptest::collection::vec;
287 use proptest::num::i32;
288 use proptest::prelude::*;
289 use proptest::proptest;
290 use proptest::sample::subsequence;
291
292 struct MergeI32 {}
293
294 impl Merge for MergeI32 {
295 type Item = i32;
296 fn merge(left: &Self::Item, right: &Self::Item) -> Self::Item {
297 right.wrapping_sub(*left)
298 }
299 }
300
301 type CBMTI32 = CBMT<i32, MergeI32>;
302 type CBMTI32Proof = MerkleProof<i32, MergeI32>;
303
304 #[test]
305 fn build_empty() {
306 let leaves = vec![];
307 let tree = CBMTI32::build_merkle_tree(&leaves);
308 assert!(tree.nodes().is_empty());
309 assert_eq!(tree.root(), i32::default());
310 }
311
312 #[test]
313 fn build_one() {
314 let leaves = vec![1i32];
315 let tree = CBMTI32::build_merkle_tree(&leaves);
316 assert_eq!(vec![1], tree.nodes());
317 }
318
319 #[test]
320 fn build_two() {
321 let leaves = vec![1i32, 2];
322 let tree = CBMTI32::build_merkle_tree(&leaves);
323 assert_eq!(vec![1, 1, 2], tree.nodes());
324 }
325
326 #[test]
327 fn build_five() {
328 let leaves = vec![2i32, 3, 5, 7, 11];
329 let tree = CBMTI32::build_merkle_tree(&leaves);
330 assert_eq!(vec![4, -2, 2, 4, 2, 3, 5, 7, 11], tree.nodes());
331 }
332
333 #[test]
334 fn build_root_directly() {
335 let leaves = vec![2i32, 3, 5, 7, 11];
336 assert_eq!(4, CBMTI32::build_merkle_root(&leaves));
337 }
338
339 #[test]
340 fn rebuild_proof() {
341 let leaves = vec![2i32, 3, 5, 7, 11];
342 let tree = CBMTI32::build_merkle_tree(&leaves);
343 let root = tree.root();
344
345 let proof = tree.build_proof(&[0, 3]).unwrap();
347 let lemmas = proof.lemmas();
348 let indices = proof.indices();
349
350 let needed_leaves: Vec<i32> = indices
352 .iter()
353 .map(|i| tree.nodes()[*i as usize].clone())
354 .collect();
355 let rebuild_proof = CBMTI32Proof::new(indices.to_vec(), lemmas.to_vec());
356 assert_eq!(rebuild_proof.verify(&root, &needed_leaves), true);
357 assert_eq!(root, rebuild_proof.root(&needed_leaves).unwrap());
358 }
359
360 fn _build_root_is_same_as_tree_root(leaves: Vec<i32>) {
361 let root = CBMTI32::build_merkle_root(&leaves);
362 let tree = CBMTI32::build_merkle_tree(&leaves);
363 assert_eq!(root, tree.root());
364 }
365
366 proptest! {
367 #[test]
368 fn build_root_is_same_as_tree_root(leaves in vec(i32::ANY, 0..1000)) {
369 _build_root_is_same_as_tree_root(leaves);
370 }
371 }
372
373 #[test]
374 fn build_proof() {
375 let leaves = vec![2i32, 3, 5, 7, 11, 13];
376 let leaf_indices = vec![0u32, 5u32];
377 let proof_leaves = leaf_indices
378 .iter()
379 .map(|i| leaves[*i as usize].clone())
380 .collect::<Vec<_>>();
381 let proof = CBMTI32::build_merkle_proof(&leaves, &leaf_indices).unwrap();
382
383 assert_eq!(vec![11, 3, 2], proof.lemmas);
384 assert_eq!(Some(1), proof.root(&proof_leaves));
385
386 let leaves = vec![2i32];
388 let leaf_indices = vec![0u32];
389 let proof_leaves = leaf_indices
390 .iter()
391 .map(|i| leaves[*i as usize].clone())
392 .collect::<Vec<_>>();
393 let proof = CBMTI32::build_merkle_proof(&leaves, &leaf_indices).unwrap();
394 assert!(proof.lemmas.is_empty());
395 assert_eq!(Some(2), proof.root(&proof_leaves));
396 }
397
398 fn _tree_root_is_same_as_proof_root(leaves: Vec<i32>, leaf_indices: Vec<u32>) {
399 let proof_leaves = leaf_indices
400 .iter()
401 .map(|i| leaves[*i as usize].clone())
402 .collect::<Vec<_>>();
403
404 let proof = CBMTI32::build_merkle_proof(&leaves, &leaf_indices).unwrap();
405 let root = CBMTI32::build_merkle_root(&leaves);
406 assert_eq!(root, proof.root(&proof_leaves).unwrap());
407 }
408
409 proptest! {
410 #[test]
411 fn tree_root_is_same_as_proof_root(input in vec(i32::ANY, 2..1000)
412 .prop_flat_map(|leaves| (Just(leaves.clone()), subsequence((0..leaves.len() as u32).collect::<Vec<u32>>(), 1..leaves.len())))
413 ) {
414 _tree_root_is_same_as_proof_root(input.0, input.1);
415 }
416 }
417
418 #[test]
419 fn verify_retrieve_leaves() {
420 let leaves = vec![2i32, 3, 5, 7, 11, 13];
421 let leaf_indices = vec![0u32, 3];
422 let mut proof = CBMTI32::build_merkle_proof(&leaves, &leaf_indices).unwrap();
423 let retrieved_leaves = CBMTI32::retrieve_leaves(&leaves, &proof);
424 assert_eq!(Some(vec![2, 7]), retrieved_leaves);
425 assert_eq!(
426 proof.root(&retrieved_leaves.unwrap()).unwrap(),
427 CBMTI32::build_merkle_root(&leaves)
428 );
429
430 proof.indices = vec![];
431 assert_eq!(None, CBMTI32::retrieve_leaves(&leaves, &proof));
432
433 proof.indices = vec![4];
434 assert_eq!(None, CBMTI32::retrieve_leaves(&leaves, &proof));
435
436 proof.indices = vec![11];
437 assert_eq!(None, CBMTI32::retrieve_leaves(&leaves, &proof));
438 }
439}