polkavm_common/
utils.rs

1#![allow(unsafe_code)]
2
3use core::mem::MaybeUninit;
4use core::ops::{Deref, Range};
5
6use crate::cast::cast;
7use crate::program::Reg;
8#[cfg(feature = "alloc")]
9use alloc::{borrow::Cow, sync::Arc, vec::Vec};
10
11#[cfg(feature = "alloc")]
12#[derive(Clone)]
13enum LifetimeObject {
14    None,
15    Arc {
16        _obj: Arc<[u8]>,
17    },
18    #[allow(dyn_drop)]
19    Other {
20        _obj: Arc<dyn Drop>,
21    },
22}
23
24#[derive(Clone)]
25pub struct ArcBytes {
26    pointer: core::ptr::NonNull<u8>,
27    length: usize,
28
29    #[cfg(feature = "alloc")]
30    lifetime: LifetimeObject,
31}
32
33impl core::fmt::Debug for ArcBytes {
34    fn fmt(&self, fmt: &mut core::fmt::Formatter) -> core::fmt::Result {
35        fmt.debug_struct("ArcBytes").field("data", &self.deref()).finish()
36    }
37}
38
39impl Default for ArcBytes {
40    fn default() -> Self {
41        ArcBytes::empty()
42    }
43}
44
45// SAFETY: It's always safe to send `ArcBytes` to another thread due to atomic refcounting.
46unsafe impl Send for ArcBytes {}
47
48// SAFETY: It's always safe to access `ArcBytes` from multiple threads due to atomic refcounting.
49unsafe impl Sync for ArcBytes {}
50
51impl ArcBytes {
52    pub const fn empty() -> Self {
53        ArcBytes {
54            pointer: core::ptr::NonNull::dangling(),
55            length: 0,
56
57            #[cfg(feature = "alloc")]
58            lifetime: LifetimeObject::None,
59        }
60    }
61
62    pub const fn from_static(bytes: &'static [u8]) -> Self {
63        ArcBytes {
64            // SAFETY: `bytes` is always a valid slice, so its pointer is also always non-null and valid.
65            pointer: unsafe { core::ptr::NonNull::new_unchecked(bytes.as_ptr().cast_mut()) },
66            length: bytes.len(),
67
68            #[cfg(feature = "alloc")]
69            lifetime: LifetimeObject::None,
70        }
71    }
72
73    pub(crate) fn subslice(&self, subrange: Range<usize>) -> Self {
74        if subrange.start == subrange.end {
75            return Self::empty();
76        }
77
78        assert!(subrange.end >= subrange.start);
79        let length = subrange.end - subrange.start;
80        assert!(length <= self.length);
81
82        ArcBytes {
83            // TODO: Use `NonNull::add` once we migrate to Rust 1.80+.
84            // SAFETY: We've checked that the new subslice is valid with `assert`s.
85            pointer: unsafe { core::ptr::NonNull::new_unchecked(self.pointer.as_ptr().add(subrange.start)) },
86            length,
87
88            #[cfg(feature = "alloc")]
89            lifetime: self.lifetime.clone(),
90        }
91    }
92}
93
94impl Eq for ArcBytes {}
95
96impl PartialEq for ArcBytes {
97    fn eq(&self, rhs: &ArcBytes) -> bool {
98        self.deref() == rhs.deref()
99    }
100}
101
102impl Deref for ArcBytes {
103    type Target = [u8];
104
105    fn deref(&self) -> &Self::Target {
106        // SAFETY: `pointer` is always non-null and `length` is always valid.
107        unsafe { core::slice::from_raw_parts(self.pointer.as_ptr(), self.length) }
108    }
109}
110
111impl AsRef<[u8]> for ArcBytes {
112    fn as_ref(&self) -> &[u8] {
113        self.deref()
114    }
115}
116
117#[cfg(feature = "alloc")]
118impl<'a> From<&'a [u8]> for ArcBytes {
119    fn from(data: &'a [u8]) -> Self {
120        let data: Arc<[u8]> = data.into();
121        Self::from(data)
122    }
123}
124
125#[cfg(not(feature = "alloc"))]
126impl<'a> From<&'static [u8]> for ArcBytes {
127    fn from(data: &'static [u8]) -> Self {
128        ArcBytes::from_static(data)
129    }
130}
131
132#[cfg(feature = "alloc")]
133impl From<Vec<u8>> for ArcBytes {
134    fn from(data: Vec<u8>) -> Self {
135        ArcBytes {
136            pointer: core::ptr::NonNull::new(data.as_ptr().cast_mut()).unwrap(),
137            length: data.len(),
138            lifetime: LifetimeObject::Other { _obj: Arc::new(data) },
139        }
140    }
141}
142
143#[cfg(feature = "alloc")]
144impl From<Arc<[u8]>> for ArcBytes {
145    fn from(data: Arc<[u8]>) -> Self {
146        ArcBytes {
147            pointer: core::ptr::NonNull::new(data.deref().as_ptr().cast_mut()).unwrap(),
148            length: data.len(),
149            lifetime: LifetimeObject::Arc { _obj: data },
150        }
151    }
152}
153
154#[cfg(feature = "alloc")]
155impl<'a> From<Cow<'a, [u8]>> for ArcBytes {
156    fn from(cow: Cow<'a, [u8]>) -> Self {
157        match cow {
158            Cow::Borrowed(data) => data.into(),
159            Cow::Owned(data) => data.into(),
160        }
161    }
162}
163
164macro_rules! define_align_to_next_page {
165    ($name:ident, $type:ty) => {
166        /// Aligns the `value` to the next `page_size`, or returns the `value` as-is if it's already aligned.
167        #[inline]
168        pub const fn $name(page_size: $type, value: $type) -> Option<$type> {
169            assert!(
170                page_size != 0 && (page_size & (page_size - 1)) == 0,
171                "page size is not a power of two"
172            );
173            if value & page_size - 1 == 0 {
174                Some(value)
175            } else {
176                if value <= <$type>::MAX - page_size {
177                    Some((value + page_size) & !(page_size - 1))
178                } else {
179                    None
180                }
181            }
182        }
183    };
184}
185
186define_align_to_next_page!(align_to_next_page_u32, u32);
187define_align_to_next_page!(align_to_next_page_u64, u64);
188define_align_to_next_page!(align_to_next_page_usize, usize);
189
190#[test]
191fn test_align_to_next_page() {
192    assert_eq!(align_to_next_page_u64(4096, 0), Some(0));
193    assert_eq!(align_to_next_page_u64(4096, 1), Some(4096));
194    assert_eq!(align_to_next_page_u64(4096, 4095), Some(4096));
195    assert_eq!(align_to_next_page_u64(4096, 4096), Some(4096));
196    assert_eq!(align_to_next_page_u64(4096, 4097), Some(8192));
197    let max = (0x10000000000000000_u128 - 4096) as u64;
198    assert_eq!(align_to_next_page_u64(4096, max), Some(max));
199    assert_eq!(align_to_next_page_u64(4096, max + 1), None);
200}
201
202pub trait AsUninitSliceMut {
203    fn as_uninit_slice_mut(&mut self) -> &mut [MaybeUninit<u8>];
204}
205
206impl AsUninitSliceMut for [MaybeUninit<u8>] {
207    fn as_uninit_slice_mut(&mut self) -> &mut [MaybeUninit<u8>] {
208        self
209    }
210}
211
212impl AsUninitSliceMut for [u8] {
213    fn as_uninit_slice_mut(&mut self) -> &mut [MaybeUninit<u8>] {
214        #[allow(unsafe_code)]
215        // SAFETY: `MaybeUnunit<T>` is guaranteed to have the same representation as `T`,
216        //         so casting `[T]` into `[MaybeUninit<T>]` is safe.
217        unsafe {
218            core::slice::from_raw_parts_mut(self.as_mut_ptr().cast(), self.len())
219        }
220    }
221}
222
223impl<const N: usize> AsUninitSliceMut for MaybeUninit<[u8; N]> {
224    fn as_uninit_slice_mut(&mut self) -> &mut [MaybeUninit<u8>] {
225        #[allow(unsafe_code)]
226        // SAFETY: `MaybeUnunit<T>` is guaranteed to have the same representation as `T`,
227        //         so casting `[T; N]` into `[MaybeUninit<T>]` is safe.
228        unsafe {
229            core::slice::from_raw_parts_mut(self.as_mut_ptr().cast(), N)
230        }
231    }
232}
233
234impl<const N: usize> AsUninitSliceMut for [u8; N] {
235    fn as_uninit_slice_mut(&mut self) -> &mut [MaybeUninit<u8>] {
236        let slice: &mut [u8] = &mut self[..];
237        slice.as_uninit_slice_mut()
238    }
239}
240
241// Copied from `MaybeUninit::slice_assume_init_mut`.
242// TODO: Remove this once this API is stabilized.
243#[allow(clippy::missing_safety_doc)]
244#[allow(unsafe_code)]
245pub unsafe fn slice_assume_init_mut<T>(slice: &mut [MaybeUninit<T>]) -> &mut [T] {
246    // SAFETY: The caller is responsible for making sure the `slice` was properly initialized.
247    unsafe { &mut *(slice as *mut [MaybeUninit<T>] as *mut [T]) }
248}
249
250#[allow(unsafe_code)]
251pub fn byte_slice_init<'dst>(dst: &'dst mut [MaybeUninit<u8>], src: &[u8]) -> &'dst mut [u8] {
252    assert_eq!(dst.len(), src.len());
253
254    let length = dst.len();
255    let src_ptr: *const u8 = src.as_ptr();
256    let dst_ptr: *mut u8 = dst.as_mut_ptr().cast::<u8>();
257
258    // SAFETY: Both pointers are valid and are guaranteed to point to a region of memory
259    // at least `length` bytes big.
260    unsafe {
261        core::ptr::copy_nonoverlapping(src_ptr, dst_ptr, length);
262    }
263
264    // SAFETY: We've just initialized this slice.
265    unsafe { slice_assume_init_mut(dst) }
266}
267
268pub fn parse_imm(text: &str) -> Option<i32> {
269    let text = text.trim();
270    if let Some(text) = text.strip_prefix("0x") {
271        return u32::from_str_radix(text, 16).ok().map(|value| value as i32);
272    }
273
274    if let Some(text) = text.strip_prefix("0b") {
275        return u32::from_str_radix(text, 2).ok().map(|value| value as i32);
276    }
277
278    if let Ok(value) = text.parse::<i32>() {
279        Some(value)
280    } else if let Ok(value) = text.parse::<u32>() {
281        Some(value as i32)
282    } else {
283        None
284    }
285}
286
287#[derive(Debug, PartialEq)]
288pub enum ParsedImmediate {
289    U32(u32),
290    U64(u64),
291}
292
293impl TryFrom<ParsedImmediate> for u32 {
294    type Error = &'static str;
295
296    fn try_from(value: ParsedImmediate) -> Result<Self, Self::Error> {
297        match value {
298            ParsedImmediate::U32(v) => Ok(v),
299            ParsedImmediate::U64(_) => Err("value is too large for u32"),
300        }
301    }
302}
303
304impl From<ParsedImmediate> for u64 {
305    fn from(value: ParsedImmediate) -> Self {
306        match value {
307            ParsedImmediate::U32(v) => cast(v).to_u64_sign_extend(),
308            ParsedImmediate::U64(v) => v,
309        }
310    }
311}
312
313pub fn parse_immediate(text: &str) -> Option<ParsedImmediate> {
314    let text = text.trim();
315
316    let (force_imm64, text) = if let Some(text) = text.strip_prefix("i64 ") {
317        (true, text.trim())
318    } else {
319        (false, text)
320    };
321
322    let value = if let Some(text) = text.strip_prefix("0x") {
323        u64::from_str_radix(text, 16).ok()?
324    } else if let Some(text) = text.strip_prefix("0b") {
325        u64::from_str_radix(text, 2).ok()?
326    } else {
327        match text.parse::<i64>() {
328            Ok(signed) => signed as u64,
329            Err(_) => return None,
330        }
331    };
332
333    if force_imm64 {
334        return Some(ParsedImmediate::U64(value));
335    }
336
337    if value < 0x7fffffff || cast(cast(value).truncate_to_u32()).to_u64_sign_extend() == value {
338        Some(ParsedImmediate::U32(cast(value).truncate_to_u32()))
339    } else {
340        Some(ParsedImmediate::U64(value))
341    }
342}
343
344#[test]
345fn test_parse_immediate() {
346    // "special cases"
347    assert_eq!(parse_immediate("0xffffffff"), Some(ParsedImmediate::U64(0xffffffff)));
348    assert_eq!(parse_immediate("0xffffffff87654321"), Some(ParsedImmediate::U32(0x87654321)));
349    assert_eq!(parse_immediate("0x80000075"), Some(ParsedImmediate::U64(0x80000075)));
350    // "normal cases"
351    assert_eq!(parse_immediate("0x1234"), Some(ParsedImmediate::U32(0x1234)));
352    assert_eq!(parse_immediate("0x12345678"), Some(ParsedImmediate::U32(0x12345678)));
353    assert_eq!(parse_immediate("0x1234567890"), Some(ParsedImmediate::U64(0x1234567890)));
354    assert_eq!(parse_immediate("-1"), Some(ParsedImmediate::U32(0xffffffff)));
355    assert_eq!(parse_immediate("-2"), Some(ParsedImmediate::U32(0xfffffffe)));
356    assert_eq!(parse_immediate("i64 0xffffffff"), Some(ParsedImmediate::U64(0xffffffff)));
357    assert_eq!(parse_immediate("0xdeadbeef"), Some(ParsedImmediate::U64(0xdeadbeef)));
358    assert_eq!(
359        parse_immediate("0xffffffff00000000"),
360        Some(ParsedImmediate::U64(0xffffffff00000000))
361    );
362    assert_eq!(parse_immediate("0xf000000e").map(Into::into), Some(0xf000000eu64));
363    assert_eq!(parse_immediate("0x80000075").and_then(|imm| imm.try_into().ok()), None::<u32>);
364}
365
366pub fn parse_reg(text: &str) -> Option<Reg> {
367    const REG_NAME_ALT: [&str; 13] = ["r0", "r1", "r2", "r3", "r4", "r5", "r6", "r7", "r8", "r9", "r10", "r11", "r12"];
368
369    let text = text.trim();
370    for (reg, name_alt) in Reg::ALL.into_iter().zip(REG_NAME_ALT) {
371        if text == reg.name() || text == name_alt {
372            return Some(reg);
373        }
374    }
375
376    None
377}
378
379#[test]
380fn test_arc_bytes() {
381    assert_eq!(&*ArcBytes::empty(), b"");
382    assert_eq!(ArcBytes::empty().as_ptr(), ArcBytes::empty().as_ptr());
383
384    #[cfg(feature = "alloc")]
385    #[allow(clippy::redundant_clone)]
386    {
387        let ab = ArcBytes::from(alloc::vec![1, 2, 3, 4]);
388        assert_eq!(ab.as_ptr(), ab.as_ptr());
389        assert_eq!(ab.clone().as_ptr(), ab.as_ptr());
390        assert_eq!(&*ab, &[1, 2, 3, 4]);
391        assert_eq!(&*ab.subslice(0..4), &[1, 2, 3, 4]);
392        assert_eq!(&*ab.subslice(0..3), &[1, 2, 3]);
393        assert_eq!(&*ab.subslice(1..4), &[2, 3, 4]);
394
395        let mut arc = Arc::<[u8]>::from(alloc::vec![1, 2, 3, 4]);
396        assert!(Arc::get_mut(&mut arc).is_some());
397        let ab2 = ArcBytes::from(Arc::clone(&arc));
398        assert!(Arc::get_mut(&mut arc).is_none());
399        assert_eq!(ab2.as_ptr(), ab2.as_ptr());
400        assert_eq!(ab2.clone().as_ptr(), ab2.as_ptr());
401        core::mem::drop(ab2);
402        assert!(Arc::get_mut(&mut arc).is_some());
403    }
404}