use triton_vm::prelude::*;
use crate::list::length::Length;
use crate::prelude::*;
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct Get {
element_type: DataType,
}
impl Get {
pub const INDEX_OUT_OF_BOUNDS_ERROR_ID: i128 = 380;
pub const MEM_PAGE_ACCESS_VIOLATION_ERROR_ID: i128 = 381;
pub fn new(element_type: DataType) -> Self {
Self::assert_element_type_is_supported(&element_type);
Self { element_type }
}
pub(crate) fn assert_element_type_is_supported(element_type: &DataType) {
let Some(static_len) = element_type.static_length() else {
panic!("element should have static length");
};
assert_ne!(0, static_len, "element must not be zero-sized");
}
}
impl BasicSnippet for Get {
fn inputs(&self) -> Vec<(DataType, String)> {
let list_type = DataType::List(Box::new(self.element_type.clone()));
vec![
(list_type, "*list".to_string()),
(DataType::U32, "index".to_string()),
]
}
fn outputs(&self) -> Vec<(DataType, String)> {
vec![(self.element_type.clone(), "element".to_string())]
}
fn entrypoint(&self) -> String {
let element_type = self.element_type.label_friendly_name();
format!("tasmlib_list_get_element___{element_type}")
}
fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
let list_length = library.import(Box::new(Length));
let mul_with_element_size = match self.element_type.stack_size() {
1 => triton_asm!(), n => triton_asm!(push {n} mul),
};
triton_asm!(
{self.entrypoint()}:
dup 1
call {list_length} dup 1
lt assert error_id {Self::INDEX_OUT_OF_BOUNDS_ERROR_ID}
addi 1
{&mul_with_element_size}
split
pick 1
push 0
eq
assert error_id {Self::MEM_PAGE_ACCESS_VIOLATION_ERROR_ID}
add {&self.element_type.read_value_from_memory_pop_pointer()}
return
)
}
}
#[cfg(test)]
pub(crate) mod tests {
use triton_vm::error::OpStackError::FailedU32Conversion;
use super::*;
use crate::rust_shadowing_helper_functions::list::insert_random_list;
use crate::rust_shadowing_helper_functions::list::list_get;
use crate::test_helpers::negative_test;
use crate::test_prelude::*;
use crate::U32_TO_USIZE_ERR;
impl Get {
fn set_up_initial_state(
&self,
list_length: usize,
index: usize,
list_pointer: BFieldElement,
) -> AccessorInitialState {
let mut memory = HashMap::default();
insert_random_list(&self.element_type, list_pointer, list_length, &mut memory);
let mut stack = self.init_stack_for_isolated_run();
stack.push(list_pointer);
stack.push(bfe!(index));
AccessorInitialState { stack, memory }
}
pub fn random_len_idx_ptr(
bench_case: Option<BenchmarkCase>,
rng: &mut impl rand::Rng,
) -> (usize, usize, BFieldElement) {
let (index, list_length) = match bench_case {
Some(BenchmarkCase::CommonCase) => (16, 32),
Some(BenchmarkCase::WorstCase) => (63, 64),
None => {
let list_length = rng.random_range(1..=100);
(rng.random_range(0..list_length), list_length)
}
};
let list_pointer = rng.random();
(list_length, index, list_pointer)
}
}
impl Accessor for Get {
fn rust_shadow(
&self,
stack: &mut Vec<BFieldElement>,
memory: &HashMap<BFieldElement, BFieldElement>,
) {
let index: u32 = stack.pop().unwrap().try_into().unwrap();
let list_pointer = stack.pop().unwrap();
let index: usize = index.try_into().expect(U32_TO_USIZE_ERR);
let element_length = self.element_type.static_length().unwrap();
let element = list_get(list_pointer, index, memory, element_length);
stack.extend(element.into_iter().rev());
}
fn pseudorandom_initial_state(
&self,
seed: [u8; 32],
bench_case: Option<BenchmarkCase>,
) -> AccessorInitialState {
let (list_length, index, list_pointer) =
Self::random_len_idx_ptr(bench_case, &mut StdRng::from_seed(seed));
self.set_up_initial_state(list_length, index, list_pointer)
}
}
#[test]
fn rust_shadow() {
for ty in [DataType::Bfe, DataType::Digest, DataType::I128] {
ShadowedAccessor::new(Get::new(ty)).test();
}
}
#[proptest]
fn out_of_bounds_access_crashes_vm(
#[strategy(0_usize..=1_000)] list_length: usize,
#[strategy(#list_length..1 << 32)] index: usize,
#[strategy(arb())] list_pointer: BFieldElement,
) {
let get = Get::new(DataType::Bfe);
let initial_state = get.set_up_initial_state(list_length, index, list_pointer);
test_assertion_failure(
&ShadowedAccessor::new(get),
initial_state.into(),
&[Get::INDEX_OUT_OF_BOUNDS_ERROR_ID],
);
}
#[proptest]
fn too_large_indices_crash_vm(
#[strategy(1_usize << 32..)] index: usize,
#[strategy(arb())] list_pointer: BFieldElement,
) {
let list_length = 0;
let get = Get::new(DataType::Bfe);
let initial_state = get.set_up_initial_state(list_length, index, list_pointer);
let expected_error = InstructionError::OpStackError(FailedU32Conversion(bfe!(index)));
negative_test(
&ShadowedAccessor::new(get),
initial_state.into(),
&[expected_error],
);
}
#[proptest(cases = 100)]
fn too_large_lists_crash_vm(
#[strategy(1_u64 << 22..1 << 32)] list_length: u64,
#[strategy((1 << 22) - 1..#list_length)] index: u64,
#[strategy(arb())] list_pointer: BFieldElement,
) {
let mut memory = HashMap::default();
memory.insert(list_pointer, bfe!(list_length));
let tuple_ty = DataType::Tuple(vec![DataType::Bfe; 1 << 10]);
let get = Get::new(tuple_ty);
let mut stack = get.init_stack_for_isolated_run();
stack.push(list_pointer);
stack.push(bfe!(index));
let initial_state = AccessorInitialState { stack, memory };
test_assertion_failure(
&ShadowedAccessor::new(get),
initial_state.into(),
&[Get::MEM_PAGE_ACCESS_VIOLATION_ERROR_ID],
);
}
}
#[cfg(test)]
mod benches {
use super::*;
use crate::test_prelude::*;
#[test]
fn benchmark() {
ShadowedAccessor::new(Get::new(DataType::Digest)).bench();
}
}