snarkvm_fields/traits/
poseidon_grain_lfsr.rs

1// Copyright 2024 Aleo Network Foundation
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#![allow(dead_code)]
17
18use crate::{FieldParameters, PrimeField};
19use snarkvm_utilities::{FromBits, vec::Vec};
20
21use anyhow::{Result, bail};
22
23pub struct PoseidonGrainLFSR {
24    pub field_size_in_bits: u64,
25    pub state: [bool; 80],
26    pub head: usize,
27}
28
29impl PoseidonGrainLFSR {
30    pub fn new(
31        is_sbox_an_inverse: bool,
32        field_size_in_bits: u64,
33        state_len: u64,
34        num_full_rounds: u64,
35        num_partial_rounds: u64,
36    ) -> Self {
37        let mut state = [false; 80];
38
39        // b0, b1 describes the field
40        state[1] = true;
41
42        // b2, ..., b5 describes the S-BOX
43        state[5] = is_sbox_an_inverse;
44
45        // b6, ..., b17 are the binary representation of n (prime_num_bits)
46        {
47            let mut cur = field_size_in_bits;
48            for i in (6..=17).rev() {
49                state[i] = cur & 1 == 1;
50                cur >>= 1;
51            }
52        }
53
54        // b18, ..., b29 are the binary representation of t (state_len, rate + capacity)
55        {
56            let mut cur = state_len;
57            for i in (18..=29).rev() {
58                state[i] = cur & 1 == 1;
59                cur >>= 1;
60            }
61        }
62
63        // b30, ..., b39 are the binary representation of R_F (the number of full rounds)
64        {
65            let mut cur = num_full_rounds;
66            for i in (30..=39).rev() {
67                state[i] = cur & 1 == 1;
68                cur >>= 1;
69            }
70        }
71
72        // b40, ..., b49 are the binary representation of R_P (the number of partial rounds)
73        {
74            let mut cur = num_partial_rounds;
75            for i in (40..=49).rev() {
76                state[i] = cur & 1 == 1;
77                cur >>= 1;
78            }
79        }
80
81        // b50, ..., b79 are set to 1
82        state[50..=79].copy_from_slice(&[true; 30]);
83
84        // Initialize.
85        let mut res = Self { field_size_in_bits, state, head: 0 };
86        for _ in 0..160 {
87            res.next_bit();
88        }
89        res
90    }
91
92    pub fn get_field_elements_rejection_sampling<F: PrimeField>(&mut self, num_elements: usize) -> Result<Vec<F>> {
93        // Ensure the number of bits matches the modulus.
94        if self.field_size_in_bits != F::Parameters::MODULUS_BITS as u64 {
95            bail!("The number of bits in the field must match the modulus");
96        }
97
98        let mut output = Vec::with_capacity(num_elements);
99        let mut bits = Vec::with_capacity(self.field_size_in_bits as usize);
100        for _ in 0..num_elements {
101            // Perform rejection sampling.
102            loop {
103                // Obtain `n` bits and make it most-significant-bit first.
104                bits.extend(self.get_bits(self.field_size_in_bits as usize));
105                bits.reverse();
106                // Construct the number.
107                let bigint = F::BigInteger::from_bits_le(&bits)?;
108                bits.clear();
109                // Ensure the number is in the field.
110                if let Some(element) = F::from_bigint(bigint) {
111                    output.push(element);
112                    break;
113                }
114            }
115        }
116        Ok(output)
117    }
118
119    pub fn get_field_elements_mod_p<F: PrimeField>(&mut self, num_elems: usize) -> Result<Vec<F>> {
120        // Ensure the number of bits matches the modulus.
121        let num_bits = self.field_size_in_bits;
122        if num_bits != F::Parameters::MODULUS_BITS as u64 {
123            bail!("The number of bits in the field must match the modulus");
124        }
125
126        // Prepare reusable vectors for the intermediate bits and bytes.
127        let mut bits = Vec::with_capacity(num_bits as usize);
128        let mut bytes = Vec::with_capacity((num_bits as usize + 7) / 8);
129
130        let mut output = Vec::with_capacity(num_elems);
131        for _ in 0..num_elems {
132            // Obtain `n` bits and make it most-significant-bit first.
133            let bits_iter = self.get_bits(num_bits as usize);
134            for bit in bits_iter {
135                bits.push(bit);
136            }
137            bits.reverse();
138
139            for byte in bits
140                .chunks(8)
141                .map(|chunk| {
142                    let mut sum = chunk[0] as u8;
143                    let mut cur = 1;
144                    for i in chunk.iter().skip(1) {
145                        cur *= 2;
146                        sum += cur * (*i as u8);
147                    }
148                    sum
149                })
150                .rev()
151            {
152                bytes.push(byte);
153            }
154
155            output.push(F::from_bytes_be_mod_order(&bytes));
156
157            // Clear the vectors of bits and bytes so they can be reused
158            // in the next iteration.
159            bits.clear();
160            bytes.clear();
161        }
162        Ok(output)
163    }
164}
165
166impl PoseidonGrainLFSR {
167    #[inline]
168    fn get_bits(&mut self, num_bits: usize) -> LFSRIter<'_> {
169        LFSRIter { lfsr: self, num_bits, current_bit: 0 }
170    }
171
172    #[inline]
173    fn next_bit(&mut self) -> bool {
174        let next_bit = self.state[(self.head + 62) % 80]
175            ^ self.state[(self.head + 51) % 80]
176            ^ self.state[(self.head + 38) % 80]
177            ^ self.state[(self.head + 23) % 80]
178            ^ self.state[(self.head + 13) % 80]
179            ^ self.state[self.head];
180        self.state[self.head] = next_bit;
181        self.head += 1;
182        self.head %= 80;
183
184        next_bit
185    }
186}
187
188pub struct LFSRIter<'a> {
189    lfsr: &'a mut PoseidonGrainLFSR,
190    num_bits: usize,
191    current_bit: usize,
192}
193
194impl Iterator for LFSRIter<'_> {
195    type Item = bool;
196
197    fn next(&mut self) -> Option<Self::Item> {
198        if self.current_bit < self.num_bits {
199            // Obtain the first bit
200            let mut new_bit = self.lfsr.next_bit();
201
202            // Loop until the first bit is true
203            while !new_bit {
204                // Discard the second bit
205                let _ = self.lfsr.next_bit();
206                // Obtain another first bit
207                new_bit = self.lfsr.next_bit();
208            }
209            self.current_bit += 1;
210
211            // Obtain the second bit
212            Some(self.lfsr.next_bit())
213        } else {
214            None
215        }
216    }
217}
218
219impl ExactSizeIterator for LFSRIter<'_> {
220    fn len(&self) -> usize {
221        self.num_bits
222    }
223}