solana_bpf_loader_program/syscalls/
mem_ops.rs

1use {
2    super::*,
3    solana_program_runtime::invoke_context::SerializedAccountMetadata,
4    solana_sbpf::{error::EbpfError, memory_region::MemoryRegion},
5    std::slice,
6};
7
8fn mem_op_consume(invoke_context: &mut InvokeContext, n: u64) -> Result<(), Error> {
9    let compute_budget = invoke_context.get_compute_budget();
10    let cost = compute_budget.mem_op_base_cost.max(
11        n.checked_div(compute_budget.cpi_bytes_per_unit)
12            .unwrap_or(u64::MAX),
13    );
14    consume_compute_meter(invoke_context, cost)
15}
16
17declare_builtin_function!(
18    /// memcpy
19    SyscallMemcpy,
20    fn rust(
21        invoke_context: &mut InvokeContext,
22        dst_addr: u64,
23        src_addr: u64,
24        n: u64,
25        _arg4: u64,
26        _arg5: u64,
27        memory_mapping: &mut MemoryMapping,
28    ) -> Result<u64, Error> {
29        mem_op_consume(invoke_context, n)?;
30
31        if !is_nonoverlapping(src_addr, n, dst_addr, n) {
32            return Err(SyscallError::CopyOverlapping.into());
33        }
34
35        // host addresses can overlap so we always invoke memmove
36        memmove(invoke_context, dst_addr, src_addr, n, memory_mapping)
37    }
38);
39
40declare_builtin_function!(
41    /// memmove
42    SyscallMemmove,
43    fn rust(
44        invoke_context: &mut InvokeContext,
45        dst_addr: u64,
46        src_addr: u64,
47        n: u64,
48        _arg4: u64,
49        _arg5: u64,
50        memory_mapping: &mut MemoryMapping,
51    ) -> Result<u64, Error> {
52        mem_op_consume(invoke_context, n)?;
53
54        memmove(invoke_context, dst_addr, src_addr, n, memory_mapping)
55    }
56);
57
58declare_builtin_function!(
59    /// memcmp
60    SyscallMemcmp,
61    fn rust(
62        invoke_context: &mut InvokeContext,
63        s1_addr: u64,
64        s2_addr: u64,
65        n: u64,
66        cmp_result_addr: u64,
67        _arg5: u64,
68        memory_mapping: &mut MemoryMapping,
69    ) -> Result<u64, Error> {
70        mem_op_consume(invoke_context, n)?;
71
72        if invoke_context
73            .get_feature_set()
74            .is_active(&solana_feature_set::bpf_account_data_direct_mapping::id())
75        {
76            let cmp_result = translate_type_mut::<i32>(
77                memory_mapping,
78                cmp_result_addr,
79                invoke_context.get_check_aligned(),
80            )?;
81            let syscall_context = invoke_context.get_syscall_context()?;
82
83            *cmp_result = memcmp_non_contiguous(s1_addr, s2_addr, n, &syscall_context.accounts_metadata, memory_mapping, invoke_context.get_check_aligned())?;
84        } else {
85            let s1 = translate_slice::<u8>(
86                memory_mapping,
87                s1_addr,
88                n,
89                invoke_context.get_check_aligned(),
90            )?;
91            let s2 = translate_slice::<u8>(
92                memory_mapping,
93                s2_addr,
94                n,
95                invoke_context.get_check_aligned(),
96            )?;
97            let cmp_result = translate_type_mut::<i32>(
98                memory_mapping,
99                cmp_result_addr,
100                invoke_context.get_check_aligned(),
101            )?;
102
103            debug_assert_eq!(s1.len(), n as usize);
104            debug_assert_eq!(s2.len(), n as usize);
105            // Safety:
106            // memcmp is marked unsafe since it assumes that the inputs are at least
107            // `n` bytes long. `s1` and `s2` are guaranteed to be exactly `n` bytes
108            // long because `translate_slice` would have failed otherwise.
109            *cmp_result = unsafe { memcmp(s1, s2, n as usize) };
110        }
111
112        Ok(0)
113    }
114);
115
116declare_builtin_function!(
117    /// memset
118    SyscallMemset,
119    fn rust(
120        invoke_context: &mut InvokeContext,
121        dst_addr: u64,
122        c: u64,
123        n: u64,
124        _arg4: u64,
125        _arg5: u64,
126        memory_mapping: &mut MemoryMapping,
127    ) -> Result<u64, Error> {
128        mem_op_consume(invoke_context, n)?;
129
130        if invoke_context
131            .get_feature_set()
132            .is_active(&solana_feature_set::bpf_account_data_direct_mapping::id())
133        {
134            let syscall_context = invoke_context.get_syscall_context()?;
135
136            memset_non_contiguous(dst_addr, c as u8, n, &syscall_context.accounts_metadata, memory_mapping, invoke_context.get_check_aligned())
137        } else {
138            let s = translate_slice_mut::<u8>(
139                memory_mapping,
140                dst_addr,
141                n,
142                invoke_context.get_check_aligned(),
143            )?;
144            s.fill(c as u8);
145            Ok(0)
146        }
147    }
148);
149
150fn memmove(
151    invoke_context: &mut InvokeContext,
152    dst_addr: u64,
153    src_addr: u64,
154    n: u64,
155    memory_mapping: &MemoryMapping,
156) -> Result<u64, Error> {
157    if invoke_context
158        .get_feature_set()
159        .is_active(&solana_feature_set::bpf_account_data_direct_mapping::id())
160    {
161        let syscall_context = invoke_context.get_syscall_context()?;
162
163        memmove_non_contiguous(
164            dst_addr,
165            src_addr,
166            n,
167            &syscall_context.accounts_metadata,
168            memory_mapping,
169            invoke_context.get_check_aligned(),
170        )
171    } else {
172        let dst_ptr = translate_slice_mut::<u8>(
173            memory_mapping,
174            dst_addr,
175            n,
176            invoke_context.get_check_aligned(),
177        )?
178        .as_mut_ptr();
179        let src_ptr = translate_slice::<u8>(
180            memory_mapping,
181            src_addr,
182            n,
183            invoke_context.get_check_aligned(),
184        )?
185        .as_ptr();
186
187        unsafe { std::ptr::copy(src_ptr, dst_ptr, n as usize) };
188        Ok(0)
189    }
190}
191
192fn memmove_non_contiguous(
193    dst_addr: u64,
194    src_addr: u64,
195    n: u64,
196    accounts: &[SerializedAccountMetadata],
197    memory_mapping: &MemoryMapping,
198    resize_area: bool,
199) -> Result<u64, Error> {
200    let reverse = dst_addr.wrapping_sub(src_addr) < n;
201    iter_memory_pair_chunks(
202        AccessType::Load,
203        src_addr,
204        AccessType::Store,
205        dst_addr,
206        n,
207        accounts,
208        memory_mapping,
209        reverse,
210        resize_area,
211        |src_host_addr, dst_host_addr, chunk_len| {
212            unsafe { std::ptr::copy(src_host_addr, dst_host_addr as *mut u8, chunk_len) };
213            Ok(0)
214        },
215    )
216}
217
218// Marked unsafe since it assumes that the slices are at least `n` bytes long.
219unsafe fn memcmp(s1: &[u8], s2: &[u8], n: usize) -> i32 {
220    for i in 0..n {
221        let a = *s1.get_unchecked(i);
222        let b = *s2.get_unchecked(i);
223        if a != b {
224            return (a as i32).saturating_sub(b as i32);
225        };
226    }
227
228    0
229}
230
231fn memcmp_non_contiguous(
232    src_addr: u64,
233    dst_addr: u64,
234    n: u64,
235    accounts: &[SerializedAccountMetadata],
236    memory_mapping: &MemoryMapping,
237    resize_area: bool,
238) -> Result<i32, Error> {
239    let memcmp_chunk = |s1_addr, s2_addr, chunk_len| {
240        let res = unsafe {
241            let s1 = slice::from_raw_parts(s1_addr, chunk_len);
242            let s2 = slice::from_raw_parts(s2_addr, chunk_len);
243            // Safety:
244            // memcmp is marked unsafe since it assumes that s1 and s2 are exactly chunk_len
245            // long. The whole point of iter_memory_pair_chunks is to find same length chunks
246            // across two memory regions.
247            memcmp(s1, s2, chunk_len)
248        };
249        if res != 0 {
250            return Err(MemcmpError::Diff(res).into());
251        }
252        Ok(0)
253    };
254    match iter_memory_pair_chunks(
255        AccessType::Load,
256        src_addr,
257        AccessType::Load,
258        dst_addr,
259        n,
260        accounts,
261        memory_mapping,
262        false,
263        resize_area,
264        memcmp_chunk,
265    ) {
266        Ok(res) => Ok(res),
267        Err(error) => match error.downcast_ref() {
268            Some(MemcmpError::Diff(diff)) => Ok(*diff),
269            _ => Err(error),
270        },
271    }
272}
273
274#[derive(Debug)]
275enum MemcmpError {
276    Diff(i32),
277}
278
279impl std::fmt::Display for MemcmpError {
280    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
281        match self {
282            MemcmpError::Diff(diff) => write!(f, "memcmp diff: {diff}"),
283        }
284    }
285}
286
287impl std::error::Error for MemcmpError {
288    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
289        match self {
290            MemcmpError::Diff(_) => None,
291        }
292    }
293}
294
295fn memset_non_contiguous(
296    dst_addr: u64,
297    c: u8,
298    n: u64,
299    accounts: &[SerializedAccountMetadata],
300    memory_mapping: &MemoryMapping,
301    check_aligned: bool,
302) -> Result<u64, Error> {
303    let dst_chunk_iter = MemoryChunkIterator::new(
304        memory_mapping,
305        accounts,
306        AccessType::Store,
307        dst_addr,
308        n,
309        check_aligned,
310    )?;
311    for item in dst_chunk_iter {
312        let (dst_region, dst_vm_addr, dst_len) = item?;
313        let dst_host_addr = Result::from(dst_region.vm_to_host(dst_vm_addr, dst_len as u64))?;
314        unsafe { slice::from_raw_parts_mut(dst_host_addr as *mut u8, dst_len).fill(c) }
315    }
316
317    Ok(0)
318}
319
320#[allow(clippy::too_many_arguments)]
321fn iter_memory_pair_chunks<T, F>(
322    src_access: AccessType,
323    src_addr: u64,
324    dst_access: AccessType,
325    dst_addr: u64,
326    n_bytes: u64,
327    accounts: &[SerializedAccountMetadata],
328    memory_mapping: &MemoryMapping,
329    reverse: bool,
330    resize_area: bool,
331    mut fun: F,
332) -> Result<T, Error>
333where
334    T: Default,
335    F: FnMut(*const u8, *const u8, usize) -> Result<T, Error>,
336{
337    let mut src_chunk_iter = MemoryChunkIterator::new(
338        memory_mapping,
339        accounts,
340        src_access,
341        src_addr,
342        n_bytes,
343        resize_area,
344    )
345    .map_err(EbpfError::from)?;
346    let mut dst_chunk_iter = MemoryChunkIterator::new(
347        memory_mapping,
348        accounts,
349        dst_access,
350        dst_addr,
351        n_bytes,
352        resize_area,
353    )
354    .map_err(EbpfError::from)?;
355
356    let mut src_chunk = None;
357    let mut dst_chunk = None;
358
359    macro_rules! memory_chunk {
360        ($chunk_iter:ident, $chunk:ident) => {
361            if let Some($chunk) = &mut $chunk {
362                // Keep processing the current chunk
363                $chunk
364            } else {
365                // This is either the first call or we've processed all the bytes in the current
366                // chunk. Move to the next one.
367                let chunk = match if reverse {
368                    $chunk_iter.next_back()
369                } else {
370                    $chunk_iter.next()
371                } {
372                    Some(item) => item?,
373                    None => break,
374                };
375                $chunk.insert(chunk)
376            }
377        };
378    }
379
380    loop {
381        let (src_region, src_chunk_addr, src_remaining) = memory_chunk!(src_chunk_iter, src_chunk);
382        let (dst_region, dst_chunk_addr, dst_remaining) = memory_chunk!(dst_chunk_iter, dst_chunk);
383
384        // We always process same-length pairs
385        let chunk_len = *src_remaining.min(dst_remaining);
386
387        let (src_host_addr, dst_host_addr) = {
388            let (src_addr, dst_addr) = if reverse {
389                // When scanning backwards not only we want to scan regions from the end,
390                // we want to process the memory within regions backwards as well.
391                (
392                    src_chunk_addr
393                        .saturating_add(*src_remaining as u64)
394                        .saturating_sub(chunk_len as u64),
395                    dst_chunk_addr
396                        .saturating_add(*dst_remaining as u64)
397                        .saturating_sub(chunk_len as u64),
398                )
399            } else {
400                (*src_chunk_addr, *dst_chunk_addr)
401            };
402
403            (
404                Result::from(src_region.vm_to_host(src_addr, chunk_len as u64))?,
405                Result::from(dst_region.vm_to_host(dst_addr, chunk_len as u64))?,
406            )
407        };
408
409        fun(
410            src_host_addr as *const u8,
411            dst_host_addr as *const u8,
412            chunk_len,
413        )?;
414
415        // Update how many bytes we have left to scan in each chunk
416        *src_remaining = src_remaining.saturating_sub(chunk_len);
417        *dst_remaining = dst_remaining.saturating_sub(chunk_len);
418
419        if !reverse {
420            // We've scanned `chunk_len` bytes so we move the vm address forward. In reverse
421            // mode we don't do this since we make progress by decreasing src_len and
422            // dst_len.
423            *src_chunk_addr = src_chunk_addr.saturating_add(chunk_len as u64);
424            *dst_chunk_addr = dst_chunk_addr.saturating_add(chunk_len as u64);
425        }
426
427        if *src_remaining == 0 {
428            src_chunk = None;
429        }
430
431        if *dst_remaining == 0 {
432            dst_chunk = None;
433        }
434    }
435
436    Ok(T::default())
437}
438
439struct MemoryChunkIterator<'a> {
440    memory_mapping: &'a MemoryMapping<'a>,
441    accounts: &'a [SerializedAccountMetadata],
442    access_type: AccessType,
443    initial_vm_addr: u64,
444    vm_addr_start: u64,
445    // exclusive end index (start + len, so one past the last valid address)
446    vm_addr_end: u64,
447    len: u64,
448    account_index: Option<usize>,
449    is_account: Option<bool>,
450    resize_area: bool,
451}
452
453impl<'a> MemoryChunkIterator<'a> {
454    fn new(
455        memory_mapping: &'a MemoryMapping,
456        accounts: &'a [SerializedAccountMetadata],
457        access_type: AccessType,
458        vm_addr: u64,
459        len: u64,
460        resize_area: bool,
461    ) -> Result<MemoryChunkIterator<'a>, EbpfError> {
462        let vm_addr_end = vm_addr.checked_add(len).ok_or(EbpfError::AccessViolation(
463            access_type,
464            vm_addr,
465            len,
466            "unknown",
467        ))?;
468
469        Ok(MemoryChunkIterator {
470            memory_mapping,
471            accounts,
472            access_type,
473            initial_vm_addr: vm_addr,
474            len,
475            vm_addr_start: vm_addr,
476            vm_addr_end,
477            account_index: None,
478            is_account: None,
479            resize_area,
480        })
481    }
482
483    fn region(&mut self, vm_addr: u64) -> Result<&'a MemoryRegion, Error> {
484        match self.memory_mapping.region(self.access_type, vm_addr) {
485            Ok(region) => Ok(region),
486            Err(error) => match error {
487                EbpfError::AccessViolation(access_type, _vm_addr, _len, name) => Err(Box::new(
488                    EbpfError::AccessViolation(access_type, self.initial_vm_addr, self.len, name),
489                )),
490                EbpfError::StackAccessViolation(access_type, _vm_addr, _len, frame) => {
491                    Err(Box::new(EbpfError::StackAccessViolation(
492                        access_type,
493                        self.initial_vm_addr,
494                        self.len,
495                        frame,
496                    )))
497                }
498                _ => Err(error.into()),
499            },
500        }
501    }
502}
503
504impl<'a> Iterator for MemoryChunkIterator<'a> {
505    type Item = Result<(&'a MemoryRegion, u64, usize), Error>;
506
507    fn next(&mut self) -> Option<Self::Item> {
508        if self.vm_addr_start == self.vm_addr_end {
509            return None;
510        }
511
512        let region = match self.region(self.vm_addr_start) {
513            Ok(region) => region,
514            Err(e) => {
515                self.vm_addr_start = self.vm_addr_end;
516                return Some(Err(e));
517            }
518        };
519
520        let region_is_account;
521
522        let mut account_index = self.account_index.unwrap_or_default();
523        self.account_index = Some(account_index);
524
525        loop {
526            if let Some(account) = self.accounts.get(account_index) {
527                let account_addr = account.vm_data_addr;
528                let resize_addr = account_addr.saturating_add(account.original_data_len as u64);
529
530                if resize_addr < region.vm_addr {
531                    // region is after this account, move on next one
532                    account_index = account_index.saturating_add(1);
533                    self.account_index = Some(account_index);
534                } else {
535                    region_is_account = region.vm_addr == account_addr
536                        // unaligned programs do not have a resize area
537                        || (self.resize_area && region.vm_addr == resize_addr);
538                    break;
539                }
540            } else {
541                // address is after all the accounts
542                region_is_account = false;
543                break;
544            }
545        }
546
547        if let Some(is_account) = self.is_account {
548            if is_account != region_is_account {
549                return Some(Err(SyscallError::InvalidLength.into()));
550            }
551        } else {
552            self.is_account = Some(region_is_account);
553        }
554
555        let vm_addr = self.vm_addr_start;
556
557        let chunk_len = if region.vm_addr_end <= self.vm_addr_end {
558            // consume the whole region
559            let len = region.vm_addr_end.saturating_sub(self.vm_addr_start);
560            self.vm_addr_start = region.vm_addr_end;
561            len
562        } else {
563            // consume part of the region
564            let len = self.vm_addr_end.saturating_sub(self.vm_addr_start);
565            self.vm_addr_start = self.vm_addr_end;
566            len
567        };
568
569        Some(Ok((region, vm_addr, chunk_len as usize)))
570    }
571}
572
573impl DoubleEndedIterator for MemoryChunkIterator<'_> {
574    fn next_back(&mut self) -> Option<Self::Item> {
575        if self.vm_addr_start == self.vm_addr_end {
576            return None;
577        }
578
579        let region = match self.region(self.vm_addr_end.saturating_sub(1)) {
580            Ok(region) => region,
581            Err(e) => {
582                self.vm_addr_start = self.vm_addr_end;
583                return Some(Err(e));
584            }
585        };
586
587        let region_is_account;
588
589        let mut account_index = self
590            .account_index
591            .unwrap_or_else(|| self.accounts.len().saturating_sub(1));
592        self.account_index = Some(account_index);
593
594        loop {
595            let Some(account) = self.accounts.get(account_index) else {
596                // address is after all the accounts
597                region_is_account = false;
598                break;
599            };
600
601            let account_addr = account.vm_data_addr;
602            let resize_addr = account_addr.saturating_add(account.original_data_len as u64);
603
604            if account_index > 0 && account_addr > region.vm_addr {
605                account_index = account_index.saturating_sub(1);
606
607                self.account_index = Some(account_index);
608            } else {
609                region_is_account = region.vm_addr == account_addr
610                    // unaligned programs do not have a resize area
611                    || (self.resize_area && region.vm_addr == resize_addr);
612                break;
613            }
614        }
615
616        if let Some(is_account) = self.is_account {
617            if is_account != region_is_account {
618                return Some(Err(SyscallError::InvalidLength.into()));
619            }
620        } else {
621            self.is_account = Some(region_is_account);
622        }
623
624        let chunk_len = if region.vm_addr >= self.vm_addr_start {
625            // consume the whole region
626            let len = self.vm_addr_end.saturating_sub(region.vm_addr);
627            self.vm_addr_end = region.vm_addr;
628            len
629        } else {
630            // consume part of the region
631            let len = self.vm_addr_end.saturating_sub(self.vm_addr_start);
632            self.vm_addr_end = self.vm_addr_start;
633            len
634        };
635
636        Some(Ok((region, self.vm_addr_end, chunk_len as usize)))
637    }
638}
639
640#[cfg(test)]
641#[allow(clippy::indexing_slicing)]
642#[allow(clippy::arithmetic_side_effects)]
643mod tests {
644    use {
645        super::*,
646        assert_matches::assert_matches,
647        solana_sbpf::{ebpf::MM_RODATA_START, program::SBPFVersion},
648        test_case::test_case,
649    };
650
651    fn to_chunk_vec<'a>(
652        iter: impl Iterator<Item = Result<(&'a MemoryRegion, u64, usize), Error>>,
653    ) -> Vec<(u64, usize)> {
654        iter.flat_map(|res| res.map(|(_, vm_addr, len)| (vm_addr, len)))
655            .collect::<Vec<_>>()
656    }
657
658    #[test]
659    #[should_panic(expected = "AccessViolation")]
660    fn test_memory_chunk_iterator_no_regions() {
661        let config = Config {
662            aligned_memory_mapping: false,
663            ..Config::default()
664        };
665        let memory_mapping = MemoryMapping::new(vec![], &config, SBPFVersion::V3).unwrap();
666
667        let mut src_chunk_iter =
668            MemoryChunkIterator::new(&memory_mapping, &[], AccessType::Load, 0, 1, true).unwrap();
669        src_chunk_iter.next().unwrap().unwrap();
670    }
671
672    #[test]
673    #[should_panic(expected = "AccessViolation")]
674    fn test_memory_chunk_iterator_new_out_of_bounds_upper() {
675        let config = Config {
676            aligned_memory_mapping: false,
677            ..Config::default()
678        };
679        let memory_mapping = MemoryMapping::new(vec![], &config, SBPFVersion::V3).unwrap();
680
681        let mut src_chunk_iter =
682            MemoryChunkIterator::new(&memory_mapping, &[], AccessType::Load, u64::MAX, 1, true)
683                .unwrap();
684        src_chunk_iter.next().unwrap().unwrap();
685    }
686
687    #[test]
688    fn test_memory_chunk_iterator_out_of_bounds() {
689        let config = Config {
690            aligned_memory_mapping: false,
691            ..Config::default()
692        };
693        let mem1 = vec![0xFF; 42];
694        let memory_mapping = MemoryMapping::new(
695            vec![MemoryRegion::new_readonly(&mem1, MM_RODATA_START)],
696            &config,
697            SBPFVersion::V3,
698        )
699        .unwrap();
700
701        // check oob at the lower bound on the first next()
702        let mut src_chunk_iter = MemoryChunkIterator::new(
703            &memory_mapping,
704            &[],
705            AccessType::Load,
706            MM_RODATA_START - 1,
707            42,
708            true,
709        )
710        .unwrap();
711        assert_matches!(
712            src_chunk_iter.next().unwrap().unwrap_err().downcast_ref().unwrap(),
713            EbpfError::AccessViolation(AccessType::Load, addr, 42, "unknown") if *addr == MM_RODATA_START - 1
714        );
715
716        // check oob at the upper bound. Since the memory mapping isn't empty,
717        // this always happens on the second next().
718        let mut src_chunk_iter = MemoryChunkIterator::new(
719            &memory_mapping,
720            &[],
721            AccessType::Load,
722            MM_RODATA_START,
723            43,
724            true,
725        )
726        .unwrap();
727        assert!(src_chunk_iter.next().unwrap().is_ok());
728        assert_matches!(
729            src_chunk_iter.next().unwrap().unwrap_err().downcast_ref().unwrap(),
730            EbpfError::AccessViolation(AccessType::Load, addr, 43, "program") if *addr == MM_RODATA_START
731        );
732
733        // check oob at the upper bound on the first next_back()
734        let mut src_chunk_iter = MemoryChunkIterator::new(
735            &memory_mapping,
736            &[],
737            AccessType::Load,
738            MM_RODATA_START,
739            43,
740            true,
741        )
742        .unwrap()
743        .rev();
744        assert_matches!(
745            src_chunk_iter.next().unwrap().unwrap_err().downcast_ref().unwrap(),
746            EbpfError::AccessViolation(AccessType::Load, addr, 43, "program") if *addr == MM_RODATA_START
747        );
748
749        // check oob at the upper bound on the 2nd next_back()
750        let mut src_chunk_iter = MemoryChunkIterator::new(
751            &memory_mapping,
752            &[],
753            AccessType::Load,
754            MM_RODATA_START - 1,
755            43,
756            true,
757        )
758        .unwrap()
759        .rev();
760        assert!(src_chunk_iter.next().unwrap().is_ok());
761        assert_matches!(
762            src_chunk_iter.next().unwrap().unwrap_err().downcast_ref().unwrap(),
763            EbpfError::AccessViolation(AccessType::Load, addr, 43, "unknown") if *addr == MM_RODATA_START - 1
764        );
765    }
766
767    #[test]
768    fn test_memory_chunk_iterator_one() {
769        let config = Config {
770            aligned_memory_mapping: false,
771            ..Config::default()
772        };
773        let mem1 = vec![0xFF; 42];
774        let memory_mapping = MemoryMapping::new(
775            vec![MemoryRegion::new_readonly(&mem1, MM_RODATA_START)],
776            &config,
777            SBPFVersion::V3,
778        )
779        .unwrap();
780
781        // check lower bound
782        let mut src_chunk_iter = MemoryChunkIterator::new(
783            &memory_mapping,
784            &[],
785            AccessType::Load,
786            MM_RODATA_START - 1,
787            1,
788            true,
789        )
790        .unwrap();
791        assert!(src_chunk_iter.next().unwrap().is_err());
792
793        // check upper bound
794        let mut src_chunk_iter = MemoryChunkIterator::new(
795            &memory_mapping,
796            &[],
797            AccessType::Load,
798            MM_RODATA_START + 42,
799            1,
800            true,
801        )
802        .unwrap();
803        assert!(src_chunk_iter.next().unwrap().is_err());
804
805        for (vm_addr, len) in [
806            (MM_RODATA_START, 0),
807            (MM_RODATA_START + 42, 0),
808            (MM_RODATA_START, 1),
809            (MM_RODATA_START, 42),
810            (MM_RODATA_START + 41, 1),
811        ] {
812            for rev in [true, false] {
813                let iter = MemoryChunkIterator::new(
814                    &memory_mapping,
815                    &[],
816                    AccessType::Load,
817                    vm_addr,
818                    len,
819                    true,
820                )
821                .unwrap();
822                let res = if rev {
823                    to_chunk_vec(iter.rev())
824                } else {
825                    to_chunk_vec(iter)
826                };
827                if len == 0 {
828                    assert_eq!(res, &[]);
829                } else {
830                    assert_eq!(res, &[(vm_addr, len as usize)]);
831                }
832            }
833        }
834    }
835
836    #[test]
837    fn test_memory_chunk_iterator_two() {
838        let config = Config {
839            aligned_memory_mapping: false,
840            ..Config::default()
841        };
842        let mem1 = vec![0x11; 8];
843        let mem2 = vec![0x22; 4];
844        let memory_mapping = MemoryMapping::new(
845            vec![
846                MemoryRegion::new_readonly(&mem1, MM_RODATA_START),
847                MemoryRegion::new_readonly(&mem2, MM_RODATA_START + 8),
848            ],
849            &config,
850            SBPFVersion::V3,
851        )
852        .unwrap();
853
854        for (vm_addr, len, mut expected) in [
855            (MM_RODATA_START, 8, vec![(MM_RODATA_START, 8)]),
856            (
857                MM_RODATA_START + 7,
858                2,
859                vec![(MM_RODATA_START + 7, 1), (MM_RODATA_START + 8, 1)],
860            ),
861            (MM_RODATA_START + 8, 4, vec![(MM_RODATA_START + 8, 4)]),
862        ] {
863            for rev in [false, true] {
864                let iter = MemoryChunkIterator::new(
865                    &memory_mapping,
866                    &[],
867                    AccessType::Load,
868                    vm_addr,
869                    len,
870                    true,
871                )
872                .unwrap();
873                let res = if rev {
874                    expected.reverse();
875                    to_chunk_vec(iter.rev())
876                } else {
877                    to_chunk_vec(iter)
878                };
879
880                assert_eq!(res, expected);
881            }
882        }
883    }
884
885    #[test]
886    fn test_iter_memory_pair_chunks_short() {
887        let config = Config {
888            aligned_memory_mapping: false,
889            ..Config::default()
890        };
891        let mem1 = vec![0x11; 8];
892        let mem2 = vec![0x22; 4];
893        let memory_mapping = MemoryMapping::new(
894            vec![
895                MemoryRegion::new_readonly(&mem1, MM_RODATA_START),
896                MemoryRegion::new_readonly(&mem2, MM_RODATA_START + 8),
897            ],
898            &config,
899            SBPFVersion::V3,
900        )
901        .unwrap();
902
903        // dst is shorter than src
904        assert_matches!(
905            iter_memory_pair_chunks(
906                AccessType::Load,
907                MM_RODATA_START,
908                AccessType::Load,
909                MM_RODATA_START + 8,
910                8,
911                &[],
912                &memory_mapping,
913                false,
914                true,
915                |_src, _dst, _len| Ok::<_, Error>(0),
916            ).unwrap_err().downcast_ref().unwrap(),
917            EbpfError::AccessViolation(AccessType::Load, addr, 8, "program") if *addr == MM_RODATA_START + 8
918        );
919
920        // src is shorter than dst
921        assert_matches!(
922            iter_memory_pair_chunks(
923                AccessType::Load,
924                MM_RODATA_START + 10,
925                AccessType::Load,
926                MM_RODATA_START + 2,
927                3,
928                &[],
929                &memory_mapping,
930                false,
931                true,
932                |_src, _dst, _len| Ok::<_, Error>(0),
933            ).unwrap_err().downcast_ref().unwrap(),
934            EbpfError::AccessViolation(AccessType::Load, addr, 3, "program") if *addr == MM_RODATA_START + 10
935        );
936    }
937
938    #[test]
939    #[should_panic(expected = "AccessViolation(Store, 4294967296, 4")]
940    fn test_memmove_non_contiguous_readonly() {
941        let config = Config {
942            aligned_memory_mapping: false,
943            ..Config::default()
944        };
945        let mem1 = vec![0x11; 8];
946        let mem2 = vec![0x22; 4];
947        let memory_mapping = MemoryMapping::new(
948            vec![
949                MemoryRegion::new_readonly(&mem1, MM_RODATA_START),
950                MemoryRegion::new_readonly(&mem2, MM_RODATA_START + 8),
951            ],
952            &config,
953            SBPFVersion::V3,
954        )
955        .unwrap();
956
957        memmove_non_contiguous(
958            MM_RODATA_START,
959            MM_RODATA_START + 8,
960            4,
961            &[],
962            &memory_mapping,
963            true,
964        )
965        .unwrap();
966    }
967
968    #[test_case(&[], (0, 0, 0); "no regions")]
969    #[test_case(&[10], (1, 10, 0); "single region 0 len")]
970    #[test_case(&[10], (0, 5, 5); "single region no overlap")]
971    #[test_case(&[10], (0, 0, 10) ; "single region complete overlap")]
972    #[test_case(&[10], (2, 0, 5); "single region partial overlap start")]
973    #[test_case(&[10], (0, 1, 6); "single region partial overlap middle")]
974    #[test_case(&[10], (2, 5, 5); "single region partial overlap end")]
975    #[test_case(&[3, 5], (0, 5, 2) ; "two regions no overlap, single source region")]
976    #[test_case(&[4, 7], (0, 5, 5) ; "two regions no overlap, multiple source regions")]
977    #[test_case(&[3, 8], (0, 0, 11) ; "two regions complete overlap")]
978    #[test_case(&[2, 9], (3, 0, 5) ; "two regions partial overlap start")]
979    #[test_case(&[3, 9], (1, 2, 5) ; "two regions partial overlap middle")]
980    #[test_case(&[7, 3], (2, 6, 4) ; "two regions partial overlap end")]
981    #[test_case(&[2, 6, 3, 4], (0, 10, 2) ; "many regions no overlap, single source region")]
982    #[test_case(&[2, 1, 2, 5, 6], (2, 10, 4) ; "many regions no overlap, multiple source regions")]
983    #[test_case(&[8, 1, 3, 6], (0, 0, 18) ; "many regions complete overlap")]
984    #[test_case(&[7, 3, 1, 4, 5], (5, 0, 8) ; "many regions overlap start")]
985    #[test_case(&[1, 5, 2, 9, 3], (5, 4, 8) ; "many regions overlap middle")]
986    #[test_case(&[3, 9, 1, 1, 2, 1], (2, 9, 8) ; "many regions overlap end")]
987    fn test_memmove_non_contiguous(
988        regions: &[usize],
989        (src_offset, dst_offset, len): (usize, usize, usize),
990    ) {
991        let config = Config {
992            aligned_memory_mapping: false,
993            ..Config::default()
994        };
995        let (mem, memory_mapping) = build_memory_mapping(regions, &config);
996
997        // flatten the memory so we can memmove it with ptr::copy
998        let mut expected_memory = flatten_memory(&mem);
999        unsafe {
1000            std::ptr::copy(
1001                expected_memory.as_ptr().add(src_offset),
1002                expected_memory.as_mut_ptr().add(dst_offset),
1003                len,
1004            )
1005        };
1006
1007        // do our memmove
1008        memmove_non_contiguous(
1009            MM_RODATA_START + dst_offset as u64,
1010            MM_RODATA_START + src_offset as u64,
1011            len as u64,
1012            &[],
1013            &memory_mapping,
1014            true,
1015        )
1016        .unwrap();
1017
1018        // flatten memory post our memmove
1019        let memory = flatten_memory(&mem);
1020
1021        // compare libc's memmove with ours
1022        assert_eq!(expected_memory, memory);
1023    }
1024
1025    #[test]
1026    #[should_panic(expected = "AccessViolation(Store, 4294967296, 9")]
1027    fn test_memset_non_contiguous_readonly() {
1028        let config = Config {
1029            aligned_memory_mapping: false,
1030            ..Config::default()
1031        };
1032        let mut mem1 = vec![0x11; 8];
1033        let mem2 = vec![0x22; 4];
1034        let memory_mapping = MemoryMapping::new(
1035            vec![
1036                MemoryRegion::new_writable(&mut mem1, MM_RODATA_START),
1037                MemoryRegion::new_readonly(&mem2, MM_RODATA_START + 8),
1038            ],
1039            &config,
1040            SBPFVersion::V3,
1041        )
1042        .unwrap();
1043
1044        assert_eq!(
1045            memset_non_contiguous(MM_RODATA_START, 0x33, 9, &[], &memory_mapping, true).unwrap(),
1046            0
1047        );
1048    }
1049
1050    #[test]
1051    fn test_memset_non_contiguous() {
1052        let config = Config {
1053            aligned_memory_mapping: false,
1054            ..Config::default()
1055        };
1056        let mem1 = vec![0x11; 1];
1057        let mut mem2 = vec![0x22; 2];
1058        let mut mem3 = vec![0x33; 3];
1059        let mut mem4 = vec![0x44; 4];
1060        let memory_mapping = MemoryMapping::new(
1061            vec![
1062                MemoryRegion::new_readonly(&mem1, MM_RODATA_START),
1063                MemoryRegion::new_writable(&mut mem2, MM_RODATA_START + 1),
1064                MemoryRegion::new_writable(&mut mem3, MM_RODATA_START + 3),
1065                MemoryRegion::new_writable(&mut mem4, MM_RODATA_START + 6),
1066            ],
1067            &config,
1068            SBPFVersion::V3,
1069        )
1070        .unwrap();
1071
1072        assert_eq!(
1073            memset_non_contiguous(MM_RODATA_START + 1, 0x55, 7, &[], &memory_mapping, true)
1074                .unwrap(),
1075            0
1076        );
1077        assert_eq!(&mem1, &[0x11]);
1078        assert_eq!(&mem2, &[0x55, 0x55]);
1079        assert_eq!(&mem3, &[0x55, 0x55, 0x55]);
1080        assert_eq!(&mem4, &[0x55, 0x55, 0x44, 0x44]);
1081    }
1082
1083    #[test]
1084    fn test_memcmp_non_contiguous() {
1085        let config = Config {
1086            aligned_memory_mapping: false,
1087            ..Config::default()
1088        };
1089        let mem1 = b"foo".to_vec();
1090        let mem2 = b"barbad".to_vec();
1091        let mem3 = b"foobarbad".to_vec();
1092        let memory_mapping = MemoryMapping::new(
1093            vec![
1094                MemoryRegion::new_readonly(&mem1, MM_RODATA_START),
1095                MemoryRegion::new_readonly(&mem2, MM_RODATA_START + 3),
1096                MemoryRegion::new_readonly(&mem3, MM_RODATA_START + 9),
1097            ],
1098            &config,
1099            SBPFVersion::V3,
1100        )
1101        .unwrap();
1102
1103        // non contiguous src
1104        assert_eq!(
1105            memcmp_non_contiguous(
1106                MM_RODATA_START,
1107                MM_RODATA_START + 9,
1108                9,
1109                &[],
1110                &memory_mapping,
1111                true
1112            )
1113            .unwrap(),
1114            0
1115        );
1116
1117        // non contiguous dst
1118        assert_eq!(
1119            memcmp_non_contiguous(
1120                MM_RODATA_START + 10,
1121                MM_RODATA_START + 1,
1122                8,
1123                &[],
1124                &memory_mapping,
1125                true
1126            )
1127            .unwrap(),
1128            0
1129        );
1130
1131        // diff
1132        assert_eq!(
1133            memcmp_non_contiguous(
1134                MM_RODATA_START + 1,
1135                MM_RODATA_START + 11,
1136                5,
1137                &[],
1138                &memory_mapping,
1139                true
1140            )
1141            .unwrap(),
1142            unsafe { memcmp(b"oobar", b"obarb", 5) }
1143        );
1144    }
1145
1146    fn build_memory_mapping<'a>(
1147        regions: &[usize],
1148        config: &'a Config,
1149    ) -> (Vec<Vec<u8>>, MemoryMapping<'a>) {
1150        let mut regs = vec![];
1151        let mut mem = Vec::new();
1152        let mut offset = 0;
1153        for (i, region_len) in regions.iter().enumerate() {
1154            mem.push(
1155                (0..*region_len)
1156                    .map(|x| (i * 10 + x) as u8)
1157                    .collect::<Vec<_>>(),
1158            );
1159            regs.push(MemoryRegion::new_writable(
1160                &mut mem[i],
1161                MM_RODATA_START + offset as u64,
1162            ));
1163            offset += *region_len;
1164        }
1165
1166        let memory_mapping = MemoryMapping::new(regs, config, SBPFVersion::V3).unwrap();
1167
1168        (mem, memory_mapping)
1169    }
1170
1171    fn flatten_memory(mem: &[Vec<u8>]) -> Vec<u8> {
1172        mem.iter().flatten().copied().collect()
1173    }
1174}