use crate::{
common::{error::DeserializeError, AsPathIterator, Bytes32, ChildError},
sparse::{primitive::Primitive, zero_sum, Node, StorageNode, StorageNodeError},
storage::{Mappable, StorageInspect, StorageMutate},
};
use alloc::{string::String, vec::Vec};
use core::{cmp, iter, marker::PhantomData};
#[derive(Debug, Clone)]
#[cfg_attr(feature = "std", derive(thiserror::Error))]
pub enum MerkleTreeError<StorageError> {
#[cfg_attr(
feature = "std",
error("cannot load node with key {0}; the key is not found in storage")
)]
LoadError(String),
#[cfg_attr(feature = "std", error(transparent))]
StorageError(StorageError),
#[cfg_attr(feature = "std", error(transparent))]
DeserializeError(DeserializeError),
#[cfg_attr(feature = "std", error(transparent))]
ChildError(ChildError<Bytes32, StorageNodeError<StorageError>>),
}
impl<StorageError> From<StorageError> for MerkleTreeError<StorageError> {
fn from(err: StorageError) -> MerkleTreeError<StorageError> {
MerkleTreeError::StorageError(err)
}
}
#[derive(Debug)]
pub struct MerkleTree<TableType, StorageType> {
root_node: Node,
storage: StorageType,
phantom_table: PhantomData<TableType>,
}
impl<TableType, StorageType> MerkleTree<TableType, StorageType> {
pub const fn empty_root() -> Bytes32 {
*zero_sum()
}
pub fn root(&self) -> Bytes32 {
self.root_node().hash()
}
fn root_node(&self) -> &Node {
&self.root_node
}
fn set_root_node(&mut self, node: Node) {
debug_assert!(node.is_leaf() || node.height() == Node::max_height() as u32);
self.root_node = node;
}
}
impl<TableType, StorageType, StorageError> MerkleTree<TableType, StorageType>
where
TableType: Mappable<Key = Bytes32, Value = Primitive, OwnedValue = Primitive>,
StorageType: StorageInspect<TableType, Error = StorageError>,
{
pub fn new(storage: StorageType) -> Self {
Self {
root_node: Node::create_placeholder(),
storage,
phantom_table: Default::default(),
}
}
pub fn load(storage: StorageType, root: &Bytes32) -> Result<Self, MerkleTreeError<StorageError>> {
let primitive = storage
.get(root)?
.ok_or_else(|| MerkleTreeError::LoadError(hex::encode(root)))?
.into_owned();
let tree = Self {
root_node: primitive.try_into().map_err(MerkleTreeError::DeserializeError)?,
storage,
phantom_table: Default::default(),
};
Ok(tree)
}
fn path_set(&self, leaf_node: Node) -> Result<(Vec<Node>, Vec<Node>), MerkleTreeError<StorageError>> {
let root_node = self.root_node().clone();
let root_storage_node = StorageNode::new(&self.storage, root_node);
let leaf_storage_node = StorageNode::new(&self.storage, leaf_node);
let (mut path_nodes, mut side_nodes): (Vec<Node>, Vec<Node>) = root_storage_node
.as_path_iter(&leaf_storage_node)
.map(|(path_node, side_node)| {
Ok((
path_node.map_err(MerkleTreeError::ChildError)?.into_node(),
side_node.map_err(MerkleTreeError::ChildError)?.into_node(),
))
})
.collect::<Result<Vec<_>, MerkleTreeError<StorageError>>>()?
.into_iter()
.unzip();
path_nodes.reverse();
side_nodes.reverse();
side_nodes.pop(); Ok((path_nodes, side_nodes))
}
}
impl<TableType, StorageType, StorageError> MerkleTree<TableType, StorageType>
where
TableType: Mappable<Key = Bytes32, Value = Primitive, OwnedValue = Primitive>,
StorageType: StorageMutate<TableType, Error = StorageError>,
{
pub fn update(&mut self, key: &Bytes32, data: &[u8]) -> Result<(), MerkleTreeError<StorageError>> {
if data.is_empty() {
self.delete(key)?;
return Ok(());
}
let leaf_node = Node::create_leaf(key, data);
self.storage.insert(&leaf_node.hash(), &leaf_node.as_ref().into())?;
self.storage.insert(leaf_node.leaf_key(), &leaf_node.as_ref().into())?;
if self.root_node().is_placeholder() {
self.set_root_node(leaf_node);
} else {
let (path_nodes, side_nodes) = self.path_set(leaf_node.clone())?;
self.update_with_path_set(&leaf_node, path_nodes.as_slice(), side_nodes.as_slice())?;
}
Ok(())
}
pub fn delete(&mut self, key: &Bytes32) -> Result<(), MerkleTreeError<StorageError>> {
if self.root() == Self::empty_root() {
return Ok(());
}
if let Some(primitive) = self.storage.get(key)? {
let primitive = primitive.into_owned();
let leaf_node: Node = primitive.try_into().map_err(MerkleTreeError::DeserializeError)?;
let (path_nodes, side_nodes): (Vec<Node>, Vec<Node>) = self.path_set(leaf_node.clone())?;
self.delete_with_path_set(&leaf_node, path_nodes.as_slice(), side_nodes.as_slice())?;
}
Ok(())
}
fn update_with_path_set(
&mut self,
requested_leaf_node: &Node,
path_nodes: &[Node],
side_nodes: &[Node],
) -> Result<(), StorageError> {
let path = requested_leaf_node.leaf_key();
let actual_leaf_node = &path_nodes[0];
let mut current_node = requested_leaf_node.clone();
if requested_leaf_node.leaf_key() != actual_leaf_node.leaf_key() {
if !actual_leaf_node.is_placeholder() {
current_node = Node::create_node_on_path(path, ¤t_node, actual_leaf_node);
self.storage
.insert(¤t_node.hash(), ¤t_node.as_ref().into())?;
}
let ancestor_depth = requested_leaf_node.common_path_length(actual_leaf_node);
let stale_depth = cmp::max(side_nodes.len(), ancestor_depth);
let placeholders_count = stale_depth - side_nodes.len();
let placeholders = iter::repeat(Node::create_placeholder()).take(placeholders_count);
for placeholder in placeholders {
current_node = Node::create_node_on_path(path, ¤t_node, &placeholder);
self.storage
.insert(¤t_node.hash(), ¤t_node.as_ref().into())?;
}
}
for side_node in side_nodes {
current_node = Node::create_node_on_path(path, ¤t_node, side_node);
self.storage
.insert(¤t_node.hash(), ¤t_node.as_ref().into())?;
}
self.set_root_node(current_node);
Ok(())
}
fn delete_with_path_set(
&mut self,
requested_leaf_node: &Node,
path_nodes: &[Node],
side_nodes: &[Node],
) -> Result<(), StorageError> {
for node in path_nodes {
self.storage.remove(&node.hash())?;
}
let path = requested_leaf_node.leaf_key();
let mut side_nodes_iter = side_nodes.iter();
let mut current_node = Node::create_placeholder();
if let Some(first_side_node) = side_nodes.first() {
if first_side_node.is_leaf() {
side_nodes_iter.next();
current_node = first_side_node.clone();
if let Some(side_node) = side_nodes_iter.find(|side_node| !side_node.is_placeholder()) {
current_node = Node::create_node_on_path(path, ¤t_node, side_node);
self.storage
.insert(¤t_node.hash(), ¤t_node.as_ref().into())?;
}
}
}
for side_node in side_nodes_iter {
current_node = Node::create_node_on_path(path, ¤t_node, side_node);
self.storage
.insert(¤t_node.hash(), ¤t_node.as_ref().into())?;
}
self.set_root_node(current_node);
Ok(())
}
}
#[cfg(test)]
mod test {
use crate::{
common::{Bytes32, StorageMap},
sparse::{hash::sum, MerkleTree, MerkleTreeError, Primitive},
};
use fuel_storage::Mappable;
use hex;
#[derive(Debug)]
struct TestTable;
impl Mappable for TestTable {
type Key = Self::OwnedKey;
type OwnedKey = Bytes32;
type Value = Self::OwnedValue;
type OwnedValue = Primitive;
}
#[test]
fn test_empty_root() {
let mut storage = StorageMap::<TestTable>::new();
let tree = MerkleTree::new(&mut storage);
let root = tree.root();
let expected_root = "0000000000000000000000000000000000000000000000000000000000000000";
assert_eq!(hex::encode(root), expected_root);
}
#[test]
fn test_update_1() {
let mut storage = StorageMap::<TestTable>::new();
let mut tree = MerkleTree::new(&mut storage);
tree.update(&sum(b"\x00\x00\x00\x00"), b"DATA").unwrap();
let root = tree.root();
let expected_root = "39f36a7cb4dfb1b46f03d044265df6a491dffc1034121bc1071a34ddce9bb14b";
assert_eq!(hex::encode(root), expected_root);
}
#[test]
fn test_update_2() {
let mut storage = StorageMap::<TestTable>::new();
let mut tree = MerkleTree::new(&mut storage);
tree.update(&sum(b"\x00\x00\x00\x00"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x01"), b"DATA").unwrap();
let root = tree.root();
let expected_root = "8d0ae412ca9ca0afcb3217af8bcd5a673e798bd6fd1dfacad17711e883f494cb";
assert_eq!(hex::encode(root), expected_root);
}
#[test]
fn test_update_3() {
let mut storage = StorageMap::<TestTable>::new();
let mut tree = MerkleTree::new(&mut storage);
tree.update(&sum(b"\x00\x00\x00\x00"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x01"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x02"), b"DATA").unwrap();
let root = tree.root();
let expected_root = "52295e42d8de2505fdc0cc825ff9fead419cbcf540d8b30c7c4b9c9b94c268b7";
assert_eq!(hex::encode(root), expected_root);
}
#[test]
fn test_update_5() {
let mut storage = StorageMap::<TestTable>::new();
let mut tree = MerkleTree::new(&mut storage);
tree.update(&sum(b"\x00\x00\x00\x00"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x01"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x02"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x03"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x04"), b"DATA").unwrap();
let root = tree.root();
let expected_root = "108f731f2414e33ae57e584dc26bd276db07874436b2264ca6e520c658185c6b";
assert_eq!(hex::encode(root), expected_root);
}
#[test]
fn test_update_10() {
let mut storage = StorageMap::<TestTable>::new();
let mut tree = MerkleTree::new(&mut storage);
for i in 0_u32..10 {
let key = sum(i.to_be_bytes());
tree.update(&key, b"DATA").unwrap();
}
let root = tree.root();
let expected_root = "21ca4917e99da99a61de93deaf88c400d4c082991cb95779e444d43dd13e8849";
assert_eq!(hex::encode(root), expected_root);
}
#[test]
fn test_update_100() {
let mut storage = StorageMap::<TestTable>::new();
let mut tree = MerkleTree::new(&mut storage);
for i in 0_u32..100 {
let key = sum(i.to_be_bytes());
tree.update(&key, b"DATA").unwrap();
}
let root = tree.root();
let expected_root = "82bf747d455a55e2f7044a03536fc43f1f55d43b855e72c0110c986707a23e4d";
assert_eq!(hex::encode(root), expected_root);
}
#[test]
fn test_update_with_repeated_inputs() {
let mut storage = StorageMap::<TestTable>::new();
let mut tree = MerkleTree::new(&mut storage);
tree.update(&sum(b"\x00\x00\x00\x00"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x00"), b"DATA").unwrap();
let root = tree.root();
let expected_root = "39f36a7cb4dfb1b46f03d044265df6a491dffc1034121bc1071a34ddce9bb14b";
assert_eq!(hex::encode(root), expected_root);
}
#[test]
fn test_update_overwrite_key() {
let mut storage = StorageMap::<TestTable>::new();
let mut tree = MerkleTree::new(&mut storage);
tree.update(&sum(b"\x00\x00\x00\x00"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x00"), b"CHANGE").unwrap();
let root = tree.root();
let expected_root = "dd97174c80e5e5aa3a31c61b05e279c1495c8a07b2a08bca5dbc9fb9774f9457";
assert_eq!(hex::encode(root), expected_root);
}
#[test]
fn test_update_union() {
let mut storage = StorageMap::<TestTable>::new();
let mut tree = MerkleTree::new(&mut storage);
for i in 0_u32..5 {
let key = sum(i.to_be_bytes());
tree.update(&key, b"DATA").unwrap();
}
for i in 10_u32..15 {
let key = sum(i.to_be_bytes());
tree.update(&key, b"DATA").unwrap();
}
for i in 20_u32..25 {
let key = sum(i.to_be_bytes());
tree.update(&key, b"DATA").unwrap();
}
let root = tree.root();
let expected_root = "7e6643325042cfe0fc76626c043b97062af51c7e9fc56665f12b479034bce326";
assert_eq!(hex::encode(root), expected_root);
}
#[test]
fn test_update_sparse_union() {
let mut storage = StorageMap::<TestTable>::new();
let mut tree = MerkleTree::new(&mut storage);
tree.update(&sum(b"\x00\x00\x00\x00"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x02"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x04"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x06"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x08"), b"DATA").unwrap();
let root = tree.root();
let expected_root = "e912e97abc67707b2e6027338292943b53d01a7fbd7b244674128c7e468dd696";
assert_eq!(hex::encode(root), expected_root);
}
#[test]
fn test_update_with_empty_data() {
let mut storage = StorageMap::<TestTable>::new();
let mut tree = MerkleTree::new(&mut storage);
tree.update(&sum(b"\x00\x00\x00\x00"), b"").unwrap();
let root = tree.root();
let expected_root = "0000000000000000000000000000000000000000000000000000000000000000";
assert_eq!(hex::encode(root), expected_root);
}
#[test]
fn test_update_with_empty_performs_delete() {
let mut storage = StorageMap::<TestTable>::new();
let mut tree = MerkleTree::new(&mut storage);
tree.update(&sum(b"\x00\x00\x00\x00"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x00"), b"").unwrap();
let root = tree.root();
let expected_root = "0000000000000000000000000000000000000000000000000000000000000000";
assert_eq!(hex::encode(root), expected_root);
}
#[test]
fn test_update_1_delete_1() {
let mut storage = StorageMap::<TestTable>::new();
let mut tree = MerkleTree::new(&mut storage);
tree.update(&sum(b"\x00\x00\x00\x00"), b"DATA").unwrap();
tree.delete(&sum(b"\x00\x00\x00\x00")).unwrap();
let root = tree.root();
let expected_root = "0000000000000000000000000000000000000000000000000000000000000000";
assert_eq!(hex::encode(root), expected_root);
}
#[test]
fn test_update_2_delete_1() {
let mut storage = StorageMap::<TestTable>::new();
let mut tree = MerkleTree::new(&mut storage);
tree.update(&sum(b"\x00\x00\x00\x00"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x01"), b"DATA").unwrap();
tree.delete(&sum(b"\x00\x00\x00\x01")).unwrap();
let root = tree.root();
let expected_root = "39f36a7cb4dfb1b46f03d044265df6a491dffc1034121bc1071a34ddce9bb14b";
assert_eq!(hex::encode(root), expected_root);
}
#[test]
fn test_update_10_delete_5() {
let mut storage = StorageMap::<TestTable>::new();
let mut tree = MerkleTree::new(&mut storage);
for i in 0_u32..10 {
let key = sum(i.to_be_bytes());
tree.update(&key, b"DATA").unwrap();
}
for i in 5_u32..10 {
let key = sum(i.to_be_bytes());
tree.delete(&key).unwrap();
}
let root = tree.root();
let expected_root = "108f731f2414e33ae57e584dc26bd276db07874436b2264ca6e520c658185c6b";
assert_eq!(hex::encode(root), expected_root);
}
#[test]
fn test_delete_non_existent_key() {
let mut storage = StorageMap::<TestTable>::new();
let mut tree = MerkleTree::new(&mut storage);
tree.update(&sum(b"\x00\x00\x00\x00"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x01"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x02"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x03"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x04"), b"DATA").unwrap();
tree.delete(&sum(b"\x00\x00\x04\x00")).unwrap();
let root = tree.root();
let expected_root = "108f731f2414e33ae57e584dc26bd276db07874436b2264ca6e520c658185c6b";
assert_eq!(hex::encode(root), expected_root);
}
#[test]
fn test_interleaved_update_delete() {
let mut storage = StorageMap::<TestTable>::new();
let mut tree = MerkleTree::new(&mut storage);
for i in 0_u32..10 {
let key = sum(i.to_be_bytes());
tree.update(&key, b"DATA").unwrap();
}
for i in 5_u32..15 {
let key = sum(i.to_be_bytes());
tree.delete(&key).unwrap();
}
for i in 10_u32..20 {
let key = sum(i.to_be_bytes());
tree.update(&key, b"DATA").unwrap();
}
for i in 15_u32..25 {
let key = sum(i.to_be_bytes());
tree.delete(&key).unwrap();
}
for i in 20_u32..30 {
let key = sum(i.to_be_bytes());
tree.update(&key, b"DATA").unwrap();
}
for i in 25_u32..35 {
let key = sum(i.to_be_bytes());
tree.delete(&key).unwrap();
}
let root = tree.root();
let expected_root = "7e6643325042cfe0fc76626c043b97062af51c7e9fc56665f12b479034bce326";
assert_eq!(hex::encode(root), expected_root);
}
#[test]
fn test_delete_sparse_union() {
let mut storage = StorageMap::<TestTable>::new();
let mut tree = MerkleTree::new(&mut storage);
for i in 0_u32..10 {
let key = sum(i.to_be_bytes());
tree.update(&key, b"DATA").unwrap();
}
for i in 0_u32..5 {
let key = sum((i * 2 + 1).to_be_bytes());
tree.delete(&key).unwrap();
}
let root = tree.root();
let expected_root = "e912e97abc67707b2e6027338292943b53d01a7fbd7b244674128c7e468dd696";
assert_eq!(hex::encode(root), expected_root);
}
#[test]
fn test_load_returns_a_valid_tree() {
let (mut storage_to_load, root_to_load) = {
let mut storage = StorageMap::<TestTable>::new();
let mut tree = MerkleTree::new(&mut storage);
tree.update(&sum(b"\x00\x00\x00\x00"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x01"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x02"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x03"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x04"), b"DATA").unwrap();
let root = tree.root();
(storage, root)
};
let expected_root = {
let mut storage = StorageMap::<TestTable>::new();
let mut tree = MerkleTree::new(&mut storage);
tree.update(&sum(b"\x00\x00\x00\x00"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x01"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x02"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x03"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x04"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x05"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x06"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x07"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x08"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x09"), b"DATA").unwrap();
tree.root()
};
let root = {
let mut tree = MerkleTree::load(&mut storage_to_load, &root_to_load).unwrap();
tree.update(&sum(b"\x00\x00\x00\x05"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x06"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x07"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x08"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x09"), b"DATA").unwrap();
tree.root()
};
assert_eq!(root, expected_root);
}
#[test]
fn test_load_returns_a_load_error_if_the_storage_is_not_valid_for_the_root() {
let mut storage = StorageMap::<TestTable>::new();
{
let mut tree = MerkleTree::new(&mut storage);
tree.update(&sum(b"\x00\x00\x00\x00"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x01"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x02"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x03"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x04"), b"DATA").unwrap();
}
let root = &sum(b"\xff\xff\xff\xff");
let err = MerkleTree::load(&mut storage, root).expect_err("Expected load() to return Error; got Ok");
assert!(matches!(err, MerkleTreeError::LoadError(_)));
}
#[test]
fn test_load_returns_a_deserialize_error_if_the_storage_is_corrupted() {
use fuel_storage::StorageMutate;
let mut storage = StorageMap::<TestTable>::new();
let mut tree = MerkleTree::new(&mut storage);
tree.update(&sum(b"\x00\x00\x00\x00"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x01"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x02"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x03"), b"DATA").unwrap();
tree.update(&sum(b"\x00\x00\x00\x04"), b"DATA").unwrap();
let root = tree.root();
let primitive = (0xff, 0xff, [0xff; 32], [0xff; 32]);
storage.insert(&root, &primitive).unwrap();
let err = MerkleTree::load(&mut storage, &root).expect_err("Expected load() to return Error; got Ok");
assert!(matches!(err, MerkleTreeError::DeserializeError(_)));
}
}