1use alloc::{
8 sync::Arc,
9 vec::Vec,
10};
11use core::{
12 any::Any,
13 fmt::Debug,
14 hash::Hash,
15 ops::AddAssign,
16};
17use hashbrown::{
18 HashMap,
19 HashSet,
20};
21
22use fuel_asm::Word;
23use fuel_storage::{
24 Mappable,
25 StorageInspect,
26 StorageMutate,
27};
28use fuel_tx::{
29 Contract,
30 Receipt,
31};
32use fuel_types::AssetId;
33
34use crate::{
35 call::CallFrame,
36 context::Context,
37 storage::{
38 ContractsAssets,
39 ContractsRawCode,
40 ContractsState,
41 },
42};
43
44use super::{
45 balances::Balance,
46 receipts::ReceiptsCtx,
47 ExecutableTransaction,
48 Interpreter,
49 Memory,
50 PanicContext,
51};
52use crate::interpreter::memory::MemoryRollbackData;
53use storage::*;
54
55mod storage;
56
57#[cfg(test)]
58mod tests;
59
60#[derive(Debug, Clone)]
61pub struct Diff<T: VmStateCapture + Clone> {
70 changes: Vec<Change<T>>,
71}
72
73#[derive(Debug, Clone)]
74enum Change<T: VmStateCapture + Clone> {
75 Register(T::State<VecState<Word>>),
77 Memory(MemoryRollbackData),
79 Storage(T::State<StorageState>),
81 Frame(T::State<VecState<Option<CallFrame>>>),
83 Receipt(T::State<VecState<Option<Receipt>>>),
85 Balance(T::State<MapState<AssetId, Option<Balance>>>),
87 Context(T::State<Context>),
89 PanicContext(T::State<PanicContext>),
91 Txn(T::State<Arc<dyn AnyDebug>>),
93}
94
95pub trait AnyDebug: Any + Debug {
97 fn as_any_ref(&self) -> &dyn Any;
99}
100
101impl<T> AnyDebug for T
102where
103 T: Any + Debug,
104{
105 fn as_any_ref(&self) -> &dyn Any {
106 self
107 }
108}
109
110pub trait VmStateCapture {
113 type State<S: core::fmt::Debug + Clone>: core::fmt::Debug + Clone;
116}
117
118#[derive(Debug, Clone)]
119pub struct Deltas;
123
124impl VmStateCapture for Deltas {
125 type State<S: core::fmt::Debug + Clone> = Delta<S>;
126}
127
128#[derive(Debug, Clone)]
129pub struct Delta<S> {
131 from: S,
133 to: S,
135}
136
137#[derive(Debug, Clone)]
138pub struct InitialVmState;
141
142impl VmStateCapture for InitialVmState {
143 type State<S: core::fmt::Debug + Clone> = Previous<S>;
144}
145#[derive(Debug, Clone)]
146pub struct Previous<S>(S);
148
149#[derive(Debug, Clone)]
150struct VecState<T> {
152 index: usize,
154 value: T,
156}
157
158#[derive(Debug, Clone)]
159struct MapState<K, V>
161where
162 K: Hash,
163 V: PartialEq,
164{
165 key: K,
167 value: V,
169}
170
171fn capture_buffer_state<'iter, I, T>(
172 a: I,
173 b: I,
174 change: fn(Delta<VecState<T>>) -> Change<Deltas>,
175) -> impl Iterator<Item = Change<Deltas>> + 'iter
176where
177 T: 'static + core::cmp::PartialEq + Clone,
178 I: Iterator<Item = &'iter T> + 'iter,
179{
180 a.enumerate()
181 .zip(b)
182 .filter(|&(a, b)| (a.1 != b))
183 .map(move |(a, b)| {
184 change(Delta {
185 from: VecState {
186 index: a.0,
187 value: a.1.clone(),
188 },
189 to: VecState {
190 index: a.0,
191 value: b.clone(),
192 },
193 })
194 })
195}
196
197type ChangeDeltaVariant<S> = fn(Delta<S>) -> Change<Deltas>;
198
199fn capture_map_state<'iter, K, V>(
200 a: &'iter HashMap<K, V>,
201 b: &'iter HashMap<K, V>,
202 change: ChangeDeltaVariant<MapState<K, Option<V>>>,
203) -> Vec<Change<Deltas>>
204where
205 K: 'static + PartialEq + Eq + Clone + Hash + Debug,
206 V: 'static + core::cmp::PartialEq + Clone + Debug,
207{
208 let a_keys: HashSet<_> = a.keys().collect();
209 let b_keys: HashSet<_> = b.keys().collect();
210 capture_map_state_inner(a, &a_keys, b, &b_keys)
211 .map(change)
212 .collect()
213}
214
215fn capture_map_state_inner<'iter, K, V>(
216 a: &'iter HashMap<K, V>,
217 a_keys: &'iter HashSet<&K>,
218 b: &'iter HashMap<K, V>,
219 b_keys: &'iter HashSet<&K>,
220) -> impl Iterator<Item = Delta<MapState<K, Option<V>>>> + 'iter
221where
222 K: 'static + PartialEq + Eq + Clone + Hash + Debug,
223 V: 'static + core::cmp::PartialEq + Clone + Debug,
224{
225 let a_diff = a_keys.difference(b_keys).map(|k| Delta {
226 from: MapState {
227 key: (*k).clone(),
228 value: Some(a[*k].clone()),
229 },
230 to: MapState {
231 key: (*k).clone(),
232 value: None,
233 },
234 });
235 let b_diff = b_keys.difference(a_keys).map(|k| Delta {
236 from: MapState {
237 key: (*k).clone(),
238 value: None,
239 },
240 to: MapState {
241 key: (*k).clone(),
242 value: Some(b[*k].clone()),
243 },
244 });
245 let intersection = a_keys.intersection(b_keys).filter_map(|k| {
246 let value_a = &a[*k];
247 let value_b = &b[*k];
248 (value_a != value_b).then(|| Delta {
249 from: MapState {
250 key: (*k).clone(),
251 value: Some(value_a.clone()),
252 },
253 to: MapState {
254 key: (*k).clone(),
255 value: Some(value_b.clone()),
256 },
257 })
258 });
259
260 a_diff.chain(b_diff).chain(intersection)
261}
262
263fn capture_vec_state<'iter, I, T>(
264 a: I,
265 b: I,
266 change: ChangeDeltaVariant<VecState<Option<T>>>,
267) -> impl Iterator<Item = Change<Deltas>> + 'iter
268where
269 T: 'static + core::cmp::PartialEq + Clone,
270 I: Iterator<Item = &'iter T> + 'iter,
271{
272 capture_vec_state_inner(a, b).map(move |(index, a, b)| {
273 change(Delta {
274 from: VecState { index, value: a },
275 to: VecState { index, value: b },
276 })
277 })
278}
279fn capture_vec_state_inner<'iter, I, T>(
280 a: I,
281 b: I,
282) -> impl Iterator<Item = (usize, Option<T>, Option<T>)> + 'iter
283where
284 T: 'static + core::cmp::PartialEq + Clone,
285 I: Iterator<Item = &'iter T> + 'iter,
286{
287 a.map(Some)
288 .chain(core::iter::repeat(None))
289 .enumerate()
290 .zip(b.map(Some).chain(core::iter::repeat(None)))
291 .take_while(|((_, a), b)| a.is_some() || b.is_some())
292 .filter(|((_, a), b)| b.map_or(true, |b| a.map_or(true, |a| a != b)))
293 .map(|((index, a), b)| (index, a.cloned(), b.cloned()))
294}
295
296impl<M, S, Tx, Ecal> Interpreter<M, S, Tx, Ecal>
297where
298 M: Memory,
299{
300 pub fn rollback_to(&self, desired_state: &Self) -> Diff<Deltas>
304 where
305 Tx: PartialEq + Clone + Debug + 'static,
306 {
307 let mut diff = Diff {
308 changes: Vec::new(),
309 };
310 let registers = capture_buffer_state(
311 self.registers.iter(),
312 desired_state.registers.iter(),
313 Change::Register,
314 );
315 diff.changes.extend(registers);
316 let frames = capture_vec_state(
317 self.frames.iter(),
318 desired_state.frames.iter(),
319 Change::Frame,
320 );
321 diff.changes.extend(frames);
322 let receipts = capture_vec_state(
323 self.receipts.as_ref().iter(),
324 desired_state.receipts.as_ref().iter(),
325 Change::Receipt,
326 );
327 diff.changes.extend(receipts);
328 let balances = capture_map_state(
329 self.balances.as_ref(),
330 desired_state.balances.as_ref(),
331 Change::Balance,
332 );
333 diff.changes.extend(balances);
334
335 let memory_rollback_data =
336 self.memory().collect_rollback_data(desired_state.memory());
337
338 if let Some(memory_rollback_data) = memory_rollback_data {
339 diff.changes.push(Change::Memory(memory_rollback_data));
340 }
341
342 if self.context != desired_state.context {
343 diff.changes.push(Change::Context(Delta {
344 from: self.context.clone(),
345 to: desired_state.context.clone(),
346 }))
347 }
348
349 if self.panic_context != desired_state.panic_context {
350 diff.changes.push(Change::PanicContext(Delta {
351 from: self.panic_context.clone(),
352 to: desired_state.panic_context.clone(),
353 }))
354 }
355
356 if self.tx != desired_state.tx {
357 let from: Arc<dyn AnyDebug> = Arc::new(self.tx.clone());
358 let to: Arc<dyn AnyDebug> = Arc::new(desired_state.tx.clone());
359 diff.changes.push(Change::Txn(Delta { from, to }))
360 }
361
362 diff
363 }
364}
365
366impl<M, S, Tx, Ecal> Interpreter<M, S, Tx, Ecal>
367where
368 M: Memory,
369{
370 fn inverse_inner(&mut self, change: &Change<InitialVmState>)
371 where
372 Tx: Clone + 'static,
373 {
374 match change {
375 Change::Register(Previous(VecState { index, value })) => {
376 self.registers[*index] = *value
377 }
378 Change::Frame(Previous(value)) => invert_vec(&mut self.frames, value),
379 Change::Receipt(Previous(value)) => {
380 invert_receipts_ctx(&mut self.receipts, value)
381 }
382 Change::Balance(Previous(value)) => invert_map(self.balances.as_mut(), value),
383 Change::Memory(memory_rollback_data) => {
384 self.memory_mut().rollback(memory_rollback_data)
385 }
386 Change::Context(Previous(value)) => self.context = value.clone(),
387 Change::PanicContext(Previous(value)) => self.panic_context = value.clone(),
388 Change::Txn(Previous(tx)) => {
389 self.tx = AsRef::<dyn AnyDebug>::as_ref(tx)
390 .as_any_ref()
391 .downcast_ref::<Tx>()
392 .unwrap()
393 .clone();
394 }
395 Change::Storage(_) => (),
396 }
397 }
398}
399
400fn invert_vec<T: Clone>(vector: &mut Vec<T>, value: &VecState<Option<T>>) {
401 use core::cmp::Ordering;
402 match (&value, value.index.cmp(&vector.len())) {
403 (
404 VecState {
405 index,
406 value: Some(value),
407 },
408 Ordering::Equal | Ordering::Greater,
409 ) => {
410 vector.resize((*index).saturating_add(1), value.clone());
411 vector[*index] = value.clone();
412 }
413 (
414 VecState {
415 index,
416 value: Some(value),
417 },
418 Ordering::Less,
419 ) => vector[*index] = value.clone(),
420 (VecState { value: None, .. }, Ordering::Equal | Ordering::Greater) => (),
421 (VecState { index, value: None }, Ordering::Less) => vector.truncate(*index),
422 }
423}
424
425fn invert_map<K: Hash + PartialEq + Eq + Clone, V: Clone + PartialEq>(
426 map: &mut HashMap<K, V>,
427 value: &MapState<K, Option<V>>,
428) {
429 match value {
430 MapState {
431 key,
432 value: Some(value),
433 } => {
434 map.insert(key.clone(), value.clone());
435 }
436 MapState { key, value: None } => {
437 map.remove(key);
438 }
439 }
440}
441
442fn invert_receipts_ctx(ctx: &mut ReceiptsCtx, value: &VecState<Option<Receipt>>) {
443 let mut ctx_mut = ctx.lock();
444 invert_vec(ctx_mut.receipts_mut(), value);
445}
446
447impl<M, S, Tx, Ecal> PartialEq for Interpreter<M, S, Tx, Ecal>
448where
449 M: Memory,
450 Tx: PartialEq,
451{
452 fn eq(&self, other: &Self) -> bool {
454 self.registers == other.registers
455 && self.memory.as_ref() == other.memory.as_ref()
456 && self.frames == other.frames
457 && self.receipts == other.receipts
458 && self.tx == other.tx
459 && self.initial_balances == other.initial_balances
460 && self.context == other.context
461 && self.balances == other.balances
462 && self.interpreter_params == other.interpreter_params
463 && self.panic_context == other.panic_context
464 }
465}
466
467impl From<Diff<Deltas>> for Diff<InitialVmState> {
468 fn from(d: Diff<Deltas>) -> Self {
469 Self {
470 changes: d
471 .changes
472 .into_iter()
473 .map(|c| match c {
474 Change::Register(v) => Change::Register(v.into()),
475 Change::Memory(v) => Change::Memory(v),
476 Change::Storage(v) => Change::Storage(v.into()),
477 Change::Frame(v) => Change::Frame(v.into()),
478 Change::Receipt(v) => Change::Receipt(v.into()),
479 Change::Balance(v) => Change::Balance(v.into()),
480 Change::Context(v) => Change::Context(v.into()),
481 Change::PanicContext(v) => Change::PanicContext(v.into()),
482 Change::Txn(v) => Change::Txn(v.into()),
483 })
484 .collect(),
485 }
486 }
487}
488
489impl<T> From<Delta<T>> for Previous<T> {
490 fn from(d: Delta<T>) -> Self {
491 Self(d.to)
492 }
493}
494
495impl<T: VmStateCapture + Clone> AddAssign for Diff<T> {
496 fn add_assign(&mut self, rhs: Self) {
497 self.changes.extend(rhs.changes);
498 }
499}