snarkvm_algorithms/crypto_hash/
poseidon.rs

1// Copyright 2024 Aleo Network Foundation
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::{AlgebraicSponge, DuplexSpongeMode, nonnative_params::*};
17use snarkvm_fields::{FieldParameters, PoseidonParameters, PrimeField, ToConstraintField};
18use snarkvm_utilities::{BigInteger, FromBits, ToBits};
19
20use smallvec::SmallVec;
21use std::{
22    iter::Peekable,
23    ops::{Index, IndexMut},
24    sync::Arc,
25};
26
27#[derive(Copy, Clone, Debug)]
28pub struct State<F: PrimeField, const RATE: usize, const CAPACITY: usize> {
29    capacity_state: [F; CAPACITY],
30    rate_state: [F; RATE],
31}
32
33impl<F: PrimeField, const RATE: usize, const CAPACITY: usize> Default for State<F, RATE, CAPACITY> {
34    fn default() -> Self {
35        Self { capacity_state: [F::zero(); CAPACITY], rate_state: [F::zero(); RATE] }
36    }
37}
38
39impl<F: PrimeField, const RATE: usize, const CAPACITY: usize> State<F, RATE, CAPACITY> {
40    /// Returns an immutable iterator over the state.
41    pub fn iter(&self) -> impl Iterator<Item = &F> + Clone {
42        self.capacity_state.iter().chain(self.rate_state.iter())
43    }
44
45    /// Returns a mutable iterator over the state.
46    pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut F> {
47        self.capacity_state.iter_mut().chain(self.rate_state.iter_mut())
48    }
49}
50
51impl<F: PrimeField, const RATE: usize, const CAPACITY: usize> Index<usize> for State<F, RATE, CAPACITY> {
52    type Output = F;
53
54    fn index(&self, index: usize) -> &Self::Output {
55        assert!(index < RATE + CAPACITY, "Index out of bounds: index is {} but length is {}", index, RATE + CAPACITY);
56        if index < CAPACITY { &self.capacity_state[index] } else { &self.rate_state[index - CAPACITY] }
57    }
58}
59
60impl<F: PrimeField, const RATE: usize, const CAPACITY: usize> IndexMut<usize> for State<F, RATE, CAPACITY> {
61    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
62        assert!(index < RATE + CAPACITY, "Index out of bounds: index is {} but length is {}", index, RATE + CAPACITY);
63        if index < CAPACITY { &mut self.capacity_state[index] } else { &mut self.rate_state[index - CAPACITY] }
64    }
65}
66
67#[derive(Clone, Debug, PartialEq, Eq)]
68pub struct Poseidon<F: PrimeField, const RATE: usize> {
69    parameters: Arc<PoseidonParameters<F, RATE, 1>>,
70}
71
72impl<F: PrimeField, const RATE: usize> Poseidon<F, RATE> {
73    /// Initializes a new instance of the cryptographic hash function.
74    pub fn setup() -> Self {
75        Self { parameters: Arc::new(F::default_poseidon_parameters::<RATE>().unwrap()) }
76    }
77
78    /// Evaluate the cryptographic hash function over a list of field elements as input.
79    pub fn evaluate(&self, input: &[F]) -> F {
80        self.evaluate_many(input, 1)[0]
81    }
82
83    /// Evaluate the cryptographic hash function over a list of field elements as input,
84    /// and returns the specified number of field elements as output.
85    pub fn evaluate_many(&self, input: &[F], num_outputs: usize) -> Vec<F> {
86        let mut sponge = PoseidonSponge::<F, RATE, 1>::new_with_parameters(&self.parameters);
87        sponge.absorb_native_field_elements(input);
88        sponge.squeeze_native_field_elements(num_outputs).to_vec()
89    }
90
91    /// Evaluate the cryptographic hash function over a non-fixed-length vector,
92    /// in which the length also needs to be hashed.
93    pub fn evaluate_with_len(&self, input: &[F]) -> F {
94        self.evaluate(&[vec![F::from(input.len() as u128)], input.to_vec()].concat())
95    }
96
97    pub fn parameters(&self) -> &Arc<PoseidonParameters<F, RATE, 1>> {
98        &self.parameters
99    }
100}
101
102/// A duplex sponge based using the Poseidon permutation.
103///
104/// This implementation of Poseidon is entirely from Fractal's implementation in [COS20][cos]
105/// with small syntax changes.
106///
107/// [cos]: https://eprint.iacr.org/2019/1076
108#[derive(Clone, Debug)]
109pub struct PoseidonSponge<F: PrimeField, const RATE: usize, const CAPACITY: usize> {
110    /// Sponge Parameters
111    parameters: Arc<PoseidonParameters<F, RATE, CAPACITY>>,
112    /// Current sponge's state (current elements in the permutation block)
113    state: State<F, RATE, CAPACITY>,
114    /// Current mode (whether its absorbing or squeezing)
115    pub mode: DuplexSpongeMode,
116    /// A persistent lookup table used when compressing elements.
117    adjustment_factor_lookup_table: Arc<[F]>,
118}
119
120impl<F: PrimeField, const RATE: usize> AlgebraicSponge<F, RATE> for PoseidonSponge<F, RATE, 1> {
121    type Parameters = Arc<PoseidonParameters<F, RATE, 1>>;
122
123    fn sample_parameters() -> Self::Parameters {
124        Arc::new(F::default_poseidon_parameters::<RATE>().unwrap())
125    }
126
127    fn new_with_parameters(parameters: &Self::Parameters) -> Self {
128        Self {
129            parameters: parameters.clone(),
130            state: State::default(),
131            mode: DuplexSpongeMode::Absorbing { next_absorb_index: 0 },
132            adjustment_factor_lookup_table: {
133                let capacity = F::size_in_bits() - 1;
134                let mut table = Vec::<F>::with_capacity(capacity);
135
136                let mut cur = F::one();
137                for _ in 0..capacity {
138                    table.push(cur);
139                    cur.double_in_place();
140                }
141
142                table.into()
143            },
144        }
145    }
146
147    /// Takes in field elements.
148    fn absorb_native_field_elements<T: ToConstraintField<F>>(&mut self, elements: &[T]) {
149        let input = elements.iter().flat_map(|e| e.to_field_elements().unwrap()).collect::<Vec<_>>();
150        if !input.is_empty() {
151            match self.mode {
152                DuplexSpongeMode::Absorbing { mut next_absorb_index } => {
153                    if next_absorb_index == RATE {
154                        self.permute();
155                        next_absorb_index = 0;
156                    }
157                    self.absorb_internal(next_absorb_index, &input);
158                }
159                DuplexSpongeMode::Squeezing { next_squeeze_index: _ } => {
160                    self.permute();
161                    self.absorb_internal(0, &input);
162                }
163            }
164        }
165    }
166
167    /// Takes in field elements.
168    fn absorb_nonnative_field_elements<Target: PrimeField>(&mut self, elements: impl IntoIterator<Item = Target>) {
169        Self::push_elements_to_sponge(self, elements, OptimizationType::Weight);
170    }
171
172    fn squeeze_nonnative_field_elements<Target: PrimeField>(&mut self, num: usize) -> SmallVec<[Target; 10]> {
173        self.get_fe(num, false)
174    }
175
176    fn squeeze_native_field_elements(&mut self, num_elements: usize) -> SmallVec<[F; 10]> {
177        if num_elements == 0 {
178            return SmallVec::<[F; 10]>::new();
179        }
180        let mut output = if num_elements <= 10 {
181            smallvec::smallvec_inline![F::zero(); 10]
182        } else {
183            smallvec::smallvec![F::zero(); num_elements]
184        };
185
186        match self.mode {
187            DuplexSpongeMode::Absorbing { next_absorb_index: _ } => {
188                self.permute();
189                self.squeeze_internal(0, &mut output[..num_elements]);
190            }
191            DuplexSpongeMode::Squeezing { mut next_squeeze_index } => {
192                if next_squeeze_index == RATE {
193                    self.permute();
194                    next_squeeze_index = 0;
195                }
196                self.squeeze_internal(next_squeeze_index, &mut output[..num_elements]);
197            }
198        }
199
200        output.truncate(num_elements);
201        output
202    }
203
204    /// Takes out field elements of 168 bits.
205    fn squeeze_short_nonnative_field_elements<Target: PrimeField>(&mut self, num: usize) -> SmallVec<[Target; 10]> {
206        self.get_fe(num, true)
207    }
208}
209
210impl<F: PrimeField, const RATE: usize> PoseidonSponge<F, RATE, 1> {
211    #[inline]
212    fn apply_ark(&mut self, round_number: usize) {
213        for (state_elem, ark_elem) in self.state.iter_mut().zip(&self.parameters.ark[round_number]) {
214            *state_elem += ark_elem;
215        }
216    }
217
218    #[inline]
219    fn apply_s_box(&mut self, is_full_round: bool) {
220        if is_full_round {
221            // Full rounds apply the S Box (x^alpha) to every element of state
222            for elem in self.state.iter_mut() {
223                *elem = elem.pow([self.parameters.alpha]);
224            }
225        } else {
226            // Partial rounds apply the S Box (x^alpha) to just the first element of state
227            self.state[0] = self.state[0].pow([self.parameters.alpha]);
228        }
229    }
230
231    #[inline]
232    fn apply_mds(&mut self) {
233        let mut new_state = State::default();
234        new_state.iter_mut().zip(&self.parameters.mds).for_each(|(new_elem, mds_row)| {
235            *new_elem = F::sum_of_products(self.state.iter(), mds_row.iter());
236        });
237        self.state = new_state;
238    }
239
240    #[inline]
241    fn permute(&mut self) {
242        // Determine the partial rounds range bound.
243        let partial_rounds = self.parameters.partial_rounds;
244        let full_rounds = self.parameters.full_rounds;
245        let full_rounds_over_2 = full_rounds / 2;
246        let partial_round_range = full_rounds_over_2..(full_rounds_over_2 + partial_rounds);
247
248        // Iterate through all rounds to permute.
249        for i in 0..(partial_rounds + full_rounds) {
250            let is_full_round = !partial_round_range.contains(&i);
251            self.apply_ark(i);
252            self.apply_s_box(is_full_round);
253            self.apply_mds();
254        }
255    }
256
257    /// Absorbs everything in elements, this does not end in an absorption.
258    #[inline]
259    fn absorb_internal(&mut self, mut rate_start: usize, input: &[F]) {
260        if !input.is_empty() {
261            let first_chunk_size = std::cmp::min(RATE - rate_start, input.len());
262            let num_elements_remaining = input.len() - first_chunk_size;
263            let (first_chunk, rest_chunk) = input.split_at(first_chunk_size);
264            let rest_chunks = rest_chunk.chunks(RATE);
265            // The total number of chunks is `elements[num_elements_remaining..].len() / RATE`, plus 1
266            // for the remainder.
267            let total_num_chunks = 1 + // 1 for the first chunk
268                // We add all the chunks that are perfectly divisible by `RATE`
269                (num_elements_remaining / RATE) +
270                // And also add 1 if the last chunk is non-empty
271                // (i.e. if `num_elements_remaining` is not a multiple of `RATE`)
272                usize::from((num_elements_remaining % RATE) != 0);
273
274            // Absorb the input elements, `RATE` elements at a time, except for the first chunk, which
275            // is of size `RATE - rate_start`.
276            for (i, chunk) in std::iter::once(first_chunk).chain(rest_chunks).enumerate() {
277                for (element, state_elem) in chunk.iter().zip(&mut self.state.rate_state[rate_start..]) {
278                    *state_elem += element;
279                }
280                // Are we in the last chunk?
281                // If so, let's wrap up.
282                if i == total_num_chunks - 1 {
283                    self.mode = DuplexSpongeMode::Absorbing { next_absorb_index: rate_start + chunk.len() };
284                    return;
285                } else {
286                    self.permute();
287                }
288                rate_start = 0;
289            }
290        }
291    }
292
293    /// Squeeze |output| many elements. This does not end in a squeeze
294    #[inline]
295    fn squeeze_internal(&mut self, mut rate_start: usize, output: &mut [F]) {
296        let output_size = output.len();
297        if output_size != 0 {
298            let first_chunk_size = std::cmp::min(RATE - rate_start, output.len());
299            let num_output_remaining = output.len() - first_chunk_size;
300            let (first_chunk, rest_chunk) = output.split_at_mut(first_chunk_size);
301            assert_eq!(rest_chunk.len(), num_output_remaining);
302            let rest_chunks = rest_chunk.chunks_mut(RATE);
303            // The total number of chunks is `output[num_output_remaining..].len() / RATE`, plus 1
304            // for the remainder.
305            let total_num_chunks = 1 + // 1 for the first chunk
306                // We add all the chunks that are perfectly divisible by `RATE`
307                (num_output_remaining / RATE) +
308                // And also add 1 if the last chunk is non-empty
309                // (i.e. if `num_output_remaining` is not a multiple of `RATE`)
310                usize::from((num_output_remaining % RATE) != 0);
311
312            // Absorb the input output, `RATE` output at a time, except for the first chunk, which
313            // is of size `RATE - rate_start`.
314            for (i, chunk) in std::iter::once(first_chunk).chain(rest_chunks).enumerate() {
315                let range = rate_start..(rate_start + chunk.len());
316                debug_assert_eq!(
317                    chunk.len(),
318                    self.state.rate_state[range.clone()].len(),
319                    "failed with squeeze {output_size} at rate {RATE} and rate_start {rate_start}"
320                );
321                chunk.copy_from_slice(&self.state.rate_state[range]);
322                // Are we in the last chunk?
323                // If so, let's wrap up.
324                if i == total_num_chunks - 1 {
325                    self.mode = DuplexSpongeMode::Squeezing { next_squeeze_index: (rate_start + chunk.len()) };
326                    return;
327                } else {
328                    self.permute();
329                }
330                rate_start = 0;
331            }
332        }
333    }
334
335    /// Compress every two elements if possible.
336    /// Provides a vector of (limb, num_of_additions), both of which are F.
337    pub fn compress_elements<TargetField: PrimeField, I: Iterator<Item = (F, F)>>(
338        &self,
339        mut src_limbs: Peekable<I>,
340        ty: OptimizationType,
341    ) -> Vec<F> {
342        let capacity = F::size_in_bits() - 1;
343        let mut dest_limbs = Vec::<F>::new();
344
345        let params = get_params(TargetField::size_in_bits(), F::size_in_bits(), ty);
346
347        // Prepare a reusable vector to be used in overhead calculation.
348        let mut num_bits = Vec::new();
349
350        while let Some(first) = src_limbs.next() {
351            let second = src_limbs.peek();
352
353            let first_max_bits_per_limb = params.bits_per_limb + crate::overhead!(first.1 + F::one(), &mut num_bits);
354            let second_max_bits_per_limb = if let Some(second) = second {
355                params.bits_per_limb + crate::overhead!(second.1 + F::one(), &mut num_bits)
356            } else {
357                0
358            };
359
360            if let Some(second) = second {
361                if first_max_bits_per_limb + second_max_bits_per_limb <= capacity {
362                    let adjustment_factor = &self.adjustment_factor_lookup_table[second_max_bits_per_limb];
363
364                    dest_limbs.push(first.0 * adjustment_factor + second.0);
365                    src_limbs.next();
366                } else {
367                    dest_limbs.push(first.0);
368                }
369            } else {
370                dest_limbs.push(first.0);
371            }
372        }
373
374        dest_limbs
375    }
376
377    /// Convert a `TargetField` element into limbs (not constraints)
378    /// This is an internal function that would be reused by a number of other functions
379    pub fn get_limbs_representations<TargetField: PrimeField>(
380        elem: &TargetField,
381        optimization_type: OptimizationType,
382    ) -> SmallVec<[F; 10]> {
383        Self::get_limbs_representations_from_big_integer::<TargetField>(&elem.to_bigint(), optimization_type)
384    }
385
386    /// Obtain the limbs directly from a big int
387    pub fn get_limbs_representations_from_big_integer<TargetField: PrimeField>(
388        elem: &<TargetField as PrimeField>::BigInteger,
389        optimization_type: OptimizationType,
390    ) -> SmallVec<[F; 10]> {
391        let params = get_params(TargetField::size_in_bits(), F::size_in_bits(), optimization_type);
392
393        // Prepare a reusable vector for the BE bits.
394        let mut cur_bits = Vec::new();
395        // Push the lower limbs first
396        let mut limbs: SmallVec<[F; 10]> = SmallVec::new();
397        let mut cur = *elem;
398        for _ in 0..params.num_limbs {
399            cur.write_bits_be(&mut cur_bits); // `write_bits_be` is big endian
400            let cur_mod_r =
401                <F as PrimeField>::BigInteger::from_bits_be(&cur_bits[cur_bits.len() - params.bits_per_limb..])
402                    .unwrap(); // therefore, the lowest `bits_per_non_top_limb` bits is what we want.
403            limbs.push(F::from_bigint(cur_mod_r).unwrap());
404            cur.divn(params.bits_per_limb as u32);
405            // Clear the vector after every iteration so its allocation can be reused.
406            cur_bits.clear();
407        }
408
409        // then we reverse, so that the limbs are ``big limb first''
410        limbs.reverse();
411
412        limbs
413    }
414
415    /// Push elements to sponge, treated in the non-native field representations.
416    pub fn push_elements_to_sponge<TargetField: PrimeField>(
417        &mut self,
418        src: impl IntoIterator<Item = TargetField>,
419        ty: OptimizationType,
420    ) {
421        let src_limbs = src
422            .into_iter()
423            .flat_map(|elem| {
424                let limbs = Self::get_limbs_representations(&elem, ty);
425                limbs.into_iter().map(|limb| (limb, F::one()))
426                // specifically set to one, since most gadgets in the constraint world would not have zero noise (due to the relatively weak normal form testing in `alloc`)
427            })
428            .peekable();
429
430        let dest_limbs = self.compress_elements::<TargetField, _>(src_limbs, ty);
431        self.absorb_native_field_elements(&dest_limbs);
432    }
433
434    /// obtain random bits from hashchain.
435    /// not guaranteed to be uniformly distributed, should only be used in certain situations.
436    pub fn get_bits(&mut self, num_bits: usize) -> Vec<bool> {
437        let bits_per_element = F::size_in_bits() - 1;
438        let num_elements = num_bits.div_ceil(bits_per_element);
439
440        let src_elements = self.squeeze_native_field_elements(num_elements);
441        let mut dest_bits = Vec::<bool>::with_capacity(num_elements * bits_per_element);
442
443        let skip = (F::Parameters::REPR_SHAVE_BITS + 1) as usize;
444        for elem in src_elements.iter() {
445            // discard the highest bit
446            let elem_bits = elem.to_bigint().to_bits_be();
447            dest_bits.extend_from_slice(&elem_bits[skip..]);
448        }
449        dest_bits.truncate(num_bits);
450
451        dest_bits
452    }
453
454    /// obtain random field elements from hashchain.
455    /// not guaranteed to be uniformly distributed, should only be used in certain situations.
456    pub fn get_fe<TargetField: PrimeField>(
457        &mut self,
458        num_elements: usize,
459        outputs_short_elements: bool,
460    ) -> SmallVec<[TargetField; 10]> {
461        let num_bits_per_nonnative = if outputs_short_elements {
462            168
463        } else {
464            TargetField::size_in_bits() - 1 // also omit the highest bit
465        };
466        let bits = self.get_bits(num_bits_per_nonnative * num_elements);
467
468        let mut lookup_table = Vec::<TargetField>::with_capacity(num_bits_per_nonnative);
469        let mut cur = TargetField::one();
470        for _ in 0..num_bits_per_nonnative {
471            lookup_table.push(cur);
472            cur.double_in_place();
473        }
474
475        let dest_elements = bits
476            .chunks_exact(num_bits_per_nonnative)
477            .map(|per_nonnative_bits| {
478                // technically, this can be done via BigInterger::from_bits; here, we use this method for consistency with the gadget counterpart
479                let mut res = TargetField::zero();
480
481                for (i, bit) in per_nonnative_bits.iter().rev().enumerate() {
482                    if *bit {
483                        res += &lookup_table[i];
484                    }
485                }
486                res
487            })
488            .collect::<SmallVec<_>>();
489        debug_assert_eq!(dest_elements.len(), num_elements);
490
491        dest_elements
492    }
493}