tasm_lib/verifier/fri/derive_from_stark.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
use triton_vm::prelude::*;
use crate::arithmetic::bfe::primitive_root_of_unity::PrimitiveRootOfUnity;
use crate::arithmetic::u32::next_power_of_two::NextPowerOfTwo;
use crate::prelude::*;
use crate::verifier::fri::verify::fri_verify_type;
/// Mimics Triton-VM's FRI parameter-derivation method, but doesn't allow for a FRI-domain length
/// of 2^32 bc the domain length is stored in a single word/a `u32`.
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
pub struct DeriveFriFromStark {
pub stark: Stark,
}
impl DeriveFriFromStark {
fn derive_fri_field_values(&self, library: &mut Library) -> Vec<LabelledInstruction> {
let next_power_of_two = library.import(Box::new(NextPowerOfTwo));
let domain_generator = library.import(Box::new(PrimitiveRootOfUnity));
let num_trace_randomizers = self.stark.num_trace_randomizers;
let fri_expansion_factor = self.stark.fri_expansion_factor;
let interpolant_codeword_length_code = triton_asm!(
// _ padded_height
push {num_trace_randomizers}
add
// _ (padded_height + num_trace_randomizers)
call {next_power_of_two}
// _ next_pow2(padded_height + num_trace_randomizers)
// _ interpolant_codeword_length
);
let fri_domain_length = triton_asm!(
// _ interpolant_codeword_length
push {fri_expansion_factor}
mul
// _ (interpolant_codeword_length * fri_expansion_factor)
// _ fri_domain_length
);
let domain_offset = BFieldElement::generator();
let num_collinearity_checks = self.stark.num_collinearity_checks;
let expansion_factor = self.stark.fri_expansion_factor;
triton_asm!(
// _ padded_height
{&interpolant_codeword_length_code}
{&fri_domain_length}
// _ fri_domain_length
push {num_collinearity_checks}
// _ fri_domain_length num_collinearity_checks
push {expansion_factor}
// _ fri_domain_length num_collinearity_checks expansion_factor
swap 2
// _ expansion_factor num_collinearity_checks fri_domain_length
push {domain_offset}
// _ expansion_factor num_collinearity_checks fri_domain_length domain_offset
dup 1
split
call {domain_generator}
// _ expansion_factor num_collinearity_checks fri_domain_length domain_offset domain_generator
)
}
}
impl BasicSnippet for DeriveFriFromStark {
fn inputs(&self) -> Vec<(DataType, String)> {
vec![(DataType::U32, "padded_height".to_owned())]
}
fn outputs(&self) -> Vec<(DataType, String)> {
vec![(
DataType::StructRef(fri_verify_type()),
"*fri_verify".to_owned(),
)]
}
fn entrypoint(&self) -> String {
"tasmlib_verifier_fri_derive_from_stark".to_owned()
}
fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
let entrypoint = self.entrypoint();
let derive_fri_field_values = self.derive_fri_field_values(library);
let dyn_malloc = library.import(Box::new(DynMalloc));
triton_asm!(
{entrypoint}:
// _ padded_height
{&derive_fri_field_values}
// _ fri_domain_length domain_offset domain_generator num_collinearity_checks expansion_factor
call {dyn_malloc}
// _ fri_domain_length domain_offset domain_generator num_collinearity_checks expansion_factor *fri_verify
write_mem 5
// _ (*fri_verify + 5)
push -5
add
// _ *fri_verify
return
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::rust_shadowing_helper_functions;
use crate::test_prelude::*;
use crate::verifier::fri::verify::FriVerify;
#[test]
fn fri_param_derivation_default_stark_pbt() {
ShadowedFunction::new(DeriveFriFromStark {
stark: Stark::default(),
})
.test();
}
#[proptest(cases = 10)]
fn fri_param_derivation_pbt_pbt(#[strategy(arb())] stark: Stark) {
ShadowedFunction::new(DeriveFriFromStark { stark }).test();
}
impl Function for DeriveFriFromStark {
fn rust_shadow(
&self,
stack: &mut Vec<BFieldElement>,
memory: &mut HashMap<BFieldElement, BFieldElement>,
) {
let padded_height: u32 = stack.pop().unwrap().try_into().unwrap();
let fri_from_tvm = self.stark.fri(padded_height.try_into().unwrap()).unwrap();
let local_fri: FriVerify = fri_from_tvm.into();
let fri_pointer =
rust_shadowing_helper_functions::dyn_malloc::dynamic_allocator(memory);
encode_to_memory(memory, fri_pointer, &local_fri);
stack.push(fri_pointer)
}
fn pseudorandom_initial_state(
&self,
seed: [u8; 32],
bench_case: Option<BenchmarkCase>,
) -> FunctionInitialState {
let padded_height: u32 = match bench_case {
Some(BenchmarkCase::CommonCase) => 2u32.pow(21),
Some(BenchmarkCase::WorstCase) => 2u32.pow(23),
None => {
let mut rng = StdRng::from_seed(seed);
let mut padded_height = 2u32.pow(rng.random_range(8..=25));
// Don't test parameters that result in too big FRI domains, i.e. larger
// than 2^32. Note that this also excludes 2^32 as domain length because
// the type used to hold this value is a `u32` in this repo. I think such a
// large FRI domain is unfeasible anyway, so I'm reasonably comfortable
// excluding that option.
while self.stark.fri(padded_height as usize * 2).is_err() {
padded_height /= 2;
}
assert!(padded_height >= 2u32.pow(8));
padded_height
}
};
FunctionInitialState {
stack: [
self.init_stack_for_isolated_run(),
vec![padded_height.into()],
]
.concat(),
memory: HashMap::default(),
}
}
}
}