tasm_lib/verifier/fri/
collinearity_check_x.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
use triton_vm::prelude::*;

use crate::data_type::DataType;
use crate::field;
use crate::traits::basic_snippet::BasicSnippet;
use crate::verifier::fri::verify::FriVerify;

/// Compute domain\[index\]^(1<<round)
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
pub struct GetCollinearityCheckX;

impl BasicSnippet for GetCollinearityCheckX {
    fn inputs(&self) -> Vec<(DataType, String)> {
        vec![
            (DataType::VoidPointer, "*fri_verify".to_string()),
            (DataType::U32, "index".to_string()),
            (DataType::U32, "round".to_string()),
        ]
    }

    fn outputs(&self) -> Vec<(DataType, String)> {
        vec![(DataType::Xfe, "evaluation_argument".to_string())]
    }

    fn entrypoint(&self) -> String {
        "tasmlib_verifier_collinearity_check_x".to_string()
    }

    fn code(&self, _library: &mut crate::library::Library) -> Vec<LabelledInstruction> {
        let entrypoint = self.entrypoint();
        let domain_offset = field!(FriVerify::domain_offset);
        let domain_generator = field!(FriVerify::domain_generator);

        triton_asm! {
            // BEFORE: _ *fri_verify index round
            // AFTER:  _ x2 x1 x0
            {entrypoint}:
                dup 2               // _ *fri_verify index round *fri_verify
                {&domain_generator} // _ *fri_verify index round *domain_generator
                read_mem 1 pop 1    // _ *fri_verify index round domain_generator
                dup 2               // _ *fri_verify index round domain_generator index
                swap 1 pow          // _ *fri_verify index round domain_generator^index

                dup 3               // _ *fri_verify index round domain_generator^index *fri_verify
                {&domain_offset}    // _ *fri_verify index round domain_generator^index *domain_offset
                read_mem 1 pop 1    // _ *fri_verify index round domain_generator^index domain_offset
                mul                 // _ *fri_verify index round domain_generator^index*domain_offset

                dup 1 push 2 pow    // _ *fri_verify index round domain_generator^index*domain_offset 2^round

                swap 1 pow          // _ *fri_verify index round (domain_generator^index*domain_offset)^(1<<round)

                swap 3 pop 3        // _ (g^i*o)^(1<<r)
                push 0 push 0 swap 2
                                    // _ 0 0 (g^i*o)^(1<<r)
                return
        }
    }
}

#[cfg(test)]
mod test {
    use std::collections::HashMap;

    use num_traits::Zero;
    use rand::prelude::*;
    use triton_vm::twenty_first::prelude::BFieldElement;

    use super::*;
    use crate::empty_stack;
    use crate::memory::encode_to_memory;
    use crate::snippet_bencher::BenchmarkCase;
    use crate::structure::tasm_object::TasmObject;
    use crate::traits::function::Function;
    use crate::traits::function::FunctionInitialState;
    use crate::traits::function::ShadowedFunction;
    use crate::traits::rust_shadow::RustShadow;

    impl Function for GetCollinearityCheckX {
        fn rust_shadow(
            &self,
            stack: &mut Vec<BFieldElement>,
            memory: &mut HashMap<BFieldElement, BFieldElement>,
        ) {
            // read stack arguments
            let round = stack.pop().unwrap().value() as usize;
            let index = stack.pop().unwrap().value() as u32;
            let fri_verify_address = stack.pop().unwrap();

            // read fri_verify object from memory
            let fri_verify = FriVerify::decode_from_memory(memory, fri_verify_address).unwrap();

            // invoke actual function
            let x = fri_verify.get_collinearity_check_x(index, round);

            // push to stack
            stack.push(x.coefficients[2]);
            stack.push(x.coefficients[1]);
            stack.push(x.coefficients[0]);
        }

        fn pseudorandom_initial_state(
            &self,
            seed: [u8; 32],
            bench_case: Option<BenchmarkCase>,
        ) -> FunctionInitialState {
            let mut rng: StdRng = SeedableRng::from_seed(seed);
            let round = if let Some(case) = bench_case {
                match case {
                    BenchmarkCase::CommonCase => 10,
                    BenchmarkCase::WorstCase => 20,
                }
            } else {
                rng.gen_range(0..10)
            };
            let fri_domain_length = if let Some(case) = bench_case {
                match case {
                    BenchmarkCase::CommonCase => 1 << 20,
                    BenchmarkCase::WorstCase => 1 << 25,
                }
            } else {
                1 << (rng.gen_range(0..5) + round)
            };
            let index = rng.gen_range(0..fri_domain_length);

            let fri_verify = FriVerify::new(rng.gen(), fri_domain_length, 4, 40);

            let mut memory = HashMap::<BFieldElement, BFieldElement>::new();
            let fri_verify_address = BFieldElement::zero();
            encode_to_memory(&mut memory, fri_verify_address, &fri_verify);

            let mut stack = empty_stack();
            stack.push(fri_verify_address);
            stack.push(BFieldElement::new(index as u64));
            stack.push(BFieldElement::new(round as u64));

            FunctionInitialState { stack, memory }
        }
    }

    #[test]
    fn test() {
        ShadowedFunction::new(GetCollinearityCheckX).test();
    }
}

#[cfg(test)]
mod bench {
    use super::GetCollinearityCheckX;
    use crate::traits::function::ShadowedFunction;
    use crate::traits::rust_shadow::RustShadow;

    #[test]
    fn bench() {
        ShadowedFunction::new(GetCollinearityCheckX).bench();
    }
}