merkle_cbt/
merkle_tree.rs

1use crate::{collections::VecDeque, vec, vec::Vec};
2use core::cmp::Reverse;
3use core::marker::PhantomData;
4
5pub trait Merge {
6    type Item;
7    fn merge(left: &Self::Item, right: &Self::Item) -> Self::Item;
8}
9
10pub struct MerkleTree<T, M> {
11    nodes: Vec<T>,
12    merge: PhantomData<M>,
13}
14
15impl<T, M> MerkleTree<T, M>
16where
17    T: Ord + Default + Clone,
18    M: Merge<Item = T>,
19{
20    /// `leaf_indices`: The indices of leaves
21    pub fn build_proof(&self, leaf_indices: &[u32]) -> Option<MerkleProof<T, M>> {
22        if self.nodes.is_empty() || leaf_indices.is_empty() {
23            return None;
24        }
25
26        let leaves_count = ((self.nodes.len() >> 1) + 1) as u32;
27        let mut indices = leaf_indices
28            .iter()
29            .map(|i| leaves_count + i - 1)
30            .collect::<Vec<_>>();
31
32        indices.sort_by_key(|i| Reverse(*i));
33        if indices[0] >= (leaves_count << 1) - 1 {
34            return None;
35        }
36
37        let mut lemmas = Vec::new();
38        let mut queue: VecDeque<u32> = indices.clone().into();
39
40        while let Some(index) = queue.pop_front() {
41            if index == 0 {
42                assert!(queue.is_empty());
43                break;
44            }
45            let sibling = index.sibling();
46            if Some(&sibling) == queue.front() {
47                queue.pop_front();
48            } else {
49                lemmas.push(self.nodes[sibling as usize].clone());
50            }
51
52            let parent = index.parent();
53            if parent != 0 {
54                queue.push_back(parent);
55            }
56        }
57
58        indices.sort_by_key(|i| &self.nodes[*i as usize]);
59
60        Some(MerkleProof {
61            indices,
62            lemmas,
63            merge: PhantomData,
64        })
65    }
66
67    pub fn root(&self) -> T {
68        if self.nodes.is_empty() {
69            T::default()
70        } else {
71            self.nodes[0].clone()
72        }
73    }
74
75    pub fn nodes(&self) -> &[T] {
76        &self.nodes
77    }
78}
79
80pub struct MerkleProof<T, M> {
81    indices: Vec<u32>,
82    lemmas: Vec<T>,
83    merge: PhantomData<M>,
84}
85
86impl<T, M> MerkleProof<T, M>
87where
88    T: Ord + Default + Clone,
89    M: Merge<Item = T>,
90{
91    pub fn new(indices: Vec<u32>, lemmas: Vec<T>) -> Self {
92        Self {
93            indices,
94            lemmas,
95            merge: PhantomData,
96        }
97    }
98
99    pub fn root(&self, leaves: &[T]) -> Option<T> {
100        if leaves.len() != self.indices.len() || leaves.is_empty() {
101            return None;
102        }
103
104        let mut leaves = leaves.to_vec();
105        leaves.sort();
106
107        let mut pre = self
108            .indices
109            .iter()
110            .zip(leaves.into_iter())
111            .map(|(i, l)| (*i, l))
112            .collect::<Vec<_>>();
113        pre.sort_by_key(|i| Reverse(i.0));
114
115        let mut queue: VecDeque<(u32, T)> = pre.into();
116        let mut lemmas_iter = self.lemmas.iter();
117
118        while let Some((index, node)) = queue.pop_front() {
119            if index == 0 {
120                // ensure that all lemmas and leaves are consumed
121                if lemmas_iter.next().is_none() && queue.is_empty() {
122                    return Some(node);
123                } else {
124                    return None;
125                }
126            }
127
128            if let Some(sibling) = match queue.front() {
129                Some((front, _)) if *front == index.sibling() => queue.pop_front().map(|i| i.1),
130                _ => lemmas_iter.next().cloned(),
131            } {
132                let parent_node = if index.is_left() {
133                    M::merge(&node, &sibling)
134                } else {
135                    M::merge(&sibling, &node)
136                };
137
138                queue.push_back((index.parent(), parent_node));
139            }
140        }
141
142        None
143    }
144
145    pub fn verify(&self, root: &T, leaves: &[T]) -> bool {
146        match self.root(leaves) {
147            Some(r) => &r == root,
148            _ => false,
149        }
150    }
151
152    pub fn indices(&self) -> &[u32] {
153        &self.indices
154    }
155
156    pub fn lemmas(&self) -> &[T] {
157        &self.lemmas
158    }
159}
160
161#[derive(Default)]
162pub struct CBMT<T, M> {
163    data_type: PhantomData<T>,
164    merge: PhantomData<M>,
165}
166
167impl<T, M> CBMT<T, M>
168where
169    T: Ord + Default + Clone,
170    M: Merge<Item = T>,
171{
172    pub fn build_merkle_root(leaves: &[T]) -> T {
173        if leaves.is_empty() {
174            return T::default();
175        }
176
177        let mut queue = VecDeque::with_capacity((leaves.len() + 1) >> 1);
178
179        let mut iter = leaves.rchunks_exact(2);
180        while let Some([leaf1, leaf2]) = iter.next() {
181            queue.push_back(M::merge(leaf1, leaf2))
182        }
183        if let [leaf] = iter.remainder() {
184            queue.push_front(leaf.clone())
185        }
186
187        while queue.len() > 1 {
188            let right = queue.pop_front().unwrap();
189            let left = queue.pop_front().unwrap();
190            queue.push_back(M::merge(&left, &right));
191        }
192
193        queue.pop_front().unwrap()
194    }
195
196    pub fn build_merkle_tree(leaves: &[T]) -> MerkleTree<T, M> {
197        let len = leaves.len();
198        if len > 0 {
199            let mut nodes = vec![T::default(); len - 1];
200            nodes.extend_from_slice(leaves);
201
202            (0..len - 1)
203                .rev()
204                .for_each(|i| nodes[i] = M::merge(&nodes[(i << 1) + 1], &nodes[(i << 1) + 2]));
205
206            MerkleTree {
207                nodes,
208                merge: PhantomData,
209            }
210        } else {
211            MerkleTree {
212                nodes: vec![],
213                merge: PhantomData,
214            }
215        }
216    }
217
218    pub fn build_merkle_proof(leaves: &[T], leaf_indices: &[u32]) -> Option<MerkleProof<T, M>> {
219        Self::build_merkle_tree(leaves).build_proof(leaf_indices)
220    }
221
222    /// retrieve that a proof points to leaves of a tree, returning `None` if the proof indices is empty or out of bounds
223    pub fn retrieve_leaves(leaves: &[T], proof: &MerkleProof<T, M>) -> Option<Vec<T>> {
224        if leaves.is_empty() || proof.indices().is_empty() {
225            return None;
226        }
227
228        let leaves_count = leaves.len() as u32;
229        let valid_indices_range = leaves_count - 1..(leaves_count << 1) - 1;
230        if proof
231            .indices()
232            .iter()
233            .all(|index| valid_indices_range.contains(index))
234        {
235            Some(
236                proof
237                    .indices()
238                    .iter()
239                    .map(|index| leaves[(index + 1 - leaves_count) as usize].clone())
240                    .collect(),
241            )
242        } else {
243            None
244        }
245    }
246}
247
248trait TreeIndex {
249    fn sibling(&self) -> Self;
250    fn parent(&self) -> Self;
251    fn is_left(&self) -> bool;
252}
253
254macro_rules! impl_tree_index {
255    ($t: ty) => {
256        impl TreeIndex for $t {
257            fn sibling(&self) -> $t {
258                if *self == 0 {
259                    0
260                } else {
261                    ((self + 1) ^ 1) - 1
262                }
263            }
264
265            fn parent(&self) -> $t {
266                if *self == 0 {
267                    0
268                } else {
269                    (self - 1) >> 1
270                }
271            }
272
273            fn is_left(&self) -> bool {
274                self & 1 == 1
275            }
276        }
277    };
278}
279
280impl_tree_index!(u32);
281impl_tree_index!(usize);
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286    use proptest::collection::vec;
287    use proptest::num::i32;
288    use proptest::prelude::*;
289    use proptest::proptest;
290    use proptest::sample::subsequence;
291
292    struct MergeI32 {}
293
294    impl Merge for MergeI32 {
295        type Item = i32;
296        fn merge(left: &Self::Item, right: &Self::Item) -> Self::Item {
297            right.wrapping_sub(*left)
298        }
299    }
300
301    type CBMTI32 = CBMT<i32, MergeI32>;
302    type CBMTI32Proof = MerkleProof<i32, MergeI32>;
303
304    #[test]
305    fn build_empty() {
306        let leaves = vec![];
307        let tree = CBMTI32::build_merkle_tree(&leaves);
308        assert!(tree.nodes().is_empty());
309        assert_eq!(tree.root(), i32::default());
310    }
311
312    #[test]
313    fn build_one() {
314        let leaves = vec![1i32];
315        let tree = CBMTI32::build_merkle_tree(&leaves);
316        assert_eq!(vec![1], tree.nodes());
317    }
318
319    #[test]
320    fn build_two() {
321        let leaves = vec![1i32, 2];
322        let tree = CBMTI32::build_merkle_tree(&leaves);
323        assert_eq!(vec![1, 1, 2], tree.nodes());
324    }
325
326    #[test]
327    fn build_five() {
328        let leaves = vec![2i32, 3, 5, 7, 11];
329        let tree = CBMTI32::build_merkle_tree(&leaves);
330        assert_eq!(vec![4, -2, 2, 4, 2, 3, 5, 7, 11], tree.nodes());
331    }
332
333    #[test]
334    fn build_root_directly() {
335        let leaves = vec![2i32, 3, 5, 7, 11];
336        assert_eq!(4, CBMTI32::build_merkle_root(&leaves));
337    }
338
339    #[test]
340    fn rebuild_proof() {
341        let leaves = vec![2i32, 3, 5, 7, 11];
342        let tree = CBMTI32::build_merkle_tree(&leaves);
343        let root = tree.root();
344
345        // build proof
346        let proof = tree.build_proof(&[0, 3]).unwrap();
347        let lemmas = proof.lemmas();
348        let indices = proof.indices();
349
350        // rebuild proof
351        let needed_leaves: Vec<i32> = indices
352            .iter()
353            .map(|i| tree.nodes()[*i as usize].clone())
354            .collect();
355        let rebuild_proof = CBMTI32Proof::new(indices.to_vec(), lemmas.to_vec());
356        assert_eq!(rebuild_proof.verify(&root, &needed_leaves), true);
357        assert_eq!(root, rebuild_proof.root(&needed_leaves).unwrap());
358    }
359
360    fn _build_root_is_same_as_tree_root(leaves: Vec<i32>) {
361        let root = CBMTI32::build_merkle_root(&leaves);
362        let tree = CBMTI32::build_merkle_tree(&leaves);
363        assert_eq!(root, tree.root());
364    }
365
366    proptest! {
367        #[test]
368        fn build_root_is_same_as_tree_root(leaves in vec(i32::ANY,  0..1000)) {
369            _build_root_is_same_as_tree_root(leaves);
370        }
371    }
372
373    #[test]
374    fn build_proof() {
375        let leaves = vec![2i32, 3, 5, 7, 11, 13];
376        let leaf_indices = vec![0u32, 5u32];
377        let proof_leaves = leaf_indices
378            .iter()
379            .map(|i| leaves[*i as usize].clone())
380            .collect::<Vec<_>>();
381        let proof = CBMTI32::build_merkle_proof(&leaves, &leaf_indices).unwrap();
382
383        assert_eq!(vec![11, 3, 2], proof.lemmas);
384        assert_eq!(Some(1), proof.root(&proof_leaves));
385
386        // merkle proof for single leaf
387        let leaves = vec![2i32];
388        let leaf_indices = vec![0u32];
389        let proof_leaves = leaf_indices
390            .iter()
391            .map(|i| leaves[*i as usize].clone())
392            .collect::<Vec<_>>();
393        let proof = CBMTI32::build_merkle_proof(&leaves, &leaf_indices).unwrap();
394        assert!(proof.lemmas.is_empty());
395        assert_eq!(Some(2), proof.root(&proof_leaves));
396    }
397
398    fn _tree_root_is_same_as_proof_root(leaves: Vec<i32>, leaf_indices: Vec<u32>) {
399        let proof_leaves = leaf_indices
400            .iter()
401            .map(|i| leaves[*i as usize].clone())
402            .collect::<Vec<_>>();
403
404        let proof = CBMTI32::build_merkle_proof(&leaves, &leaf_indices).unwrap();
405        let root = CBMTI32::build_merkle_root(&leaves);
406        assert_eq!(root, proof.root(&proof_leaves).unwrap());
407    }
408
409    proptest! {
410        #[test]
411        fn tree_root_is_same_as_proof_root(input in vec(i32::ANY,  2..1000)
412            .prop_flat_map(|leaves| (Just(leaves.clone()), subsequence((0..leaves.len() as u32).collect::<Vec<u32>>(), 1..leaves.len())))
413        ) {
414            _tree_root_is_same_as_proof_root(input.0, input.1);
415        }
416    }
417
418    #[test]
419    fn verify_retrieve_leaves() {
420        let leaves = vec![2i32, 3, 5, 7, 11, 13];
421        let leaf_indices = vec![0u32, 3];
422        let mut proof = CBMTI32::build_merkle_proof(&leaves, &leaf_indices).unwrap();
423        let retrieved_leaves = CBMTI32::retrieve_leaves(&leaves, &proof);
424        assert_eq!(Some(vec![2, 7]), retrieved_leaves);
425        assert_eq!(
426            proof.root(&retrieved_leaves.unwrap()).unwrap(),
427            CBMTI32::build_merkle_root(&leaves)
428        );
429
430        proof.indices = vec![];
431        assert_eq!(None, CBMTI32::retrieve_leaves(&leaves, &proof));
432
433        proof.indices = vec![4];
434        assert_eq!(None, CBMTI32::retrieve_leaves(&leaves, &proof));
435
436        proof.indices = vec![11];
437        assert_eq!(None, CBMTI32::retrieve_leaves(&leaves, &proof));
438    }
439}