tasm_lib/list/higher_order/
zip.rsuse std::collections::HashMap;
use itertools::Itertools;
use rand::prelude::*;
use triton_vm::prelude::*;
use crate::data_type::DataType;
use crate::list::new::New;
use crate::list::LIST_METADATA_SIZE;
use crate::rust_shadowing_helper_functions::list::untyped_insert_random_list;
use crate::traits::basic_snippet::BasicSnippet;
use crate::traits::function::*;
use crate::*;
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct Zip {
pub left_type: DataType,
pub right_type: DataType,
}
impl Zip {
pub fn new(left_type: DataType, right_type: DataType) -> Self {
Self {
left_type,
right_type,
}
}
}
impl BasicSnippet for Zip {
fn inputs(&self) -> Vec<(DataType, String)> {
let list = |data_type| DataType::List(Box::new(data_type));
let left_list = (list(self.left_type.clone()), "*left_list".to_string());
let right_list = (list(self.right_type.clone()), "*right_list".to_string());
vec![left_list, right_list]
}
fn outputs(&self) -> Vec<(DataType, String)> {
let list = |data_type| DataType::List(Box::new(data_type));
let tuple_type = DataType::Tuple(vec![self.left_type.clone(), self.right_type.clone()]);
let output_list = (list(tuple_type), "*output_list".to_string());
vec![output_list]
}
fn entrypoint(&self) -> String {
format!(
"tasmlib_list_higher_order_u32_zip_{}_with_{}",
self.left_type.label_friendly_name(),
self.right_type.label_friendly_name()
)
}
fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
let output_type = DataType::Tuple(vec![self.left_type.clone(), self.right_type.clone()]);
let new_output_list = library.import(Box::new(New::new(output_type.clone())));
let entrypoint = self.entrypoint();
let main_loop_label = format!("{entrypoint}_loop");
let right_size = self.right_type.stack_size();
let left_size = self.left_type.stack_size();
let read_left_element = self.left_type.read_value_from_memory_leave_pointer();
let read_right_element = self.right_type.read_value_from_memory_leave_pointer();
let write_output_element = output_type.write_value_to_memory_leave_pointer();
let left_size_plus_one = left_size + 1;
let left_size_plus_three = left_size + 3;
let sum_of_size = left_size + right_size;
let sum_of_size_plus_two = sum_of_size + 2;
assert!(sum_of_size_plus_two <= NUM_OP_STACK_REGISTERS, "zip only works for an output element size less than or equal to the available op-stack words");
let minus_two_times_sum_of_size = -(2 * sum_of_size as i32);
let mul_with_size = |n| match n {
0 => triton_asm!(pop 1 push 0),
1 => triton_asm!(),
n => triton_asm!(
push {n}
mul
),
};
let main_loop = triton_asm!(
{main_loop_label}:
dup 3
dup 3
eq
skiz return
dup 2
{&read_left_element}
swap {left_size_plus_three}
pop 1
dup {left_size_plus_one}
{&read_right_element}
swap {sum_of_size_plus_two}
pop 1
dup {sum_of_size}
{&write_output_element}
push {minus_two_times_sum_of_size}
add
swap 1
pop 1
recurse
);
triton_asm!(
{entrypoint}:
dup 1 read_mem 1 pop 1 dup 1 read_mem 1 pop 1 dup 1 eq assert call {new_output_list} dup 1
swap 1
write_mem 1
dup 1
push -1
add
{&mul_with_size(sum_of_size)}
add
swap 2
dup 1
{&mul_with_size(right_size)}
add
swap 1
{&mul_with_size(left_size)}
dup 3
add
swap 2
call {main_loop_label}
push {sum_of_size - 1}
add
swap 3
pop 3
return
{&main_loop}
)
}
}
impl Function for Zip {
fn rust_shadow(
&self,
stack: &mut Vec<BFieldElement>,
memory: &mut HashMap<BFieldElement, BFieldElement>,
) {
use rust_shadowing_helper_functions::dyn_malloc;
use rust_shadowing_helper_functions::list;
let right_pointer = stack.pop().unwrap();
let left_pointer = stack.pop().unwrap();
let left_length = list::list_get_length(left_pointer, memory);
let right_length = list::list_get_length(right_pointer, memory);
assert_eq!(left_length, right_length);
let len = left_length;
let output_pointer = dyn_malloc::dynamic_allocator(memory);
list::list_new(output_pointer, memory);
list::list_set_length(output_pointer, len, memory);
for i in 0..len {
let left_item = list::list_get(left_pointer, i, memory, self.left_type.stack_size());
let right_item = list::list_get(right_pointer, i, memory, self.right_type.stack_size());
let pair = right_item.into_iter().chain(left_item).collect_vec();
list::list_set(output_pointer, i, pair, memory);
}
stack.push(output_pointer);
}
fn pseudorandom_initial_state(
&self,
seed: [u8; 32],
_bench_case: Option<snippet_bencher::BenchmarkCase>,
) -> FunctionInitialState {
let mut rng: StdRng = SeedableRng::from_seed(seed);
let list_len = rng.gen_range(0..20);
let execution_state = self.generate_input_state(list_len, list_len);
FunctionInitialState {
stack: execution_state.stack,
memory: execution_state.nondeterminism.ram,
}
}
}
impl Zip {
fn generate_input_state(&self, left_length: usize, right_length: usize) -> InitVmState {
let fill_with_random_elements =
|data_type: &DataType, list_pointer, list_len, memory: &mut _| {
untyped_insert_random_list(list_pointer, list_len, memory, data_type.stack_size())
};
let left_pointer = BFieldElement::new(0);
let left_size = LIST_METADATA_SIZE + left_length * self.left_type.stack_size();
let right_pointer = left_pointer + BFieldElement::new(left_size as u64);
let mut memory = HashMap::default();
fill_with_random_elements(&self.left_type, left_pointer, left_length, &mut memory);
fill_with_random_elements(&self.right_type, right_pointer, right_length, &mut memory);
let stack = [empty_stack(), vec![left_pointer, right_pointer]].concat();
InitVmState::with_stack_and_memory(stack, memory)
}
}
#[cfg(test)]
mod tests {
use proptest::collection::vec;
use proptest::prelude::*;
use proptest_arbitrary_interop::arb;
use test_strategy::proptest;
use super::*;
use crate::rust_shadowing_helper_functions::list;
use crate::structure::tasm_object::MemoryIter;
use crate::traits::function::ShadowedFunction;
use crate::traits::rust_shadow::RustShadow;
#[test]
fn prop_test_xfe_digest() {
ShadowedFunction::new(Zip::new(DataType::Xfe, DataType::Digest)).test();
}
#[test]
fn list_prop_test_more_types() {
ShadowedFunction::new(Zip::new(DataType::Bfe, DataType::Bfe)).test();
ShadowedFunction::new(Zip::new(DataType::U64, DataType::U32)).test();
ShadowedFunction::new(Zip::new(DataType::Bool, DataType::Digest)).test();
ShadowedFunction::new(Zip::new(DataType::U128, DataType::VoidPointer)).test();
ShadowedFunction::new(Zip::new(DataType::U128, DataType::Digest)).test();
ShadowedFunction::new(Zip::new(DataType::U128, DataType::U128)).test();
ShadowedFunction::new(Zip::new(DataType::Digest, DataType::Digest)).test();
}
#[proptest]
fn zipping_u32s_with_x_field_elements_correspond_to_bfieldcodec(
left_list: Vec<u32>,
#[strategy(vec(arb(), #left_list.len()))] right_list: Vec<XFieldElement>,
) {
let left_type = DataType::U32;
let right_type = DataType::Xfe;
let left_pointer = BFieldElement::new(0);
let right_pointer = BFieldElement::new(1 << 60); let mut ram = HashMap::default();
write_list_to_ram(&mut ram, left_pointer, &left_type, &left_list);
write_list_to_ram(&mut ram, right_pointer, &right_type, &right_list);
let mut stack = [empty_stack(), vec![left_pointer, right_pointer]].concat();
let zip = Zip::new(left_type, right_type);
zip.rust_shadow(&mut stack, &mut ram);
let zipped = left_list.into_iter().zip_eq(right_list).collect_vec();
let encoding = zipped.encode();
let output_list_pointer = stack.pop().unwrap();
let memory_iter = MemoryIter::new(&ram, output_list_pointer);
let tasm_zip_result = memory_iter.take(encoding.len()).collect_vec();
prop_assert_eq!(encoding, tasm_zip_result);
}
fn write_list_to_ram<T: BFieldCodec + Copy>(
ram: &mut HashMap<BFieldElement, BFieldElement>,
list_pointer: BFieldElement,
item_type: &DataType,
list: &[T],
) {
list::list_new(list_pointer, ram);
for &item in list {
list::list_push(list_pointer, item.encode(), ram, item_type.stack_size());
}
}
}
#[cfg(test)]
mod benches {
use super::*;
use crate::traits::function::ShadowedFunction;
use crate::traits::rust_shadow::RustShadow;
#[test]
fn zip_benchmark() {
ShadowedFunction::new(Zip::new(DataType::Xfe, DataType::Digest)).bench();
}
}