tasm_lib/traits/
deprecated_snippet.rsuse std::collections::HashMap;
use anyhow::Result;
use triton_vm::isa::parser::to_labelled_instructions;
use triton_vm::isa::parser::tokenize;
use triton_vm::prelude::*;
use super::basic_snippet::BasicSnippet;
use crate::data_type::DataType;
use crate::execute_bench_deprecated;
use crate::execute_test;
use crate::execute_with_terminal_state;
use crate::library::Library;
use crate::snippet_bencher::BenchmarkResult;
use crate::InitVmState;
pub trait DeprecatedSnippet {
fn entrypoint_name(&self) -> String;
fn input_field_names(&self) -> Vec<String>;
fn input_types(&self) -> Vec<DataType>;
fn output_field_names(&self) -> Vec<String>;
fn output_types(&self) -> Vec<DataType>;
fn stack_diff(&self) -> isize;
fn function_code(&self, library: &mut Library) -> String;
fn crash_conditions(&self) -> Vec<String>;
fn gen_input_states(&self) -> Vec<InitVmState>;
fn common_case_input_state(&self) -> InitVmState;
fn worst_case_input_state(&self) -> InitVmState;
fn function_code_as_instructions(&self, library: &mut Library) -> Vec<LabelledInstruction> {
let f_body = self.function_code(library);
let (_, tokens) = tokenize(&f_body).unwrap();
to_labelled_instructions(&tokens)
}
fn rust_shadowing(
&self,
stack: &mut Vec<BFieldElement>,
std_in: Vec<BFieldElement>,
secret_in: Vec<BFieldElement>,
memory: &mut HashMap<BFieldElement, BFieldElement>,
);
fn link_for_isolated_run(&self) -> Vec<LabelledInstruction> {
let mut snippet_state = Library::new();
let entrypoint = self.entrypoint_name();
let mut function_body = self.function_code(&mut snippet_state);
function_body.push('\n'); let library_code = snippet_state.all_imports();
let code = triton_asm!(
call {entrypoint}
halt
{function_body}
{&library_code}
);
let _program = Program::new(&code);
code
}
fn link_and_run_tasm_for_test(
&self,
stack: &mut Vec<BFieldElement>,
std_in: Vec<BFieldElement>,
nondeterminism: NonDeterminism,
) -> Result<VMState> {
let expected_length_prior: usize = self.input_types().iter().map(|x| x.stack_size()).sum();
let expected_length_after: usize = self.output_types().iter().map(|x| x.stack_size()).sum();
assert_eq!(
Self::stack_diff(self),
expected_length_after as isize - expected_length_prior as isize,
"Declared stack diff must match type indicators"
);
let code = self.link_for_isolated_run();
let program = Program::new(&code);
let tvm_result =
execute_with_terminal_state(program, &std_in, stack, &nondeterminism, None);
let final_state = tvm_result?;
execute_test(
&code,
stack,
Self::stack_diff(self),
std_in,
nondeterminism,
None,
);
Ok(final_state)
}
fn link_and_run_tasm_for_bench(
&self,
stack: &mut Vec<BFieldElement>,
std_in: Vec<BFieldElement>,
nondeterminism: NonDeterminism,
) -> Result<BenchmarkResult> {
let expected_length_prior: usize = self.input_types().iter().map(|x| x.stack_size()).sum();
let expected_length_after: usize = self.output_types().iter().map(|x| x.stack_size()).sum();
assert_eq!(
Self::stack_diff(self),
expected_length_after as isize - expected_length_prior as isize,
"Declared stack diff must match type indicators"
);
let code = self.link_for_isolated_run();
execute_bench_deprecated(&code, stack, Self::stack_diff(self), std_in, nondeterminism)
}
fn link_and_run_tasm_from_state_for_test(&self, execution_state: &mut InitVmState) -> VMState {
let stack_prior = execution_state.stack.clone();
let ret = self.link_and_run_tasm_for_test(
&mut execution_state.stack,
execution_state.public_input.clone(),
execution_state.nondeterminism.to_owned(),
);
let stack_after = execution_state.stack.clone();
assert_eq!(
stack_prior[Digest::LEN..(stack_prior.len() - Self::input_field_names(self).len())],
stack_after[Digest::LEN..(stack_after.len() - Self::output_field_names(self).len())]
);
ret.unwrap()
}
fn link_and_run_tasm_from_state_for_bench(
&self,
execution_state: &mut InitVmState,
) -> Result<BenchmarkResult> {
let stack_prior = execution_state.stack.clone();
let ret = self.link_and_run_tasm_for_bench(
&mut execution_state.stack,
execution_state.public_input.clone(),
execution_state.nondeterminism.to_owned(),
);
let stack_after = execution_state.stack.clone();
assert_eq!(
stack_prior[Digest::LEN..(stack_prior.len() - Self::input_field_names(self).len())],
stack_after[Digest::LEN..(stack_after.len() - Self::output_field_names(self).len())]
);
ret
}
}
impl<S: DeprecatedSnippet> BasicSnippet for S {
fn inputs(&self) -> Vec<(DataType, String)> {
let mut ret = vec![];
let mut field_name_index = 0;
let field_names = self.input_field_names();
for input_type in self.input_types() {
ret.push((
input_type.clone(),
format!("input_{}", field_names[field_name_index]),
));
field_name_index += input_type.stack_size();
}
ret
}
fn outputs(&self) -> Vec<(DataType, String)> {
let mut ret = vec![];
let mut field_name_index = 0;
let field_names = self.output_field_names();
for input_type in self.output_types() {
ret.push((
input_type.clone(),
format!("output_{}", field_names[field_name_index]),
));
field_name_index += input_type.stack_size();
}
ret
}
fn entrypoint(&self) -> String {
self.entrypoint_name()
}
fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
self.function_code_as_instructions(library)
}
}
#[cfg(test)]
pub(crate) mod tests {
use std::cell::RefCell;
use std::rc::Rc;
use std::str::FromStr;
use itertools::Itertools;
use super::*;
use crate::arithmetic;
use crate::test_helpers::test_rust_equivalence_given_execution_state_deprecated;
use crate::traits::rust_shadow::RustShadow;
use crate::VmHasher;
pub(crate) struct DeprecatedSnippetWrapper<S: DeprecatedSnippet> {
pub(crate) deprecated_snippet: S,
}
impl<S: DeprecatedSnippet> DeprecatedSnippetWrapper<S> {
pub(crate) fn new(deprecated_snippet: S) -> Self {
Self { deprecated_snippet }
}
}
impl<S: DeprecatedSnippet + Clone + 'static> RustShadow for DeprecatedSnippetWrapper<S> {
fn inner(&self) -> Rc<RefCell<dyn BasicSnippet>> {
Rc::new(RefCell::new(self.deprecated_snippet.clone()))
}
fn rust_shadow_wrapper(
&self,
stdin: &[BFieldElement],
nondeterminism: &NonDeterminism,
stack: &mut Vec<BFieldElement>,
memory: &mut HashMap<BFieldElement, BFieldElement>,
_sponge: &mut Option<VmHasher>,
) -> Vec<BFieldElement> {
let mut stack_copy = stack.to_vec();
self.deprecated_snippet.rust_shadowing(
&mut stack_copy,
stdin.to_vec(),
nondeterminism.individual_tokens.clone(),
memory,
);
*stack = stack_copy;
vec![]
}
fn test(&self) {
let mut execution_states = self.deprecated_snippet.gen_input_states();
let snippet = &self.deprecated_snippet;
for execution_state in execution_states.iter_mut() {
test_rust_equivalence_given_execution_state_deprecated(
snippet,
execution_state.clone(),
);
}
}
fn bench(&self) {
todo!()
}
}
#[test]
fn can_return_code() {
let mut empty_library = Library::new();
let example_snippet =
arithmetic::u32::safeadd::Safeadd.function_code_as_instructions(&mut empty_library);
assert!(!example_snippet.is_empty());
println!(
"{}",
example_snippet.iter().map(|x| x.to_string()).join("\n")
);
}
#[test]
fn data_type_string_rep() {
assert_eq!("DataType::Digest", DataType::Digest.variant_name());
assert_eq!(
"DataType::Tuple(vec![DataType::XFE, DataType::BFE, DataType::Digest])",
DataType::Tuple(vec![DataType::Xfe, DataType::Bfe, DataType::Digest]).variant_name()
);
assert_eq!(
"DataType::List(Box::new(DataType::Digest))",
DataType::List(Box::new(DataType::Digest)).variant_name()
);
}
#[test]
fn parse_and_stringify_datatype_test() {
assert_eq!(DataType::Digest, DataType::from_str("digest").unwrap());
assert_eq!(
DataType::Bfe,
DataType::from_str(&DataType::Bfe.label_friendly_name()).unwrap()
);
for data_type in DataType::big_random_generatable_type_collection() {
assert_eq!(
data_type,
DataType::from_str(&data_type.label_friendly_name()).unwrap()
)
}
}
}