fuel_vm/constraints/
reg_key.rs

1//! Utilities for accessing register values and proving at compile time that
2//! the register index is valid.
3//!
4//! This module also provides utilities for mutably accessing multiple registers.
5use core::ops::{
6    Deref,
7    DerefMut,
8};
9
10use fuel_asm::{
11    PanicReason,
12    RegId,
13    RegisterId,
14    Word,
15};
16
17use crate::consts::{
18    VM_REGISTER_COUNT,
19    VM_REGISTER_PROGRAM_COUNT,
20    VM_REGISTER_SYSTEM_COUNT,
21};
22
23#[cfg(test)]
24mod tests;
25
26#[derive(Debug, PartialEq, Eq)]
27/// Mutable reference to a register value at a given index.
28pub struct RegMut<'r, const INDEX: u8>(&'r mut Word);
29
30#[derive(Clone, Copy, Debug, PartialEq, Eq)]
31/// Immutable reference to a register value at a given index.
32pub struct Reg<'r, const INDEX: u8>(&'r Word);
33
34#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
35/// A key to a writable register that is within
36/// the bounds of the writable registers.
37pub struct WriteRegKey(usize);
38
39impl WriteRegKey {
40    /// Create a new writable register key if the index is within the bounds
41    /// of the writable registers.
42    pub fn new(k: impl Into<usize>) -> Result<Self, PanicReason> {
43        let k = k.into();
44        is_register_writable(&k)?;
45        Ok(Self(k))
46    }
47
48    /// Translate this key from an absolute register index
49    /// to a program register index.
50    ///
51    /// This subtracts the number of system registers from the key.
52    #[allow(clippy::arithmetic_side_effects)] // Safety: checked in constructor
53    fn translate(self) -> usize {
54        self.0 - VM_REGISTER_SYSTEM_COUNT
55    }
56}
57
58/// Check that the register is above the system registers and below the total
59/// number of registers.
60pub(crate) fn is_register_writable(r: &RegisterId) -> Result<(), PanicReason> {
61    const W_USIZE: usize = RegId::WRITABLE.to_u8() as usize;
62    const RANGE: core::ops::Range<usize> = W_USIZE..(W_USIZE + VM_REGISTER_PROGRAM_COUNT);
63    if RANGE.contains(r) {
64        Ok(())
65    } else {
66        Err(PanicReason::ReservedRegisterNotWritable)
67    }
68}
69
70impl<'r, const INDEX: u8> RegMut<'r, INDEX> {
71    /// Create a new mutable register reference.
72    pub fn new(reg: &'r mut Word) -> Self {
73        Self(reg)
74    }
75}
76
77impl<'r, const INDEX: u8> Reg<'r, INDEX> {
78    /// Create a new immutable register reference.
79    pub fn new(reg: &'r Word) -> Self {
80        Self(reg)
81    }
82}
83
84impl<const INDEX: u8> Deref for Reg<'_, INDEX> {
85    type Target = Word;
86
87    fn deref(&self) -> &Self::Target {
88        self.0
89    }
90}
91
92impl<const INDEX: u8> Deref for RegMut<'_, INDEX> {
93    type Target = Word;
94
95    fn deref(&self) -> &Self::Target {
96        self.0
97    }
98}
99
100impl<const INDEX: u8> DerefMut for RegMut<'_, INDEX> {
101    fn deref_mut(&mut self) -> &mut Self::Target {
102        self.0
103    }
104}
105
106impl<'a, const INDEX: u8> From<RegMut<'a, INDEX>> for Reg<'a, INDEX> {
107    fn from(reg: RegMut<'a, INDEX>) -> Self {
108        Self(reg.0)
109    }
110}
111
112impl<'r, const INDEX: u8> RegMut<'r, INDEX> {
113    /// Re-borrow the register as an immutable reference.
114    pub fn as_ref(&self) -> Reg<INDEX> {
115        Reg(self.0)
116    }
117}
118
119impl<'r, const INDEX: u8> RegMut<'r, INDEX> {
120    /// Re-borrow the register as a mutable reference.
121    pub fn as_mut(&mut self) -> RegMut<INDEX> {
122        RegMut(self.0)
123    }
124}
125
126macro_rules! impl_keys {
127    ( $($i:ident, $f:ident $(,$f_mut:ident)?)* ) => {
128        $(
129            #[doc = "Register index key for use with Reg and RegMut."]
130            pub const $i: u8 = RegId::$i.to_u8();
131        )*
132        #[doc = "Get register reference by name."]
133        pub trait GetReg {
134        $(
135            #[doc = "Get register reference for this key."]
136            fn $f(&self) -> Reg<'_, $i>;
137        )*
138        }
139        #[doc = "Get register mutable reference by name."]
140        pub trait GetRegMut {
141        $(
142            $(
143            #[doc = "Get mutable register reference for this key."]
144            fn $f_mut(&mut self) -> RegMut<'_, $i>;
145            )?
146        )*
147        }
148        impl GetReg for [Word; VM_REGISTER_COUNT] {
149        $(
150            fn $f(&self) -> Reg<'_, $i> {
151                Reg(&self[$i as usize])
152            }
153        )*
154        }
155        impl GetRegMut for [Word; VM_REGISTER_COUNT] {
156        $(
157            $(
158            fn $f_mut(&mut self) -> RegMut<'_, $i> {
159                RegMut(&mut self[$i as usize])
160            }
161            )?
162        )*
163        }
164    };
165}
166
167impl_keys! {
168    ZERO, zero
169    ONE, one
170    OF, of, of_mut
171    PC, pc, pc_mut
172    SSP, ssp, ssp_mut
173    SP, sp, sp_mut
174    FP, fp, fp_mut
175    HP, hp, hp_mut
176    ERR, err, err_mut
177    GGAS, ggas, ggas_mut
178    CGAS, cgas, cgas_mut
179    BAL, bal, bal_mut
180    IS, is, is_mut
181    RET, ret, ret_mut
182    RETL, retl, retl_mut
183    FLAG, flag, flag_mut
184}
185
186/// The set of system registers split into
187/// individual mutable references.
188pub(crate) struct SystemRegisters<'a> {
189    pub(crate) zero: RegMut<'a, ZERO>,
190    pub(crate) one: RegMut<'a, ONE>,
191    pub(crate) of: RegMut<'a, OF>,
192    pub(crate) pc: RegMut<'a, PC>,
193    pub(crate) ssp: RegMut<'a, SSP>,
194    pub(crate) sp: RegMut<'a, SP>,
195    pub(crate) fp: RegMut<'a, FP>,
196    pub(crate) hp: RegMut<'a, HP>,
197    pub(crate) err: RegMut<'a, ERR>,
198    pub(crate) ggas: RegMut<'a, GGAS>,
199    pub(crate) cgas: RegMut<'a, CGAS>,
200    pub(crate) bal: RegMut<'a, BAL>,
201    pub(crate) is: RegMut<'a, IS>,
202    pub(crate) ret: RegMut<'a, RET>,
203    pub(crate) retl: RegMut<'a, RETL>,
204    pub(crate) flag: RegMut<'a, FLAG>,
205}
206
207/// Same as `SystemRegisters` but with immutable references.
208pub(crate) struct SystemRegistersRef<'a> {
209    pub(crate) zero: Reg<'a, ZERO>,
210    pub(crate) one: Reg<'a, ONE>,
211    pub(crate) of: Reg<'a, OF>,
212    pub(crate) pc: Reg<'a, PC>,
213    pub(crate) ssp: Reg<'a, SSP>,
214    pub(crate) sp: Reg<'a, SP>,
215    pub(crate) fp: Reg<'a, FP>,
216    pub(crate) hp: Reg<'a, HP>,
217    pub(crate) err: Reg<'a, ERR>,
218    pub(crate) ggas: Reg<'a, GGAS>,
219    pub(crate) cgas: Reg<'a, CGAS>,
220    pub(crate) bal: Reg<'a, BAL>,
221    pub(crate) is: Reg<'a, IS>,
222    pub(crate) ret: Reg<'a, RET>,
223    pub(crate) retl: Reg<'a, RETL>,
224    pub(crate) flag: Reg<'a, FLAG>,
225}
226
227/// The set of program registers split from the system registers.
228pub(crate) struct ProgramRegisters<'a>(pub &'a mut [Word; VM_REGISTER_PROGRAM_COUNT]);
229
230/// Same as `ProgramRegisters` but with immutable references.
231pub(crate) struct ProgramRegistersRef<'a>(pub &'a [Word; VM_REGISTER_PROGRAM_COUNT]);
232
233/// Split the registers into system and program registers.
234///
235/// This allows multiple mutable references to registers.
236pub(crate) fn split_registers(
237    registers: &mut [Word; VM_REGISTER_COUNT],
238) -> (SystemRegisters<'_>, ProgramRegisters<'_>) {
239    let [zero, one, of, pc, ssp, sp, fp, hp, err, ggas, cgas, bal, is, ret, retl, flag, rest @ ..] =
240        registers;
241    let r = SystemRegisters {
242        zero: RegMut(zero),
243        one: RegMut(one),
244        of: RegMut(of),
245        pc: RegMut(pc),
246        ssp: RegMut(ssp),
247        sp: RegMut(sp),
248        fp: RegMut(fp),
249        hp: RegMut(hp),
250        err: RegMut(err),
251        ggas: RegMut(ggas),
252        cgas: RegMut(cgas),
253        bal: RegMut(bal),
254        is: RegMut(is),
255        ret: RegMut(ret),
256        retl: RegMut(retl),
257        flag: RegMut(flag),
258    };
259    (r, ProgramRegisters(rest))
260}
261
262/// Copy the system and program registers into a single array.
263pub(crate) fn copy_registers(
264    system_registers: &SystemRegistersRef<'_>,
265    program_registers: &ProgramRegistersRef<'_>,
266) -> [Word; VM_REGISTER_COUNT] {
267    let mut out = [0u64; VM_REGISTER_COUNT];
268    out[..VM_REGISTER_SYSTEM_COUNT]
269        .copy_from_slice(&<[Word; VM_REGISTER_SYSTEM_COUNT]>::from(system_registers));
270    out[VM_REGISTER_SYSTEM_COUNT..].copy_from_slice(program_registers.0);
271    out
272}
273
274impl<'r> ProgramRegisters<'r> {
275    /// Get two mutable references to program registers.
276    /// Note they cannot be the same register.
277    pub fn get_mut_two(
278        &mut self,
279        a: WriteRegKey,
280        b: WriteRegKey,
281    ) -> Option<(&mut Word, &mut Word)> {
282        if a == b {
283            // Cannot mutably borrow the same register twice.
284            return None
285        }
286
287        // Order registers
288        let swap = a > b;
289        let (a, b) = if swap { (b, a) } else { (a, b) };
290
291        // Translate the absolute register indices to a program register indeces.
292        let a = a.translate();
293
294        // Subtract a + 1 because because we split the array at `a`.
295        let b = b
296            .translate()
297            .checked_sub(a.saturating_add(1))
298            .expect("Cannot underflow as the values are ordered");
299
300        // Split the array at the first register which is a.
301        let [i, rest @ ..] = &mut self.0[a..] else {
302            return None
303        };
304
305        // Translate the higher absolute register index to a program register index.
306        // Get the `b` register.
307        let j = &mut rest[b];
308
309        Some(if swap { (j, i) } else { (i, j) })
310    }
311}
312
313impl<'a> From<&'a SystemRegisters<'_>> for SystemRegistersRef<'a> {
314    fn from(value: &'a SystemRegisters<'_>) -> Self {
315        Self {
316            zero: Reg(value.zero.0),
317            one: Reg(value.one.0),
318            of: Reg(value.of.0),
319            pc: Reg(value.pc.0),
320            ssp: Reg(value.ssp.0),
321            sp: Reg(value.sp.0),
322            fp: Reg(value.fp.0),
323            hp: Reg(value.hp.0),
324            err: Reg(value.err.0),
325            ggas: Reg(value.ggas.0),
326            cgas: Reg(value.cgas.0),
327            bal: Reg(value.bal.0),
328            is: Reg(value.is.0),
329            ret: Reg(value.ret.0),
330            retl: Reg(value.retl.0),
331            flag: Reg(value.flag.0),
332        }
333    }
334}
335
336impl<'a> From<SystemRegisters<'a>> for SystemRegistersRef<'a> {
337    fn from(value: SystemRegisters<'a>) -> Self {
338        Self {
339            zero: Reg(value.zero.0),
340            one: Reg(value.one.0),
341            of: Reg(value.of.0),
342            pc: Reg(value.pc.0),
343            ssp: Reg(value.ssp.0),
344            sp: Reg(value.sp.0),
345            fp: Reg(value.fp.0),
346            hp: Reg(value.hp.0),
347            err: Reg(value.err.0),
348            ggas: Reg(value.ggas.0),
349            cgas: Reg(value.cgas.0),
350            bal: Reg(value.bal.0),
351            is: Reg(value.is.0),
352            ret: Reg(value.ret.0),
353            retl: Reg(value.retl.0),
354            flag: Reg(value.flag.0),
355        }
356    }
357}
358
359impl<'a> From<&'a ProgramRegisters<'_>> for ProgramRegistersRef<'a> {
360    fn from(value: &'a ProgramRegisters<'_>) -> Self {
361        Self(value.0)
362    }
363}
364
365impl<'a> From<ProgramRegisters<'a>> for ProgramRegistersRef<'a> {
366    fn from(value: ProgramRegisters<'a>) -> Self {
367        Self(value.0)
368    }
369}
370
371impl TryFrom<RegisterId> for WriteRegKey {
372    type Error = PanicReason;
373
374    fn try_from(r: RegisterId) -> Result<Self, Self::Error> {
375        Self::new(r)
376    }
377}
378
379impl core::ops::Index<WriteRegKey> for ProgramRegisters<'_> {
380    type Output = Word;
381
382    fn index(&self, index: WriteRegKey) -> &Self::Output {
383        &self.0[index.translate()]
384    }
385}
386
387impl core::ops::IndexMut<WriteRegKey> for ProgramRegisters<'_> {
388    fn index_mut(&mut self, index: WriteRegKey) -> &mut Self::Output {
389        &mut self.0[index.translate()]
390    }
391}
392
393impl<'a> From<&SystemRegistersRef<'a>> for [Word; VM_REGISTER_SYSTEM_COUNT] {
394    fn from(value: &SystemRegistersRef<'a>) -> Self {
395        let SystemRegistersRef {
396            zero,
397            one,
398            of,
399            pc,
400            ssp,
401            sp,
402            fp,
403            hp,
404            err,
405            ggas,
406            cgas,
407            bal,
408            is,
409            ret,
410            retl,
411            flag,
412        } = value;
413        [
414            *zero.0, *one.0, *of.0, *pc.0, *ssp.0, *sp.0, *fp.0, *hp.0, *err.0, *ggas.0,
415            *cgas.0, *bal.0, *is.0, *ret.0, *retl.0, *flag.0,
416        ]
417    }
418}
419
420#[derive(Debug, Clone, Copy)]
421pub(crate) enum ProgramRegistersSegment {
422    /// Registers 16..40
423    Low,
424    /// Registers 40..64
425    High,
426}
427
428impl<'r> ProgramRegisters<'r> {
429    /// Returns the registers corresponding to the segment, always 24 elements.
430    pub(crate) fn segment(&self, segment: ProgramRegistersSegment) -> &[Word] {
431        match segment {
432            ProgramRegistersSegment::Low => &self.0[..24],
433            ProgramRegistersSegment::High => &self.0[24..],
434        }
435    }
436
437    /// Returns the registers corresponding to the segment, always 24 elements.
438    pub(crate) fn segment_mut(
439        &mut self,
440        segment: ProgramRegistersSegment,
441    ) -> &mut [Word] {
442        match segment {
443            ProgramRegistersSegment::Low => &mut self.0[..24],
444            ProgramRegistersSegment::High => &mut self.0[24..],
445        }
446    }
447}