1use super::{ProofToHashes, ProvingTrie, TrieError};
28use crate::{Decode, DispatchError, Encode};
29use codec::MaxEncodedLen;
30use sp_std::vec::Vec;
31use sp_trie::{
32 trie_types::{TrieDBBuilder, TrieDBMutBuilderV1},
33 LayoutV1, MemoryDB, Trie, TrieMut,
34};
35
36pub struct BasicProvingTrie<Hashing, Key, Value>
40where
41 Hashing: sp_core::Hasher,
42{
43 db: MemoryDB<Hashing>,
44 root: Hashing::Out,
45 _phantom: core::marker::PhantomData<(Key, Value)>,
46}
47
48impl<Hashing, Key, Value> BasicProvingTrie<Hashing, Key, Value>
49where
50 Hashing: sp_core::Hasher,
51 Key: Encode,
52{
53 pub fn create_multi_proof(&self, keys: &[Key]) -> Result<Vec<u8>, DispatchError> {
59 sp_trie::generate_trie_proof::<LayoutV1<Hashing>, _, _, _>(
60 &self.db,
61 self.root,
62 &keys.into_iter().map(|k| k.encode()).collect::<Vec<Vec<u8>>>(),
63 )
64 .map_err(|err| TrieError::from(*err).into())
65 .map(|structured_proof| structured_proof.encode())
66 }
67}
68
69impl<Hashing, Key, Value> ProvingTrie<Hashing, Key, Value> for BasicProvingTrie<Hashing, Key, Value>
70where
71 Hashing: sp_core::Hasher,
72 Key: Encode,
73 Value: Encode + Decode,
74{
75 fn generate_for<I>(items: I) -> Result<Self, DispatchError>
77 where
78 I: IntoIterator<Item = (Key, Value)>,
79 {
80 let mut db = MemoryDB::default();
81 let mut root = Default::default();
82
83 {
84 let mut trie = TrieDBMutBuilderV1::new(&mut db, &mut root).build();
85 for (key, value) in items.into_iter() {
86 key.using_encoded(|k| value.using_encoded(|v| trie.insert(k, v)))
87 .map_err(|_| "failed to insert into trie")?;
88 }
89 }
90
91 Ok(Self { db, root, _phantom: Default::default() })
92 }
93
94 fn root(&self) -> &Hashing::Out {
96 &self.root
97 }
98
99 fn query(&self, key: &Key) -> Option<Value> {
102 let trie = TrieDBBuilder::new(&self.db, &self.root).build();
103 key.using_encoded(|s| trie.get(s))
104 .ok()?
105 .and_then(|raw| Value::decode(&mut &*raw).ok())
106 }
107
108 fn create_proof(&self, key: &Key) -> Result<Vec<u8>, DispatchError> {
110 sp_trie::generate_trie_proof::<LayoutV1<Hashing>, _, _, _>(
111 &self.db,
112 self.root,
113 &[key.encode()],
114 )
115 .map_err(|err| TrieError::from(*err).into())
116 .map(|structured_proof| structured_proof.encode())
117 }
118
119 fn verify_proof(
121 root: &Hashing::Out,
122 proof: &[u8],
123 key: &Key,
124 value: &Value,
125 ) -> Result<(), DispatchError> {
126 verify_proof::<Hashing, Key, Value>(root, proof, key, value)
127 }
128}
129
130impl<Hashing, Key, Value> ProofToHashes for BasicProvingTrie<Hashing, Key, Value>
131where
132 Hashing: sp_core::Hasher,
133 Hashing::Out: MaxEncodedLen,
134{
135 type Proof = [u8];
137 fn proof_to_hashes(proof: &[u8]) -> Result<u32, DispatchError> {
140 use codec::DecodeLength;
141 let depth =
142 <Vec<Vec<u8>> as DecodeLength>::len(proof).map_err(|_| TrieError::DecodeError)?;
143 Ok(depth as u32)
144 }
145}
146
147pub fn verify_proof<Hashing, Key, Value>(
149 root: &Hashing::Out,
150 proof: &[u8],
151 key: &Key,
152 value: &Value,
153) -> Result<(), DispatchError>
154where
155 Hashing: sp_core::Hasher,
156 Key: Encode,
157 Value: Encode,
158{
159 let structured_proof: Vec<Vec<u8>> =
160 Decode::decode(&mut &proof[..]).map_err(|_| TrieError::DecodeError)?;
161 sp_trie::verify_trie_proof::<LayoutV1<Hashing>, _, _, _>(
162 &root,
163 &structured_proof,
164 &[(key.encode(), Some(value.encode()))],
165 )
166 .map_err(|err| TrieError::from(err).into())
167}
168
169pub fn verify_multi_proof<Hashing, Key, Value>(
171 root: &Hashing::Out,
172 proof: &[u8],
173 items: &[(Key, Value)],
174) -> Result<(), DispatchError>
175where
176 Hashing: sp_core::Hasher,
177 Key: Encode,
178 Value: Encode,
179{
180 let structured_proof: Vec<Vec<u8>> =
181 Decode::decode(&mut &proof[..]).map_err(|_| TrieError::DecodeError)?;
182 let items_encoded = items
183 .into_iter()
184 .map(|(key, value)| (key.encode(), Some(value.encode())))
185 .collect::<Vec<(Vec<u8>, Option<Vec<u8>>)>>();
186
187 sp_trie::verify_trie_proof::<LayoutV1<Hashing>, _, _, _>(
188 &root,
189 &structured_proof,
190 &items_encoded,
191 )
192 .map_err(|err| TrieError::from(err).into())
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198 use crate::traits::BlakeTwo256;
199 use sp_core::H256;
200 use sp_std::collections::btree_map::BTreeMap;
201
202 type BalanceTrie = BasicProvingTrie<BlakeTwo256, u32, u128>;
204
205 fn empty_root() -> H256 {
207 sp_trie::empty_trie_root::<LayoutV1<BlakeTwo256>>()
208 }
209
210 fn create_balance_trie() -> BalanceTrie {
211 let mut map = BTreeMap::<u32, u128>::new();
213 for i in 0..100u32 {
214 map.insert(i, i.into());
215 }
216
217 let balance_trie = BalanceTrie::generate_for(map).unwrap();
219
220 let root = *balance_trie.root();
222 assert!(root != empty_root());
223
224 assert_eq!(balance_trie.query(&6u32), Some(6u128));
226 assert_eq!(balance_trie.query(&9u32), Some(9u128));
227 assert_eq!(balance_trie.query(&69u32), Some(69u128));
228 assert_eq!(balance_trie.query(&6969u32), None);
230
231 balance_trie
232 }
233
234 #[test]
235 fn empty_trie_works() {
236 let empty_trie = BalanceTrie::generate_for(Vec::new()).unwrap();
237 assert_eq!(*empty_trie.root(), empty_root());
238 }
239
240 #[test]
241 fn basic_end_to_end_single_value() {
242 let balance_trie = create_balance_trie();
243 let root = *balance_trie.root();
244
245 let proof = balance_trie.create_proof(&6u32).unwrap();
247
248 for i in 0..200u32 {
250 if i == 6 {
251 assert_eq!(
252 verify_proof::<BlakeTwo256, _, _>(&root, &proof, &i, &u128::from(i)),
253 Ok(())
254 );
255 assert_eq!(
257 verify_proof::<BlakeTwo256, _, _>(&root, &proof, &i, &u128::from(i + 1)),
258 Err(TrieError::RootMismatch.into())
259 );
260 } else {
261 assert!(
262 verify_proof::<BlakeTwo256, _, _>(&root, &proof, &i, &u128::from(i)).is_err()
263 );
264 }
265 }
266 }
267
268 #[test]
269 fn basic_end_to_end_multi() {
270 let balance_trie = create_balance_trie();
271 let root = *balance_trie.root();
272
273 let proof = balance_trie.create_multi_proof(&[6u32, 9u32, 69u32]).unwrap();
275 let items = [(6u32, 6u128), (9u32, 9u128), (69u32, 69u128)];
276
277 assert_eq!(verify_multi_proof::<BlakeTwo256, _, _>(&root, &proof, &items), Ok(()));
278 }
279
280 #[test]
281 fn proof_fails_with_bad_data() {
282 let balance_trie = create_balance_trie();
283 let root = *balance_trie.root();
284
285 let proof = balance_trie.create_proof(&6u32).unwrap();
287
288 assert_eq!(verify_proof::<BlakeTwo256, _, _>(&root, &proof, &6u32, &6u128), Ok(()));
290
291 assert_eq!(
293 verify_proof::<BlakeTwo256, _, _>(&Default::default(), &proof, &6u32, &6u128),
294 Err(TrieError::RootMismatch.into())
295 );
296
297 let bad_proof = balance_trie.create_proof(&99u32).unwrap();
299
300 assert_eq!(
302 verify_proof::<BlakeTwo256, _, _>(&root, &bad_proof, &6u32, &6u128),
303 Err(TrieError::ExtraneousHashReference.into())
304 );
305 }
306
307 #[test]
308 fn proof_to_hashes() {
309 let mut i: u32 = 1;
310 let log16 = |x: u32| -> u32 {
312 let x_f64 = x as f64;
313 let log16_x = (x_f64.ln() / 16_f64.ln()).ceil();
314 log16_x as u32
315 };
316
317 while i < 10_000_000 {
318 let trie = BalanceTrie::generate_for((0..i).map(|i| (i, u128::from(i)))).unwrap();
319 let proof = trie.create_proof(&0).unwrap();
320 let hashes = BalanceTrie::proof_to_hashes(&proof).unwrap();
321 let log16 = log16(i).max(1);
322
323 assert_eq!(hashes, log16);
324 i = i * 10;
325 }
326 }
327}