1use crate::{NodeCodec, StorageProof};
24use codec::Encode;
25use hash_db::Hasher;
26use parking_lot::{Mutex, MutexGuard};
27use std::{
28 collections::{HashMap, HashSet},
29 marker::PhantomData,
30 mem,
31 ops::DerefMut,
32 sync::{
33 atomic::{AtomicUsize, Ordering},
34 Arc,
35 },
36};
37use trie_db::{RecordedForKey, TrieAccess};
38
39const LOG_TARGET: &str = "trie-recorder";
40
41#[derive(Default)]
43struct Transaction<H> {
44 recorded_keys: HashMap<H, HashMap<Arc<[u8]>, Option<RecordedForKey>>>,
49 accessed_nodes: HashSet<H>,
53}
54
55struct RecorderInner<H> {
57 recorded_keys: HashMap<H, HashMap<Arc<[u8]>, RecordedForKey>>,
61
62 transactions: Vec<Transaction<H>>,
64
65 accessed_nodes: HashMap<H, Vec<u8>>,
69}
70
71impl<H> Default for RecorderInner<H> {
72 fn default() -> Self {
73 Self {
74 recorded_keys: Default::default(),
75 accessed_nodes: Default::default(),
76 transactions: Vec::new(),
77 }
78 }
79}
80
81pub struct Recorder<H: Hasher> {
87 inner: Arc<Mutex<RecorderInner<H::Out>>>,
88 encoded_size_estimation: Arc<AtomicUsize>,
92}
93
94impl<H: Hasher> Default for Recorder<H> {
95 fn default() -> Self {
96 Self { inner: Default::default(), encoded_size_estimation: Arc::new(0.into()) }
97 }
98}
99
100impl<H: Hasher> Clone for Recorder<H> {
101 fn clone(&self) -> Self {
102 Self {
103 inner: self.inner.clone(),
104 encoded_size_estimation: self.encoded_size_estimation.clone(),
105 }
106 }
107}
108
109impl<H: Hasher> Recorder<H> {
110 pub fn recorded_keys(&self) -> HashMap<<H as Hasher>::Out, HashMap<Arc<[u8]>, RecordedForKey>> {
114 self.inner.lock().recorded_keys.clone()
115 }
116
117 #[inline]
124 pub fn as_trie_recorder(&self, storage_root: H::Out) -> TrieRecorder<'_, H> {
125 TrieRecorder::<H> {
126 inner: self.inner.lock(),
127 storage_root,
128 encoded_size_estimation: self.encoded_size_estimation.clone(),
129 _phantom: PhantomData,
130 }
131 }
132
133 pub fn drain_storage_proof(self) -> StorageProof {
142 let mut recorder = mem::take(&mut *self.inner.lock());
143 StorageProof::new(recorder.accessed_nodes.drain().map(|(_, v)| v))
144 }
145
146 pub fn to_storage_proof(&self) -> StorageProof {
153 let recorder = self.inner.lock();
154 StorageProof::new(recorder.accessed_nodes.values().cloned())
155 }
156
157 pub fn estimate_encoded_size(&self) -> usize {
162 self.encoded_size_estimation.load(Ordering::Relaxed)
163 }
164
165 pub fn reset(&self) {
169 mem::take(&mut *self.inner.lock());
170 self.encoded_size_estimation.store(0, Ordering::Relaxed);
171 }
172
173 pub fn start_transaction(&self) {
175 let mut inner = self.inner.lock();
176 inner.transactions.push(Default::default());
177 }
178
179 pub fn rollback_transaction(&self) -> Result<(), ()> {
183 let mut inner = self.inner.lock();
184
185 let mut new_encoded_size_estimation = self.encoded_size_estimation.load(Ordering::Relaxed);
188 let transaction = inner.transactions.pop().ok_or(())?;
189
190 transaction.accessed_nodes.into_iter().for_each(|n| {
191 if let Some(old) = inner.accessed_nodes.remove(&n) {
192 new_encoded_size_estimation =
193 new_encoded_size_estimation.saturating_sub(old.encoded_size());
194 }
195 });
196
197 transaction.recorded_keys.into_iter().for_each(|(storage_root, keys)| {
198 keys.into_iter().for_each(|(k, old_state)| {
199 if let Some(state) = old_state {
200 inner.recorded_keys.entry(storage_root).or_default().insert(k, state);
201 } else {
202 inner.recorded_keys.entry(storage_root).or_default().remove(&k);
203 }
204 });
205 });
206
207 self.encoded_size_estimation
208 .store(new_encoded_size_estimation, Ordering::Relaxed);
209
210 Ok(())
211 }
212
213 pub fn commit_transaction(&self) -> Result<(), ()> {
217 let mut inner = self.inner.lock();
218
219 let transaction = inner.transactions.pop().ok_or(())?;
220
221 if let Some(parent_transaction) = inner.transactions.last_mut() {
222 parent_transaction.accessed_nodes.extend(transaction.accessed_nodes);
223
224 transaction.recorded_keys.into_iter().for_each(|(storage_root, keys)| {
225 keys.into_iter().for_each(|(k, old_state)| {
226 parent_transaction
227 .recorded_keys
228 .entry(storage_root)
229 .or_default()
230 .entry(k)
231 .or_insert(old_state);
232 })
233 });
234 }
235
236 Ok(())
237 }
238}
239
240impl<H: Hasher> crate::ProofSizeProvider for Recorder<H> {
241 fn estimate_encoded_size(&self) -> usize {
242 Recorder::estimate_encoded_size(self)
243 }
244}
245
246pub struct TrieRecorder<'a, H: Hasher> {
248 inner: MutexGuard<'a, RecorderInner<H::Out>>,
249 storage_root: H::Out,
250 encoded_size_estimation: Arc<AtomicUsize>,
251 _phantom: PhantomData<H>,
252}
253
254impl<H: Hasher> crate::TrieRecorderProvider<H> for Recorder<H> {
255 type Recorder<'a>
256 = TrieRecorder<'a, H>
257 where
258 H: 'a;
259
260 fn drain_storage_proof(self) -> Option<StorageProof> {
261 Some(Recorder::drain_storage_proof(self))
262 }
263
264 fn as_trie_recorder(&self, storage_root: H::Out) -> Self::Recorder<'_> {
265 Recorder::as_trie_recorder(&self, storage_root)
266 }
267}
268
269impl<'a, H: Hasher> TrieRecorder<'a, H> {
270 fn update_recorded_keys(&mut self, full_key: &[u8], access: RecordedForKey) {
272 let inner = self.inner.deref_mut();
273
274 let entry =
275 inner.recorded_keys.entry(self.storage_root).or_default().entry(full_key.into());
276
277 let key = entry.key().clone();
278
279 let entry = if matches!(access, RecordedForKey::Value) {
282 entry.and_modify(|e| {
283 if let Some(tx) = inner.transactions.last_mut() {
284 tx.recorded_keys
286 .entry(self.storage_root)
287 .or_default()
288 .entry(key.clone())
289 .or_insert(Some(*e));
290 }
291
292 *e = access;
293 })
294 } else {
295 entry
296 };
297
298 entry.or_insert_with(|| {
299 if let Some(tx) = inner.transactions.last_mut() {
300 tx.recorded_keys
302 .entry(self.storage_root)
303 .or_default()
304 .entry(key)
305 .or_insert(None);
306 }
307
308 access
309 });
310 }
311}
312
313impl<'a, H: Hasher> trie_db::TrieRecorder<H::Out> for TrieRecorder<'a, H> {
314 fn record(&mut self, access: TrieAccess<H::Out>) {
315 let mut encoded_size_update = 0;
316
317 match access {
318 TrieAccess::NodeOwned { hash, node_owned } => {
319 tracing::trace!(
320 target: LOG_TARGET,
321 hash = ?hash,
322 "Recording node",
323 );
324
325 let inner = self.inner.deref_mut();
326
327 inner.accessed_nodes.entry(hash).or_insert_with(|| {
328 let node = node_owned.to_encoded::<NodeCodec<H>>();
329
330 encoded_size_update += node.encoded_size();
331
332 if let Some(tx) = inner.transactions.last_mut() {
333 tx.accessed_nodes.insert(hash);
334 }
335
336 node
337 });
338 },
339 TrieAccess::EncodedNode { hash, encoded_node } => {
340 tracing::trace!(
341 target: LOG_TARGET,
342 hash = ?hash,
343 "Recording node",
344 );
345
346 let inner = self.inner.deref_mut();
347
348 inner.accessed_nodes.entry(hash).or_insert_with(|| {
349 let node = encoded_node.into_owned();
350
351 encoded_size_update += node.encoded_size();
352
353 if let Some(tx) = inner.transactions.last_mut() {
354 tx.accessed_nodes.insert(hash);
355 }
356
357 node
358 });
359 },
360 TrieAccess::Value { hash, value, full_key } => {
361 tracing::trace!(
362 target: LOG_TARGET,
363 hash = ?hash,
364 key = ?sp_core::hexdisplay::HexDisplay::from(&full_key),
365 "Recording value",
366 );
367
368 let inner = self.inner.deref_mut();
369
370 inner.accessed_nodes.entry(hash).or_insert_with(|| {
371 let value = value.into_owned();
372
373 encoded_size_update += value.encoded_size();
374
375 if let Some(tx) = inner.transactions.last_mut() {
376 tx.accessed_nodes.insert(hash);
377 }
378
379 value
380 });
381
382 self.update_recorded_keys(full_key, RecordedForKey::Value);
383 },
384 TrieAccess::Hash { full_key } => {
385 tracing::trace!(
386 target: LOG_TARGET,
387 key = ?sp_core::hexdisplay::HexDisplay::from(&full_key),
388 "Recorded hash access for key",
389 );
390
391 self.update_recorded_keys(full_key, RecordedForKey::Hash);
394 },
395 TrieAccess::NonExisting { full_key } => {
396 tracing::trace!(
397 target: LOG_TARGET,
398 key = ?sp_core::hexdisplay::HexDisplay::from(&full_key),
399 "Recorded non-existing value access for key",
400 );
401
402 self.update_recorded_keys(full_key, RecordedForKey::Value);
406 },
407 TrieAccess::InlineValue { full_key } => {
408 tracing::trace!(
409 target: LOG_TARGET,
410 key = ?sp_core::hexdisplay::HexDisplay::from(&full_key),
411 "Recorded inline value access for key",
412 );
413
414 self.update_recorded_keys(full_key, RecordedForKey::Value);
417 },
418 };
419
420 self.encoded_size_estimation.fetch_add(encoded_size_update, Ordering::Relaxed);
421 }
422
423 fn trie_nodes_recorded_for_key(&self, key: &[u8]) -> RecordedForKey {
424 self.inner
425 .recorded_keys
426 .get(&self.storage_root)
427 .and_then(|k| k.get(key).copied())
428 .unwrap_or(RecordedForKey::None)
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435 use crate::tests::create_trie;
436 use trie_db::{Trie, TrieDBBuilder, TrieRecorder};
437
438 type MemoryDB = crate::MemoryDB<sp_core::Blake2Hasher>;
439 type Layout = crate::LayoutV1<sp_core::Blake2Hasher>;
440 type Recorder = super::Recorder<sp_core::Blake2Hasher>;
441
442 const TEST_DATA: &[(&[u8], &[u8])] =
443 &[(b"key1", &[1; 64]), (b"key2", &[2; 64]), (b"key3", &[3; 64]), (b"key4", &[4; 64])];
444
445 #[test]
446 fn recorder_works() {
447 let (db, root) = create_trie::<Layout>(TEST_DATA);
448
449 let recorder = Recorder::default();
450
451 {
452 let mut trie_recorder = recorder.as_trie_recorder(root);
453 let trie = TrieDBBuilder::<Layout>::new(&db, &root)
454 .with_recorder(&mut trie_recorder)
455 .build();
456 assert_eq!(TEST_DATA[0].1.to_vec(), trie.get(TEST_DATA[0].0).unwrap().unwrap());
457 }
458
459 let storage_proof = recorder.drain_storage_proof();
460 let memory_db: MemoryDB = storage_proof.into_memory_db();
461
462 let trie = TrieDBBuilder::<Layout>::new(&memory_db, &root).build();
464 assert_eq!(TEST_DATA[0].1.to_vec(), trie.get(TEST_DATA[0].0).unwrap().unwrap());
465 }
466
467 #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
468 struct RecorderStats {
469 accessed_nodes: usize,
470 recorded_keys: usize,
471 estimated_size: usize,
472 }
473
474 impl RecorderStats {
475 fn extract(recorder: &Recorder) -> Self {
476 let inner = recorder.inner.lock();
477
478 let recorded_keys =
479 inner.recorded_keys.iter().flat_map(|(_, keys)| keys.keys()).count();
480
481 Self {
482 recorded_keys,
483 accessed_nodes: inner.accessed_nodes.len(),
484 estimated_size: recorder.estimate_encoded_size(),
485 }
486 }
487 }
488
489 #[test]
490 fn recorder_transactions_rollback_work() {
491 let (db, root) = create_trie::<Layout>(TEST_DATA);
492
493 let recorder = Recorder::default();
494 let mut stats = vec![RecorderStats::default()];
495
496 for i in 0..4 {
497 recorder.start_transaction();
498 {
499 let mut trie_recorder = recorder.as_trie_recorder(root);
500 let trie = TrieDBBuilder::<Layout>::new(&db, &root)
501 .with_recorder(&mut trie_recorder)
502 .build();
503
504 assert_eq!(TEST_DATA[i].1.to_vec(), trie.get(TEST_DATA[i].0).unwrap().unwrap());
505 }
506 stats.push(RecorderStats::extract(&recorder));
507 }
508
509 assert_eq!(4, recorder.inner.lock().transactions.len());
510
511 for i in 0..5 {
512 assert_eq!(stats[4 - i], RecorderStats::extract(&recorder));
513
514 let storage_proof = recorder.to_storage_proof();
515 let memory_db: MemoryDB = storage_proof.into_memory_db();
516
517 let trie = TrieDBBuilder::<Layout>::new(&memory_db, &root).build();
519
520 for a in 0..4 {
522 if a < 4 - i {
523 assert_eq!(TEST_DATA[a].1.to_vec(), trie.get(TEST_DATA[a].0).unwrap().unwrap());
524 } else {
525 assert!(trie.get(TEST_DATA[a].0).is_err());
527 }
528 }
529
530 if i < 4 {
531 recorder.rollback_transaction().unwrap();
532 }
533 }
534
535 assert_eq!(0, recorder.inner.lock().transactions.len());
536 }
537
538 #[test]
539 fn recorder_transactions_commit_work() {
540 let (db, root) = create_trie::<Layout>(TEST_DATA);
541
542 let recorder = Recorder::default();
543
544 for i in 0..4 {
545 recorder.start_transaction();
546 {
547 let mut trie_recorder = recorder.as_trie_recorder(root);
548 let trie = TrieDBBuilder::<Layout>::new(&db, &root)
549 .with_recorder(&mut trie_recorder)
550 .build();
551
552 assert_eq!(TEST_DATA[i].1.to_vec(), trie.get(TEST_DATA[i].0).unwrap().unwrap());
553 }
554 }
555
556 let stats = RecorderStats::extract(&recorder);
557 assert_eq!(4, recorder.inner.lock().transactions.len());
558
559 for _ in 0..4 {
560 recorder.commit_transaction().unwrap();
561 }
562 assert_eq!(0, recorder.inner.lock().transactions.len());
563 assert_eq!(stats, RecorderStats::extract(&recorder));
564
565 let storage_proof = recorder.to_storage_proof();
566 let memory_db: MemoryDB = storage_proof.into_memory_db();
567
568 let trie = TrieDBBuilder::<Layout>::new(&memory_db, &root).build();
570
571 for i in 0..4 {
573 assert_eq!(TEST_DATA[i].1.to_vec(), trie.get(TEST_DATA[i].0).unwrap().unwrap());
574 }
575 }
576
577 #[test]
578 fn recorder_transactions_commit_and_rollback_work() {
579 let (db, root) = create_trie::<Layout>(TEST_DATA);
580
581 let recorder = Recorder::default();
582
583 for i in 0..2 {
584 recorder.start_transaction();
585 {
586 let mut trie_recorder = recorder.as_trie_recorder(root);
587 let trie = TrieDBBuilder::<Layout>::new(&db, &root)
588 .with_recorder(&mut trie_recorder)
589 .build();
590
591 assert_eq!(TEST_DATA[i].1.to_vec(), trie.get(TEST_DATA[i].0).unwrap().unwrap());
592 }
593 }
594
595 recorder.rollback_transaction().unwrap();
596
597 for i in 2..4 {
598 recorder.start_transaction();
599 {
600 let mut trie_recorder = recorder.as_trie_recorder(root);
601 let trie = TrieDBBuilder::<Layout>::new(&db, &root)
602 .with_recorder(&mut trie_recorder)
603 .build();
604
605 assert_eq!(TEST_DATA[i].1.to_vec(), trie.get(TEST_DATA[i].0).unwrap().unwrap());
606 }
607 }
608
609 recorder.rollback_transaction().unwrap();
610
611 assert_eq!(2, recorder.inner.lock().transactions.len());
612
613 for _ in 0..2 {
614 recorder.commit_transaction().unwrap();
615 }
616
617 assert_eq!(0, recorder.inner.lock().transactions.len());
618
619 let storage_proof = recorder.to_storage_proof();
620 let memory_db: MemoryDB = storage_proof.into_memory_db();
621
622 let trie = TrieDBBuilder::<Layout>::new(&memory_db, &root).build();
624
625 for i in 0..4 {
627 if i % 2 == 0 {
628 assert_eq!(TEST_DATA[i].1.to_vec(), trie.get(TEST_DATA[i].0).unwrap().unwrap());
629 } else {
630 assert!(trie.get(TEST_DATA[i].0).is_err());
631 }
632 }
633 }
634
635 #[test]
636 fn recorder_transaction_accessed_keys_works() {
637 let key = TEST_DATA[0].0;
638 let (db, root) = create_trie::<Layout>(TEST_DATA);
639
640 let recorder = Recorder::default();
641
642 {
643 let trie_recorder = recorder.as_trie_recorder(root);
644 assert!(matches!(trie_recorder.trie_nodes_recorded_for_key(key), RecordedForKey::None));
645 }
646
647 recorder.start_transaction();
648 {
649 let mut trie_recorder = recorder.as_trie_recorder(root);
650 let trie = TrieDBBuilder::<Layout>::new(&db, &root)
651 .with_recorder(&mut trie_recorder)
652 .build();
653
654 assert_eq!(
655 sp_core::Blake2Hasher::hash(TEST_DATA[0].1),
656 trie.get_hash(TEST_DATA[0].0).unwrap().unwrap()
657 );
658 assert!(matches!(trie_recorder.trie_nodes_recorded_for_key(key), RecordedForKey::Hash));
659 }
660
661 recorder.start_transaction();
662 {
663 let mut trie_recorder = recorder.as_trie_recorder(root);
664 let trie = TrieDBBuilder::<Layout>::new(&db, &root)
665 .with_recorder(&mut trie_recorder)
666 .build();
667
668 assert_eq!(TEST_DATA[0].1.to_vec(), trie.get(TEST_DATA[0].0).unwrap().unwrap());
669 assert!(matches!(
670 trie_recorder.trie_nodes_recorded_for_key(key),
671 RecordedForKey::Value,
672 ));
673 }
674
675 recorder.rollback_transaction().unwrap();
676 {
677 let trie_recorder = recorder.as_trie_recorder(root);
678 assert!(matches!(trie_recorder.trie_nodes_recorded_for_key(key), RecordedForKey::Hash));
679 }
680
681 recorder.rollback_transaction().unwrap();
682 {
683 let trie_recorder = recorder.as_trie_recorder(root);
684 assert!(matches!(trie_recorder.trie_nodes_recorded_for_key(key), RecordedForKey::None));
685 }
686
687 recorder.start_transaction();
688 {
689 let mut trie_recorder = recorder.as_trie_recorder(root);
690 let trie = TrieDBBuilder::<Layout>::new(&db, &root)
691 .with_recorder(&mut trie_recorder)
692 .build();
693
694 assert_eq!(TEST_DATA[0].1.to_vec(), trie.get(TEST_DATA[0].0).unwrap().unwrap());
695 assert!(matches!(
696 trie_recorder.trie_nodes_recorded_for_key(key),
697 RecordedForKey::Value,
698 ));
699 }
700
701 recorder.start_transaction();
702 {
703 let mut trie_recorder = recorder.as_trie_recorder(root);
704 let trie = TrieDBBuilder::<Layout>::new(&db, &root)
705 .with_recorder(&mut trie_recorder)
706 .build();
707
708 assert_eq!(
709 sp_core::Blake2Hasher::hash(TEST_DATA[0].1),
710 trie.get_hash(TEST_DATA[0].0).unwrap().unwrap()
711 );
712 assert!(matches!(
713 trie_recorder.trie_nodes_recorded_for_key(key),
714 RecordedForKey::Value
715 ));
716 }
717
718 recorder.rollback_transaction().unwrap();
719 {
720 let trie_recorder = recorder.as_trie_recorder(root);
721 assert!(matches!(
722 trie_recorder.trie_nodes_recorded_for_key(key),
723 RecordedForKey::Value
724 ));
725 }
726
727 recorder.rollback_transaction().unwrap();
728 {
729 let trie_recorder = recorder.as_trie_recorder(root);
730 assert!(matches!(trie_recorder.trie_nodes_recorded_for_key(key), RecordedForKey::None));
731 }
732 }
733}