tasm_lib/list/higher_order/
all.rsuse itertools::Itertools;
use triton_vm::prelude::*;
use super::inner_function::InnerFunction;
use crate::list::get::Get;
use crate::list::length::Length;
use crate::prelude::*;
pub struct All {
pub f: InnerFunction,
}
impl All {
pub fn new(f: InnerFunction) -> Self {
Self { f }
}
}
impl BasicSnippet for All {
fn inputs(&self) -> Vec<(DataType, String)> {
let element_type = self.f.domain();
let list_type = DataType::List(Box::new(element_type));
vec![(list_type, "*input_list".to_string())]
}
fn outputs(&self) -> Vec<(DataType, String)> {
vec![(DataType::Bool, "all_true".to_string())]
}
fn entrypoint(&self) -> String {
format!("tasmlib_list_higher_order_u32_all_{}", self.f.entrypoint())
}
fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
let input_type = self.f.domain();
let output_type = self.f.range();
assert_eq!(output_type, DataType::Bool);
let get_length = library.import(Box::new(Length));
let list_get = library.import(Box::new(Get::new(input_type)));
let inner_function_name = match &self.f {
InnerFunction::RawCode(rc) => rc.entrypoint(),
InnerFunction::NoFunctionBody(_) => todo!(),
InnerFunction::BasicSnippet(bs) => {
let labelled_instructions = bs.annotated_code(library);
library.explicit_import(&bs.entrypoint(), &labelled_instructions)
}
};
let maybe_inner_function_body_raw = match &self.f {
InnerFunction::RawCode(rc) => rc.function.iter().map(|x| x.to_string()).join("\n"),
InnerFunction::NoFunctionBody(_) => todo!(),
InnerFunction::BasicSnippet(_) => Default::default(),
};
let entrypoint = self.entrypoint();
let main_loop = format!("{entrypoint}_loop");
let result_type_hint = format!("hint all_{}: Boolean = stack[0]", self.f.entrypoint());
triton_asm!(
{entrypoint}:
hint input_list = stack[0]
push 1 {result_type_hint}
swap 1 dup 0 call {get_length}
hint list_item: Index = stack[0]
call {main_loop}
pop 2 return
{main_loop}:
dup 0 push 0 eq
skiz return
push -1 add
dup 1 dup 1
call {list_get}
call {inner_function_name}
dup 3 mul swap 3 pop 1 recurse
{maybe_inner_function_body_raw}
)
}
}
#[cfg(test)]
mod tests {
use num::One;
use num::Zero;
use super::*;
use crate::arithmetic;
use crate::empty_stack;
use crate::list::higher_order::inner_function::RawCode;
use crate::list::LIST_METADATA_SIZE;
use crate::rust_shadowing_helper_functions;
use crate::rust_shadowing_helper_functions::list::list_get;
use crate::rust_shadowing_helper_functions::list::untyped_insert_random_list;
use crate::test_helpers::test_rust_equivalence_given_complete_state;
use crate::test_prelude::*;
impl All {
fn generate_input_state(
&self,
list_pointer: BFieldElement,
list_length: usize,
random: bool,
) -> InitVmState {
let mut stack = empty_stack();
stack.push(list_pointer);
let mut memory = HashMap::default();
let input_type = self.f.domain();
let list_bookkeeping_offset = LIST_METADATA_SIZE;
let element_index_in_list =
list_bookkeeping_offset + list_length * input_type.stack_size();
let element_index = list_pointer + BFieldElement::new(element_index_in_list as u64);
memory.insert(BFieldElement::zero(), element_index);
if random {
untyped_insert_random_list(
list_pointer,
list_length,
&mut memory,
input_type.stack_size(),
);
} else {
rust_shadowing_helper_functions::list::list_insert(
list_pointer,
(0..list_length as u64)
.map(BFieldElement::new)
.collect_vec(),
&mut memory,
);
}
InitVmState::with_stack_and_memory(stack, memory)
}
}
impl Function for All {
fn rust_shadow(
&self,
stack: &mut Vec<BFieldElement>,
memory: &mut HashMap<BFieldElement, BFieldElement>,
) {
let input_type = self.f.domain();
let list_pointer = stack.pop().unwrap();
let list_length =
rust_shadowing_helper_functions::list::list_get_length(list_pointer, memory);
let mut satisfied = true;
for i in 0..list_length {
let input_item = list_get(list_pointer, i, memory, input_type.stack_size());
for bfe in input_item.into_iter().rev() {
stack.push(bfe);
}
self.f.apply(stack, memory);
let single_result = stack.pop().unwrap().value() != 0;
satisfied = satisfied && single_result;
}
stack.push(BFieldElement::new(satisfied as u64));
}
fn pseudorandom_initial_state(
&self,
seed: [u8; 32],
bench_case: Option<BenchmarkCase>,
) -> FunctionInitialState {
let (stack, memory) = match bench_case {
Some(BenchmarkCase::CommonCase) => {
let list_pointer = BFieldElement::new(5);
let list_length = 10;
let execution_state =
self.generate_input_state(list_pointer, list_length, false);
(execution_state.stack, execution_state.nondeterminism.ram)
}
Some(BenchmarkCase::WorstCase) => {
let list_pointer = BFieldElement::new(5);
let list_length = 100;
let execution_state =
self.generate_input_state(list_pointer, list_length, false);
(execution_state.stack, execution_state.nondeterminism.ram)
}
None => {
let mut rng = StdRng::from_seed(seed);
let list_pointer = BFieldElement::new(rng.next_u64() % (1 << 20));
let list_length = 1 << (rng.next_u32() as usize % 4);
let execution_state =
self.generate_input_state(list_pointer, list_length, true);
(execution_state.stack, execution_state.nondeterminism.ram)
}
};
FunctionInitialState { stack, memory }
}
}
#[test]
fn rust_shadow() {
let inner_function = InnerFunction::BasicSnippet(Box::new(TestHashXFieldElementLsb));
ShadowedFunction::new(All::new(inner_function)).test();
}
#[test]
fn all_lt_test() {
const TWO_POW_31: u64 = 1u64 << 31;
let rawcode = RawCode::new(
triton_asm!(
less_than_2_pow_31:
push 2147483648 swap 1
lt
return
),
DataType::Bfe,
DataType::Bool,
);
let snippet = All::new(InnerFunction::RawCode(rawcode));
let mut memory = HashMap::new();
rust_shadowing_helper_functions::list::list_insert(
BFieldElement::new(42),
(0..30).map(BFieldElement::new).collect_vec(),
&mut memory,
);
let input_stack = [empty_stack(), vec![BFieldElement::new(42)]].concat();
let expected_end_stack_true = [empty_stack(), vec![BFieldElement::one()]].concat();
let shadowed_snippet = ShadowedFunction::new(snippet);
let mut nondeterminism = NonDeterminism::default().with_ram(memory);
test_rust_equivalence_given_complete_state(
&shadowed_snippet,
&input_stack,
&[],
&nondeterminism,
&None,
Some(&expected_end_stack_true),
);
rust_shadowing_helper_functions::list::list_insert(
BFieldElement::new(42),
(0..30)
.map(|x| BFieldElement::new(x + TWO_POW_31 - 20))
.collect_vec(),
&mut nondeterminism.ram,
);
let expected_end_stack_false = [empty_stack(), vec![BFieldElement::zero()]].concat();
test_rust_equivalence_given_complete_state(
&shadowed_snippet,
&input_stack,
&[],
&nondeterminism,
&None,
Some(&expected_end_stack_false),
);
}
#[test]
fn test_with_raw_function_lsb_on_bfe() {
let rawcode = RawCode::new(
triton_asm!(
lsb_bfe:
split push 2 swap 1 div_mod swap 2 pop 2 return
),
DataType::Bfe,
DataType::Bool,
);
let snippet = All::new(InnerFunction::RawCode(rawcode));
ShadowedFunction::new(snippet).test();
}
#[test]
fn test_with_raw_function_eq_42() {
let raw_code = RawCode::new(
triton_asm!(
eq_42:
push 42
eq
return
),
DataType::U32,
DataType::Bool,
);
let snippet = All::new(InnerFunction::RawCode(raw_code));
ShadowedFunction::new(snippet).test();
}
#[test]
fn test_with_raw_function_lsb_on_xfe() {
let rawcode = RawCode::new(
triton_asm!(
lsb_xfe:
split push 2 swap 1 div_mod swap 4 pop 4 return
),
DataType::Xfe,
DataType::Bool,
);
let snippet = All::new(InnerFunction::RawCode(rawcode));
ShadowedFunction::new(snippet).test();
}
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
pub(super) struct TestHashXFieldElementLsb;
impl BasicSnippet for TestHashXFieldElementLsb {
fn inputs(&self) -> Vec<(DataType, String)> {
vec![(DataType::Xfe, "element".to_string())]
}
fn outputs(&self) -> Vec<(DataType, String)> {
vec![(DataType::Bool, "bool".to_string())]
}
fn entrypoint(&self) -> String {
"test_hash_xfield_element_lsb".to_string()
}
fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
let entrypoint = self.entrypoint();
let unused_import = library.import(Box::new(arithmetic::u32::safe_add::SafeAdd));
triton_asm!(
{entrypoint}:
push 0
push 0
call {unused_import}
pop 1
push 0
push 0
push 0
push 0
push 0
push 0
push 1 pick 9
pick 9
pick 9 sponge_init
sponge_absorb
sponge_squeeze
split
push 2
place 1
div_mod place 11
pop 5
pop 5
pop 1 return
)
}
}
}
#[cfg(test)]
mod benches {
use super::tests::TestHashXFieldElementLsb;
use super::*;
use crate::test_prelude::*;
#[test]
fn benchmark() {
let inner_function = InnerFunction::BasicSnippet(Box::new(TestHashXFieldElementLsb));
ShadowedFunction::new(All::new(inner_function)).bench();
}
}