1use super::{ProofToHashes, ProvingTrie, TrieError};
24use crate::{Decode, DispatchError, Encode};
25use binary_merkle_tree::{merkle_proof, merkle_root, MerkleProof};
26use codec::MaxEncodedLen;
27use sp_std::{collections::btree_map::BTreeMap, vec::Vec};
28
29pub struct BasicProvingTrie<Hashing, Key, Value>
32where
33 Hashing: sp_core::Hasher,
34{
35 db: BTreeMap<Key, Value>,
36 root: Hashing::Out,
37 _phantom: core::marker::PhantomData<(Key, Value)>,
38}
39
40impl<Hashing, Key, Value> ProvingTrie<Hashing, Key, Value> for BasicProvingTrie<Hashing, Key, Value>
41where
42 Hashing: sp_core::Hasher,
43 Hashing::Out: Encode + Decode,
44 Key: Encode + Decode + Ord,
45 Value: Encode + Decode + Clone,
46{
47 fn generate_for<I>(items: I) -> Result<Self, DispatchError>
49 where
50 I: IntoIterator<Item = (Key, Value)>,
51 {
52 let mut db = BTreeMap::default();
53 for (key, value) in items.into_iter() {
54 db.insert(key, value);
55 }
56 let root = merkle_root::<Hashing, _>(db.iter().map(|item| item.encode()));
57 Ok(Self { db, root, _phantom: Default::default() })
58 }
59
60 fn root(&self) -> &Hashing::Out {
62 &self.root
63 }
64
65 fn query(&self, key: &Key) -> Option<Value> {
68 self.db.get(&key).cloned()
69 }
70
71 fn create_proof(&self, key: &Key) -> Result<Vec<u8>, DispatchError> {
74 let mut encoded = Vec::with_capacity(self.db.len());
75 let mut found_index = None;
76
77 for (i, (k, v)) in self.db.iter().enumerate() {
79 if k == key {
81 found_index = Some(i);
82 }
83
84 encoded.push((k, v).encode());
85 }
86
87 let index = found_index.ok_or(TrieError::IncompleteDatabase)?;
88 let proof = merkle_proof::<Hashing, Vec<Vec<u8>>, Vec<u8>>(encoded, index as u32);
89 Ok(proof.encode())
90 }
91
92 fn verify_proof(
94 root: &Hashing::Out,
95 proof: &[u8],
96 key: &Key,
97 value: &Value,
98 ) -> Result<(), DispatchError> {
99 verify_proof::<Hashing, Key, Value>(root, proof, key, value)
100 }
101}
102
103impl<Hashing, Key, Value> ProofToHashes for BasicProvingTrie<Hashing, Key, Value>
104where
105 Hashing: sp_core::Hasher,
106 Hashing::Out: MaxEncodedLen + Decode,
107 Key: Decode,
108 Value: Decode,
109{
110 type Proof = [u8];
112 fn proof_to_hashes(proof: &[u8]) -> Result<u32, DispatchError> {
116 let decoded_proof: MerkleProof<Hashing::Out, Vec<u8>> =
117 Decode::decode(&mut &proof[..]).map_err(|_| TrieError::IncompleteProof)?;
118 let depth = decoded_proof.proof.len();
119 Ok(depth as u32)
120 }
121}
122
123pub fn verify_proof<Hashing, Key, Value>(
125 root: &Hashing::Out,
126 proof: &[u8],
127 key: &Key,
128 value: &Value,
129) -> Result<(), DispatchError>
130where
131 Hashing: sp_core::Hasher,
132 Hashing::Out: Decode,
133 Key: Encode + Decode,
134 Value: Encode + Decode,
135{
136 let decoded_proof: MerkleProof<Hashing::Out, Vec<u8>> =
137 Decode::decode(&mut &proof[..]).map_err(|_| TrieError::IncompleteProof)?;
138 if *root != decoded_proof.root {
139 return Err(TrieError::RootMismatch.into());
140 }
141
142 if (key, value).encode() != decoded_proof.leaf {
143 return Err(TrieError::ValueMismatch.into());
144 }
145
146 if binary_merkle_tree::verify_proof::<Hashing, _, _>(
147 &decoded_proof.root,
148 decoded_proof.proof,
149 decoded_proof.number_of_leaves,
150 decoded_proof.leaf_index,
151 &decoded_proof.leaf,
152 ) {
153 Ok(())
154 } else {
155 Err(TrieError::IncompleteProof.into())
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162 use crate::traits::BlakeTwo256;
163 use sp_core::H256;
164 use sp_std::collections::btree_map::BTreeMap;
165
166 type BalanceTrie = BasicProvingTrie<BlakeTwo256, u32, u128>;
168
169 fn empty_root() -> H256 {
171 let tree = BalanceTrie::generate_for(Vec::new()).unwrap();
172 *tree.root()
173 }
174
175 fn create_balance_trie() -> BalanceTrie {
176 let mut map = BTreeMap::<u32, u128>::new();
178 for i in 0..100u32 {
179 map.insert(i, i.into());
180 }
181
182 let balance_trie = BalanceTrie::generate_for(map).unwrap();
184
185 let root = *balance_trie.root();
187 assert!(root != empty_root());
188
189 assert_eq!(balance_trie.query(&6u32), Some(6u128));
191 assert_eq!(balance_trie.query(&9u32), Some(9u128));
192 assert_eq!(balance_trie.query(&69u32), Some(69u128));
193
194 balance_trie
195 }
196
197 #[test]
198 fn empty_trie_works() {
199 let empty_trie = BalanceTrie::generate_for(Vec::new()).unwrap();
200 assert_eq!(*empty_trie.root(), empty_root());
201 }
202
203 #[test]
204 fn basic_end_to_end_single_value() {
205 let balance_trie = create_balance_trie();
206 let root = *balance_trie.root();
207
208 let proof = balance_trie.create_proof(&6u32).unwrap();
210
211 for i in 0..200u32 {
213 if i == 6 {
214 assert_eq!(
215 verify_proof::<BlakeTwo256, _, _>(&root, &proof, &i, &u128::from(i)),
216 Ok(())
217 );
218 assert_eq!(
220 verify_proof::<BlakeTwo256, _, _>(&root, &proof, &i, &u128::from(i + 1)),
221 Err(TrieError::ValueMismatch.into())
222 );
223 } else {
224 assert!(
225 verify_proof::<BlakeTwo256, _, _>(&root, &proof, &i, &u128::from(i)).is_err()
226 );
227 }
228 }
229 }
230
231 #[test]
232 fn proof_fails_with_bad_data() {
233 let balance_trie = create_balance_trie();
234 let root = *balance_trie.root();
235
236 let proof = balance_trie.create_proof(&6u32).unwrap();
238
239 assert_eq!(verify_proof::<BlakeTwo256, _, _>(&root, &proof, &6u32, &6u128), Ok(()));
241
242 assert_eq!(
244 verify_proof::<BlakeTwo256, _, _>(&Default::default(), &proof, &6u32, &6u128),
245 Err(TrieError::RootMismatch.into())
246 );
247
248 assert_eq!(
250 verify_proof::<BlakeTwo256, _, _>(&root, &[], &6u32, &6u128),
251 Err(TrieError::IncompleteProof.into())
252 );
253 }
254
255 #[test]
258 fn assert_structure_of_merkle_proof() {
259 let balance_trie = create_balance_trie();
260 let root = *balance_trie.root();
261 let proof = balance_trie.create_proof(&6u32).unwrap();
263 let decoded_proof: MerkleProof<H256, Vec<u8>> = Decode::decode(&mut &proof[..]).unwrap();
264
265 let constructed_proof = MerkleProof::<H256, Vec<u8>> {
266 root,
267 proof: decoded_proof.proof.clone(),
268 number_of_leaves: 100,
269 leaf_index: 6,
270 leaf: (6u32, 6u128).encode(),
271 };
272 assert_eq!(constructed_proof, decoded_proof);
273 }
274
275 #[test]
276 fn proof_to_hashes() {
277 let mut i: u32 = 1;
278 while i < 10_000_000 {
279 let trie = BalanceTrie::generate_for((0..i).map(|i| (i, u128::from(i)))).unwrap();
280 let proof = trie.create_proof(&0).unwrap();
281 let hashes = BalanceTrie::proof_to_hashes(&proof).unwrap();
282 let log2 = (i as f64).log2().ceil() as u32;
283
284 assert_eq!(hashes, log2);
285 i = i * 10;
286 }
287 }
288}