tasm_lib/verifier/fri/
derive_from_stark.rsuse triton_vm::prelude::LabelledInstruction;
use triton_vm::prelude::*;
use triton_vm::stark::Stark;
use crate::arithmetic::bfe::primitive_root_of_unity::PrimitiveRootOfUnity;
use crate::arithmetic::u32::next_power_of_two::NextPowerOfTwo;
use crate::data_type::DataType;
use crate::library::Library;
use crate::memory::dyn_malloc::DynMalloc;
use crate::traits::basic_snippet::BasicSnippet;
use crate::verifier::fri::verify::fri_verify_type;
pub struct DeriveFriFromStark {
pub stark: Stark,
}
impl DeriveFriFromStark {
fn derive_fri_field_values(&self, library: &mut Library) -> Vec<LabelledInstruction> {
let next_power_of_two = library.import(Box::new(NextPowerOfTwo));
let domain_generator = library.import(Box::new(PrimitiveRootOfUnity));
let num_trace_randomizers = self.stark.num_trace_randomizers;
let fri_expansion_factor = self.stark.fri_expansion_factor;
let interpolant_codeword_length_code = triton_asm!(
push {num_trace_randomizers}
add
call {next_power_of_two}
);
let fri_domain_length = triton_asm!(
push {fri_expansion_factor}
mul
);
let domain_offset = BFieldElement::generator();
let num_collinearity_checks = self.stark.num_collinearity_checks;
let expansion_factor = self.stark.fri_expansion_factor;
triton_asm!(
{&interpolant_codeword_length_code}
{&fri_domain_length}
push {num_collinearity_checks}
push {expansion_factor}
swap 2
push {domain_offset}
dup 1
split
call {domain_generator}
)
}
}
impl BasicSnippet for DeriveFriFromStark {
fn inputs(&self) -> Vec<(DataType, String)> {
vec![(DataType::U32, "padded_height".to_owned())]
}
fn outputs(&self) -> Vec<(DataType, String)> {
vec![(
DataType::StructRef(fri_verify_type()),
"*fri_verify".to_owned(),
)]
}
fn entrypoint(&self) -> String {
"tasmlib_verifier_fri_derive_from_stark".to_owned()
}
fn code(&self, library: &mut crate::library::Library) -> Vec<LabelledInstruction> {
let entrypoint = self.entrypoint();
let derive_fri_field_values = self.derive_fri_field_values(library);
let dyn_malloc = library.import(Box::new(DynMalloc));
triton_asm!(
{entrypoint}:
{&derive_fri_field_values}
call {dyn_malloc}
write_mem 5
push -5
add
return
)
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use proptest_arbitrary_interop::arb;
use rand::prelude::*;
use test_strategy::proptest;
use super::*;
use crate::memory::encode_to_memory;
use crate::rust_shadowing_helper_functions;
use crate::snippet_bencher::BenchmarkCase;
use crate::traits::function::Function;
use crate::traits::function::FunctionInitialState;
use crate::traits::function::ShadowedFunction;
use crate::traits::rust_shadow::RustShadow;
use crate::verifier::fri::verify::FriVerify;
#[test]
fn fri_param_derivation_default_stark_pbt() {
ShadowedFunction::new(DeriveFriFromStark {
stark: Stark::default(),
})
.test();
}
#[proptest(cases = 10)]
fn fri_param_derivation_pbt_pbt(#[strategy(arb())] stark: Stark) {
ShadowedFunction::new(DeriveFriFromStark { stark }).test();
}
impl Function for DeriveFriFromStark {
fn rust_shadow(
&self,
stack: &mut Vec<BFieldElement>,
memory: &mut HashMap<BFieldElement, BFieldElement>,
) {
let padded_height: u32 = stack.pop().unwrap().try_into().unwrap();
let fri_from_tvm = self.stark.fri(padded_height.try_into().unwrap()).unwrap();
let local_fri: FriVerify = fri_from_tvm.into();
let fri_pointer =
rust_shadowing_helper_functions::dyn_malloc::dynamic_allocator(memory);
encode_to_memory(memory, fri_pointer, &local_fri);
stack.push(fri_pointer)
}
fn pseudorandom_initial_state(
&self,
seed: [u8; 32],
bench_case: Option<BenchmarkCase>,
) -> FunctionInitialState {
let padded_height: u32 = match bench_case {
Some(BenchmarkCase::CommonCase) => 2u32.pow(21),
Some(BenchmarkCase::WorstCase) => 2u32.pow(23),
None => {
let mut rng: StdRng = SeedableRng::from_seed(seed);
let mut padded_height = 2u32.pow(rng.gen_range(8..=25));
while self.stark.fri(padded_height as usize * 2).is_err() {
padded_height /= 2;
}
assert!(padded_height >= 2u32.pow(8));
padded_height
}
};
FunctionInitialState {
stack: [
self.init_stack_for_isolated_run(),
vec![padded_height.into()],
]
.concat(),
memory: HashMap::default(),
}
}
}
}