use std::collections::HashMap;
use itertools::Itertools;
use tasm_lib::list::LIST_METADATA_SIZE;
use triton_vm::prelude::*;
use crate::data_type::DataType;
use crate::empty_stack;
use crate::list::new::New;
use crate::list::set_length::SetLength;
use crate::rust_shadowing_helper_functions;
use crate::traits::deprecated_snippet::DeprecatedSnippet;
use crate::InitVmState;
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq)]
pub struct Range;
impl Range {
fn init_state(minimum: u32, supremum: u32) -> InitVmState {
let mut stack = empty_stack();
stack.push(BFieldElement::new(minimum as u64));
stack.push(BFieldElement::new(supremum as u64));
InitVmState::with_stack(stack)
}
}
impl DeprecatedSnippet for Range {
fn entrypoint_name(&self) -> String {
"tasmlib_list_range".into()
}
fn input_field_names(&self) -> Vec<String>
where
Self: Sized,
{
vec!["minimum".to_string(), "supremum".to_string()]
}
fn input_types(&self) -> Vec<DataType> {
vec![DataType::U32, DataType::U32]
}
fn output_field_names(&self) -> Vec<String>
where
Self: Sized,
{
vec!["*list".to_string()]
}
fn output_types(&self) -> Vec<DataType> {
vec![DataType::List(Box::new(DataType::U32))]
}
fn stack_diff(&self) -> isize
where
Self: Sized,
{
-1
}
fn function_code(&self, library: &mut crate::library::Library) -> String {
let data_type = DataType::U32;
let new_list = library.import(Box::new(New::new(data_type.clone())));
let set_length = library.import(Box::new(SetLength::new(data_type)));
let entrypoint = self.entrypoint_name();
let inner_loop = format!("{entrypoint}_loop");
triton_asm!(
{entrypoint}:
hint supremum = stack[0]
hint minimum = stack[1]
dup 0 push 1 add dup 2 lt assert
dup 0 dup 2 push -1 mul add call {new_list} dup 1 call {set_length} call {inner_loop} swap 3 pop 3 return
{inner_loop}:
dup 1 push 0 eq skiz return
swap 1 push -1 add dup 3 dup 1 add dup 2 push {LIST_METADATA_SIZE}
hint list_metadata_size = stack[0]
add dup 2 add write_mem 1 pop 1 swap 1 recurse
)
.iter()
.join("\n")
}
fn crash_conditions(&self) -> Vec<String>
where
Self: Sized,
{
vec![
"minimum not u32".to_string(),
"supremum not u32".to_string(),
"minimum larger than supremum".to_string(),
]
}
fn gen_input_states(&self) -> Vec<InitVmState>
where
Self: Sized,
{
vec![
Self::init_state(0, 1),
Self::init_state(0, 10),
Self::init_state(5, 15),
Self::init_state(12, 12), ]
}
fn common_case_input_state(&self) -> InitVmState
where
Self: Sized,
{
Self::init_state(0, 45)
}
fn worst_case_input_state(&self) -> InitVmState
where
Self: Sized,
{
Self::init_state(0, 250)
}
fn rust_shadowing(
&self,
stack: &mut Vec<BFieldElement>,
_std_in: Vec<BFieldElement>,
_secret_in: Vec<BFieldElement>,
memory: &mut HashMap<BFieldElement, BFieldElement>,
) where
Self: Sized,
{
let supremum: u32 = stack.pop().unwrap().value().try_into().unwrap();
let minimum: u32 = stack.pop().unwrap().value().try_into().unwrap();
let num_elements: usize = (supremum - minimum).try_into().unwrap();
let safety_offset = LIST_METADATA_SIZE;
let length = num_elements;
let list_pointer = rust_shadowing_helper_functions::dyn_malloc::dynamic_allocator(memory);
rust_shadowing_helper_functions::list::list_new(list_pointer, memory);
rust_shadowing_helper_functions::list::list_set_length(list_pointer, length, memory);
for i in minimum..supremum {
memory.insert(
list_pointer + BFieldElement::new(safety_offset as u64 + i as u64 - minimum as u64),
BFieldElement::new(i as u64),
);
}
stack.push(list_pointer);
}
}
#[cfg(test)]
mod tests {
use triton_vm::error::InstructionError;
use super::*;
use crate::execute_with_terminal_state;
use crate::test_helpers::test_rust_equivalence_multiple_deprecated;
#[test]
fn new_snippet_test() {
test_rust_equivalence_multiple_deprecated(&Range, true);
}
#[test]
fn bad_range_test() {
let init_state = Range::init_state(13, 12);
let snippet = Range;
let terminal_state = execute_with_terminal_state(
Program::new(&snippet.link_for_isolated_run()),
&init_state.public_input,
&init_state.stack,
&NonDeterminism::default(),
None,
);
let err = terminal_state.unwrap_err();
assert!(matches!(err, InstructionError::AssertionFailed(_)));
}
}
#[cfg(test)]
mod benches {
use super::*;
use crate::snippet_bencher::bench_and_write;
#[test]
fn benchmark() {
bench_and_write(Range);
}
}