use crate::{
binary::{
empty_sum,
in_memory::NodesTable,
Node,
Primitive,
},
common::{
Bytes32,
Position,
ProofSet,
StorageMap,
Subtree,
},
storage::{
Mappable,
StorageInspect,
StorageInspectInfallible,
StorageMutate,
StorageMutateInfallible,
},
};
use alloc::vec::Vec;
use core::marker::PhantomData;
#[derive(Debug, Clone, derive_more::Display)]
pub enum MerkleTreeError<StorageError> {
#[display(fmt = "proof index {_0} is not valid")]
InvalidProofIndex(u64),
#[display(fmt = "cannot load node with key {_0}; the key is not found in storage")]
LoadError(u64),
#[display(fmt = "{}", _0)]
StorageError(StorageError),
}
impl<StorageError> From<StorageError> for MerkleTreeError<StorageError> {
fn from(err: StorageError) -> MerkleTreeError<StorageError> {
MerkleTreeError::StorageError(err)
}
}
#[derive(Debug, Clone)]
pub struct MerkleTree<TableType, StorageType> {
storage: StorageType,
head: Option<Subtree<Node>>,
leaves_count: u64,
phantom_table: PhantomData<TableType>,
}
impl<TableType, StorageType> MerkleTree<TableType, StorageType> {
pub const fn empty_root() -> &'static Bytes32 {
empty_sum()
}
pub fn root(&self) -> Bytes32 {
let mut scratch_storage = StorageMap::<NodesTable>::new();
let root_node = self.root_node(&mut scratch_storage);
match root_node {
None => *Self::empty_root(),
Some(ref node) => *node.hash(),
}
}
fn head(&self) -> Option<&Subtree<Node>> {
self.head.as_ref()
}
pub fn leaves_count(&self) -> u64 {
self.leaves_count
}
fn root_node(&self, scratch_storage: &mut StorageMap<NodesTable>) -> Option<Node> {
self.head()
.map(|head| build_root_node(head, scratch_storage))
}
fn peak_positions(&self) -> Vec<Position> {
let leaves_count = self.leaves_count + 1;
let leaf_position = Position::from_leaf_index(leaves_count - 1);
let root_position = self.root_position();
let mut peaks_itr = root_position.path(&leaf_position, leaves_count).iter();
peaks_itr.next(); let (_, peaks): (Vec<_>, Vec<_>) = peaks_itr.unzip();
peaks
}
fn root_position(&self) -> Position {
let leaves_count = self.leaves_count + 1;
let root_index = leaves_count.next_power_of_two() - 1;
Position::from_in_order_index(root_index)
}
}
impl<TableType, StorageType, StorageError> MerkleTree<TableType, StorageType>
where
TableType: Mappable<Key = u64, Value = Primitive, OwnedValue = Primitive>,
StorageType: StorageInspect<TableType, Error = StorageError>,
{
pub fn new(storage: StorageType) -> Self {
Self {
storage,
head: None,
leaves_count: 0,
phantom_table: Default::default(),
}
}
pub fn load(
storage: StorageType,
leaves_count: u64,
) -> Result<Self, MerkleTreeError<StorageError>> {
let mut tree = Self {
storage,
head: None,
leaves_count,
phantom_table: Default::default(),
};
tree.build()?;
Ok(tree)
}
pub fn prove(
&self,
proof_index: u64,
) -> Result<(Bytes32, ProofSet), MerkleTreeError<StorageError>> {
if proof_index + 1 > self.leaves_count {
return Err(MerkleTreeError::InvalidProofIndex(proof_index))
}
let mut proof_set = ProofSet::new();
let root_position = self.root_position();
let leaf_position = Position::from_leaf_index(proof_index);
let (_, mut side_positions): (Vec<_>, Vec<_>) = root_position
.path(&leaf_position, self.leaves_count)
.iter()
.unzip();
side_positions.reverse(); side_positions.pop(); let mut scratch_storage = StorageMap::<NodesTable>::new();
let root_node = self
.root_node(&mut scratch_storage)
.expect("Root node must be present");
for side_position in side_positions {
let key = side_position.in_order_index();
let primitive = StorageInspectInfallible::get(&scratch_storage, &key)
.or(StorageInspect::get(&self.storage, &key)?)
.ok_or(MerkleTreeError::LoadError(key))?
.into_owned();
let node = Node::from(primitive);
proof_set.push(*node.hash());
}
let root = *root_node.hash();
Ok((root, proof_set))
}
pub fn reset(&mut self) {
self.leaves_count = 0;
self.head = None;
}
fn build(&mut self) -> Result<(), MerkleTreeError<StorageError>> {
let mut current_head = None;
let peaks = &self.peak_positions();
for peak in peaks.iter() {
let key = peak.in_order_index();
let node = self
.storage
.get(&key)?
.ok_or(MerkleTreeError::LoadError(key))?
.into_owned()
.into();
let next = Subtree::new(node, current_head);
current_head = Some(next);
}
self.head = current_head;
Ok(())
}
}
impl<TableType, StorageType, StorageError> MerkleTree<TableType, StorageType>
where
TableType: Mappable<Key = u64, Value = Primitive, OwnedValue = Primitive>,
StorageType: StorageMutate<TableType, Error = StorageError>,
{
pub fn push(&mut self, data: &[u8]) -> Result<(), StorageError> {
let node = Node::create_leaf(self.leaves_count, data);
self.storage.insert(&node.key(), &node.as_ref().into())?;
let next = self.head.take();
let head = Subtree::new(node, next);
self.head = Some(head);
self.join_all_subtrees()?;
self.leaves_count += 1;
Ok(())
}
fn join_all_subtrees(&mut self) -> Result<(), StorageError> {
while {
if let Some((head, next)) = self
.head()
.and_then(|head| head.next().map(|next| (head, next)))
{
head.node().height() == next.node().height()
} else {
false
}
} {
let mut head = self.head.take().expect("Expected head to be present");
let mut head_next = head.take_next().expect("Expected next to be present");
let joined_head = join_subtrees(&mut head_next, &mut head);
self.storage
.insert(&joined_head.node().key(), &joined_head.node().into())?;
self.head = Some(joined_head);
}
Ok(())
}
}
fn join_subtrees(lhs: &mut Subtree<Node>, rhs: &mut Subtree<Node>) -> Subtree<Node> {
let joined_node = Node::create_node(lhs.node(), rhs.node());
Subtree::new(joined_node, lhs.take_next())
}
fn build_root_node<Table, Storage>(subtree: &Subtree<Node>, storage: &mut Storage) -> Node
where
Table: Mappable<Key = u64, OwnedValue = Primitive, Value = Primitive>,
Storage: StorageMutateInfallible<Table>,
{
let mut head = subtree.clone();
while let Some(mut head_next) = head.take_next() {
head = join_subtrees(&mut head_next, &mut head);
storage.insert(&head.node().key(), &head.node().into());
}
head.node().clone()
}
#[cfg(test)]
mod test {
use super::{
MerkleTree,
MerkleTreeError,
};
use crate::{
binary::{
empty_sum,
leaf_sum,
node_sum,
Node,
Primitive,
},
common::StorageMap,
};
use fuel_merkle_test_helpers::TEST_DATA;
use fuel_storage::{
Mappable,
StorageInspect,
};
use alloc::vec::Vec;
#[derive(Debug)]
struct TestTable;
impl Mappable for TestTable {
type Key = Self::OwnedKey;
type OwnedKey = u64;
type OwnedValue = Primitive;
type Value = Self::OwnedValue;
}
#[test]
fn test_push_builds_internal_tree_structure() {
let mut storage_map = StorageMap::<TestTable>::new();
let mut tree = MerkleTree::new(&mut storage_map);
let data = &TEST_DATA[0..7]; for datum in data.iter() {
let _ = tree.push(datum);
}
let leaf_0 = leaf_sum(data[0]);
let leaf_1 = leaf_sum(data[1]);
let leaf_2 = leaf_sum(data[2]);
let leaf_3 = leaf_sum(data[3]);
let leaf_4 = leaf_sum(data[4]);
let leaf_5 = leaf_sum(data[5]);
let leaf_6 = leaf_sum(data[6]);
let node_1 = node_sum(&leaf_0, &leaf_1);
let node_5 = node_sum(&leaf_2, &leaf_3);
let node_3 = node_sum(&node_1, &node_5);
let node_9 = node_sum(&leaf_4, &leaf_5);
let s_leaf_0 = storage_map.get(&0).unwrap().unwrap();
let s_leaf_1 = storage_map.get(&2).unwrap().unwrap();
let s_leaf_2 = storage_map.get(&4).unwrap().unwrap();
let s_leaf_3 = storage_map.get(&6).unwrap().unwrap();
let s_leaf_4 = storage_map.get(&8).unwrap().unwrap();
let s_leaf_5 = storage_map.get(&10).unwrap().unwrap();
let s_leaf_6 = storage_map.get(&12).unwrap().unwrap();
let s_node_1 = storage_map.get(&1).unwrap().unwrap();
let s_node_5 = storage_map.get(&5).unwrap().unwrap();
let s_node_9 = storage_map.get(&9).unwrap().unwrap();
let s_node_3 = storage_map.get(&3).unwrap().unwrap();
assert_eq!(*Node::from(s_leaf_0.into_owned()).hash(), leaf_0);
assert_eq!(*Node::from(s_leaf_1.into_owned()).hash(), leaf_1);
assert_eq!(*Node::from(s_leaf_2.into_owned()).hash(), leaf_2);
assert_eq!(*Node::from(s_leaf_3.into_owned()).hash(), leaf_3);
assert_eq!(*Node::from(s_leaf_4.into_owned()).hash(), leaf_4);
assert_eq!(*Node::from(s_leaf_5.into_owned()).hash(), leaf_5);
assert_eq!(*Node::from(s_leaf_6.into_owned()).hash(), leaf_6);
assert_eq!(*Node::from(s_node_1.into_owned()).hash(), node_1);
assert_eq!(*Node::from(s_node_5.into_owned()).hash(), node_5);
assert_eq!(*Node::from(s_node_9.into_owned()).hash(), node_9);
assert_eq!(*Node::from(s_node_3.into_owned()).hash(), node_3);
}
#[test]
fn load_returns_a_valid_tree() {
const LEAVES_COUNT: u64 = 2u64.pow(16) - 1;
let mut storage_map = StorageMap::<TestTable>::new();
let expected_root = {
let mut tree = MerkleTree::new(&mut storage_map);
let data = (0u64..LEAVES_COUNT)
.map(|i| i.to_be_bytes())
.collect::<Vec<_>>();
for datum in data.iter() {
let _ = tree.push(datum);
}
tree.root()
};
let root = {
let tree = MerkleTree::load(&mut storage_map, LEAVES_COUNT).unwrap();
tree.root()
};
assert_eq!(expected_root, root);
}
#[test]
fn load_returns_empty_tree_for_0_leaves() {
const LEAVES_COUNT: u64 = 0;
let expected_root = *MerkleTree::<(), ()>::empty_root();
let root = {
let mut storage_map = StorageMap::<TestTable>::new();
let tree = MerkleTree::load(&mut storage_map, LEAVES_COUNT).unwrap();
tree.root()
};
assert_eq!(expected_root, root);
}
#[test]
fn load_returns_a_load_error_if_the_storage_is_not_valid_for_the_leaves_count() {
const LEAVES_COUNT: u64 = 5;
let mut storage_map = StorageMap::<TestTable>::new();
let mut tree = MerkleTree::new(&mut storage_map);
let data = (0u64..LEAVES_COUNT)
.map(|i| i.to_be_bytes())
.collect::<Vec<_>>();
for datum in data.iter() {
let _ = tree.push(datum);
}
let err = MerkleTree::load(&mut storage_map, LEAVES_COUNT * 2)
.expect_err("Expected load() to return Error; got Ok");
assert!(matches!(err, MerkleTreeError::LoadError(_)));
}
#[test]
fn root_returns_the_empty_root_for_0_leaves() {
let mut storage_map = StorageMap::<TestTable>::new();
let tree = MerkleTree::new(&mut storage_map);
let root = tree.root();
assert_eq!(root, empty_sum().clone());
}
#[test]
fn root_returns_the_merkle_root_for_1_leaf() {
let mut storage_map = StorageMap::<TestTable>::new();
let mut tree = MerkleTree::new(&mut storage_map);
let data = &TEST_DATA[0..1]; for datum in data.iter() {
let _ = tree.push(datum);
}
let leaf_0 = leaf_sum(data[0]);
let root = tree.root();
assert_eq!(root, leaf_0);
}
#[test]
fn root_returns_the_merkle_root_for_7_leaves() {
let mut storage_map = StorageMap::<TestTable>::new();
let mut tree = MerkleTree::new(&mut storage_map);
let data = &TEST_DATA[0..7]; for datum in data.iter() {
let _ = tree.push(datum);
}
let leaf_0 = leaf_sum(data[0]);
let leaf_1 = leaf_sum(data[1]);
let leaf_2 = leaf_sum(data[2]);
let leaf_3 = leaf_sum(data[3]);
let leaf_4 = leaf_sum(data[4]);
let leaf_5 = leaf_sum(data[5]);
let leaf_6 = leaf_sum(data[6]);
let node_1 = node_sum(&leaf_0, &leaf_1);
let node_5 = node_sum(&leaf_2, &leaf_3);
let node_3 = node_sum(&node_1, &node_5);
let node_9 = node_sum(&leaf_4, &leaf_5);
let node_11 = node_sum(&node_9, &leaf_6);
let node_7 = node_sum(&node_3, &node_11);
let root = tree.root();
assert_eq!(root, node_7);
}
#[test]
fn prove_returns_invalid_proof_index_error_for_0_leaves() {
let mut storage_map = StorageMap::<TestTable>::new();
let tree = MerkleTree::new(&mut storage_map);
let err = tree
.prove(0)
.expect_err("Expected prove() to return Error; got Ok");
assert!(matches!(err, MerkleTreeError::InvalidProofIndex(0)));
}
#[test]
fn prove_returns_invalid_proof_index_error_when_index_is_greater_than_number_of_leaves(
) {
let mut storage_map = StorageMap::<TestTable>::new();
let mut tree = MerkleTree::new(&mut storage_map);
let data = &TEST_DATA[0..5]; for datum in data.iter() {
let _ = tree.push(datum);
}
let err = tree
.prove(10)
.expect_err("Expected prove() to return Error; got Ok");
assert!(matches!(err, MerkleTreeError::InvalidProofIndex(10)))
}
#[test]
fn prove_returns_the_merkle_root_and_proof_set_for_1_leaf() {
let mut storage_map = StorageMap::<TestTable>::new();
let mut tree = MerkleTree::new(&mut storage_map);
let data = &TEST_DATA[0..1]; for datum in data.iter() {
let _ = tree.push(datum);
}
let leaf_0 = leaf_sum(data[0]);
{
let (root, proof_set) = tree.prove(0).unwrap();
assert_eq!(root, leaf_0);
assert!(proof_set.is_empty());
}
}
#[test]
fn prove_returns_the_merkle_root_and_proof_set_for_4_leaves() {
let mut storage_map = StorageMap::<TestTable>::new();
let mut tree = MerkleTree::new(&mut storage_map);
let data = &TEST_DATA[0..4]; for datum in data.iter() {
let _ = tree.push(datum);
}
let leaf_0 = leaf_sum(data[0]);
let leaf_1 = leaf_sum(data[1]);
let leaf_2 = leaf_sum(data[2]);
let leaf_3 = leaf_sum(data[3]);
let node_1 = node_sum(&leaf_0, &leaf_1);
let node_5 = node_sum(&leaf_2, &leaf_3);
let node_3 = node_sum(&node_1, &node_5);
{
let (root, proof_set) = tree.prove(0).unwrap();
assert_eq!(root, node_3);
assert_eq!(proof_set[0], leaf_1);
assert_eq!(proof_set[1], node_5);
}
{
let (root, proof_set) = tree.prove(1).unwrap();
assert_eq!(root, node_3);
assert_eq!(proof_set[0], leaf_0);
assert_eq!(proof_set[1], node_5);
}
{
let (root, proof_set) = tree.prove(2).unwrap();
assert_eq!(root, node_3);
assert_eq!(proof_set[0], leaf_3);
assert_eq!(proof_set[1], node_1);
}
{
let (root, proof_set) = tree.prove(3).unwrap();
assert_eq!(root, node_3);
assert_eq!(proof_set[0], leaf_2);
assert_eq!(proof_set[1], node_1);
}
}
#[test]
fn prove_returns_the_merkle_root_and_proof_set_for_5_leaves() {
let mut storage_map = StorageMap::<TestTable>::new();
let mut tree = MerkleTree::new(&mut storage_map);
let data = &TEST_DATA[0..5]; for datum in data.iter() {
let _ = tree.push(datum);
}
let leaf_0 = leaf_sum(data[0]);
let leaf_1 = leaf_sum(data[1]);
let leaf_2 = leaf_sum(data[2]);
let leaf_3 = leaf_sum(data[3]);
let leaf_4 = leaf_sum(data[4]);
let node_1 = node_sum(&leaf_0, &leaf_1);
let node_5 = node_sum(&leaf_2, &leaf_3);
let node_3 = node_sum(&node_1, &node_5);
let node_7 = node_sum(&node_3, &leaf_4);
{
let (root, proof_set) = tree.prove(0).unwrap();
assert_eq!(root, node_7);
assert_eq!(proof_set[0], leaf_1);
assert_eq!(proof_set[1], node_5);
assert_eq!(proof_set[2], leaf_4);
}
{
let (root, proof_set) = tree.prove(1).unwrap();
assert_eq!(root, node_7);
assert_eq!(proof_set[0], leaf_0);
assert_eq!(proof_set[1], node_5);
assert_eq!(proof_set[2], leaf_4);
}
{
let (root, proof_set) = tree.prove(2).unwrap();
assert_eq!(root, node_7);
assert_eq!(proof_set[0], leaf_3);
assert_eq!(proof_set[1], node_1);
assert_eq!(proof_set[2], leaf_4);
}
{
let (root, proof_set) = tree.prove(3).unwrap();
assert_eq!(root, node_7);
assert_eq!(proof_set[0], leaf_2);
assert_eq!(proof_set[1], node_1);
assert_eq!(proof_set[2], leaf_4);
}
{
let (root, proof_set) = tree.prove(4).unwrap();
assert_eq!(root, node_7);
assert_eq!(proof_set[0], node_3);
}
}
#[test]
fn prove_returns_the_merkle_root_and_proof_set_for_7_leaves() {
let mut storage_map = StorageMap::<TestTable>::new();
let mut tree = MerkleTree::new(&mut storage_map);
let data = &TEST_DATA[0..7]; for datum in data.iter() {
let _ = tree.push(datum);
}
let leaf_0 = leaf_sum(data[0]);
let leaf_1 = leaf_sum(data[1]);
let leaf_2 = leaf_sum(data[2]);
let leaf_3 = leaf_sum(data[3]);
let leaf_4 = leaf_sum(data[4]);
let leaf_5 = leaf_sum(data[5]);
let leaf_6 = leaf_sum(data[6]);
let node_1 = node_sum(&leaf_0, &leaf_1);
let node_5 = node_sum(&leaf_2, &leaf_3);
let node_3 = node_sum(&node_1, &node_5);
let node_9 = node_sum(&leaf_4, &leaf_5);
let node_11 = node_sum(&node_9, &leaf_6);
let node_7 = node_sum(&node_3, &node_11);
{
let (root, proof_set) = tree.prove(0).unwrap();
assert_eq!(root, node_7);
assert_eq!(proof_set[0], leaf_1);
assert_eq!(proof_set[1], node_5);
assert_eq!(proof_set[2], node_11);
}
{
let (root, proof_set) = tree.prove(1).unwrap();
assert_eq!(root, node_7);
assert_eq!(proof_set[0], leaf_0);
assert_eq!(proof_set[1], node_5);
assert_eq!(proof_set[2], node_11);
}
{
let (root, proof_set) = tree.prove(2).unwrap();
assert_eq!(root, node_7);
assert_eq!(proof_set[0], leaf_3);
assert_eq!(proof_set[1], node_1);
assert_eq!(proof_set[2], node_11);
}
{
let (root, proof_set) = tree.prove(3).unwrap();
assert_eq!(root, node_7);
assert_eq!(proof_set[0], leaf_2);
assert_eq!(proof_set[1], node_1);
assert_eq!(proof_set[2], node_11);
}
{
let (root, proof_set) = tree.prove(4).unwrap();
assert_eq!(root, node_7);
assert_eq!(proof_set[0], leaf_5);
assert_eq!(proof_set[1], leaf_6);
assert_eq!(proof_set[2], node_3);
}
{
let (root, proof_set) = tree.prove(5).unwrap();
assert_eq!(root, node_7);
assert_eq!(proof_set[0], leaf_4);
assert_eq!(proof_set[1], leaf_6);
assert_eq!(proof_set[2], node_3);
}
{
let (root, proof_set) = tree.prove(6).unwrap();
assert_eq!(root, node_7);
assert_eq!(proof_set[0], node_9);
assert_eq!(proof_set[1], node_3);
}
}
#[test]
fn reset_reverts_tree_to_empty_state() {
let mut storage_map = StorageMap::<TestTable>::new();
let mut tree = MerkleTree::new(&mut storage_map);
let data = &TEST_DATA[0..4]; for datum in data.iter() {
let _ = tree.push(datum);
}
tree.reset();
let root = tree.root();
let expected_root = *MerkleTree::<(), ()>::empty_root();
assert_eq!(root, expected_root);
let data = &TEST_DATA[0..4]; for datum in data.iter() {
let _ = tree.push(datum);
}
let leaf_0 = leaf_sum(data[0]);
let leaf_1 = leaf_sum(data[1]);
let leaf_2 = leaf_sum(data[2]);
let leaf_3 = leaf_sum(data[3]);
let node_1 = node_sum(&leaf_0, &leaf_1);
let node_5 = node_sum(&leaf_2, &leaf_3);
let node_3 = node_sum(&node_1, &node_5);
let root = tree.root();
let expected_root = node_3;
assert_eq!(root, expected_root);
}
}