snarkvm_algorithms/crypto_hash/
poseidon.rs1use crate::{AlgebraicSponge, DuplexSpongeMode, nonnative_params::*};
17use snarkvm_fields::{FieldParameters, PoseidonParameters, PrimeField, ToConstraintField};
18use snarkvm_utilities::{BigInteger, FromBits, ToBits};
19
20use smallvec::SmallVec;
21use std::{
22 iter::Peekable,
23 ops::{Index, IndexMut},
24 sync::Arc,
25};
26
27#[derive(Copy, Clone, Debug)]
28pub struct State<F: PrimeField, const RATE: usize, const CAPACITY: usize> {
29 capacity_state: [F; CAPACITY],
30 rate_state: [F; RATE],
31}
32
33impl<F: PrimeField, const RATE: usize, const CAPACITY: usize> Default for State<F, RATE, CAPACITY> {
34 fn default() -> Self {
35 Self { capacity_state: [F::zero(); CAPACITY], rate_state: [F::zero(); RATE] }
36 }
37}
38
39impl<F: PrimeField, const RATE: usize, const CAPACITY: usize> State<F, RATE, CAPACITY> {
40 pub fn iter(&self) -> impl Iterator<Item = &F> + Clone {
42 self.capacity_state.iter().chain(self.rate_state.iter())
43 }
44
45 pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut F> {
47 self.capacity_state.iter_mut().chain(self.rate_state.iter_mut())
48 }
49}
50
51impl<F: PrimeField, const RATE: usize, const CAPACITY: usize> Index<usize> for State<F, RATE, CAPACITY> {
52 type Output = F;
53
54 fn index(&self, index: usize) -> &Self::Output {
55 assert!(index < RATE + CAPACITY, "Index out of bounds: index is {} but length is {}", index, RATE + CAPACITY);
56 if index < CAPACITY { &self.capacity_state[index] } else { &self.rate_state[index - CAPACITY] }
57 }
58}
59
60impl<F: PrimeField, const RATE: usize, const CAPACITY: usize> IndexMut<usize> for State<F, RATE, CAPACITY> {
61 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
62 assert!(index < RATE + CAPACITY, "Index out of bounds: index is {} but length is {}", index, RATE + CAPACITY);
63 if index < CAPACITY { &mut self.capacity_state[index] } else { &mut self.rate_state[index - CAPACITY] }
64 }
65}
66
67#[derive(Clone, Debug, PartialEq, Eq)]
68pub struct Poseidon<F: PrimeField, const RATE: usize> {
69 parameters: Arc<PoseidonParameters<F, RATE, 1>>,
70}
71
72impl<F: PrimeField, const RATE: usize> Poseidon<F, RATE> {
73 pub fn setup() -> Self {
75 Self { parameters: Arc::new(F::default_poseidon_parameters::<RATE>().unwrap()) }
76 }
77
78 pub fn evaluate(&self, input: &[F]) -> F {
80 self.evaluate_many(input, 1)[0]
81 }
82
83 pub fn evaluate_many(&self, input: &[F], num_outputs: usize) -> Vec<F> {
86 let mut sponge = PoseidonSponge::<F, RATE, 1>::new_with_parameters(&self.parameters);
87 sponge.absorb_native_field_elements(input);
88 sponge.squeeze_native_field_elements(num_outputs).to_vec()
89 }
90
91 pub fn evaluate_with_len(&self, input: &[F]) -> F {
94 self.evaluate(&[vec![F::from(input.len() as u128)], input.to_vec()].concat())
95 }
96
97 pub fn parameters(&self) -> &Arc<PoseidonParameters<F, RATE, 1>> {
98 &self.parameters
99 }
100}
101
102#[derive(Clone, Debug)]
109pub struct PoseidonSponge<F: PrimeField, const RATE: usize, const CAPACITY: usize> {
110 parameters: Arc<PoseidonParameters<F, RATE, CAPACITY>>,
112 state: State<F, RATE, CAPACITY>,
114 pub mode: DuplexSpongeMode,
116 adjustment_factor_lookup_table: Arc<[F]>,
118}
119
120impl<F: PrimeField, const RATE: usize> AlgebraicSponge<F, RATE> for PoseidonSponge<F, RATE, 1> {
121 type Parameters = Arc<PoseidonParameters<F, RATE, 1>>;
122
123 fn sample_parameters() -> Self::Parameters {
124 Arc::new(F::default_poseidon_parameters::<RATE>().unwrap())
125 }
126
127 fn new_with_parameters(parameters: &Self::Parameters) -> Self {
128 Self {
129 parameters: parameters.clone(),
130 state: State::default(),
131 mode: DuplexSpongeMode::Absorbing { next_absorb_index: 0 },
132 adjustment_factor_lookup_table: {
133 let capacity = F::size_in_bits() - 1;
134 let mut table = Vec::<F>::with_capacity(capacity);
135
136 let mut cur = F::one();
137 for _ in 0..capacity {
138 table.push(cur);
139 cur.double_in_place();
140 }
141
142 table.into()
143 },
144 }
145 }
146
147 fn absorb_native_field_elements<T: ToConstraintField<F>>(&mut self, elements: &[T]) {
149 let input = elements.iter().flat_map(|e| e.to_field_elements().unwrap()).collect::<Vec<_>>();
150 if !input.is_empty() {
151 match self.mode {
152 DuplexSpongeMode::Absorbing { mut next_absorb_index } => {
153 if next_absorb_index == RATE {
154 self.permute();
155 next_absorb_index = 0;
156 }
157 self.absorb_internal(next_absorb_index, &input);
158 }
159 DuplexSpongeMode::Squeezing { next_squeeze_index: _ } => {
160 self.permute();
161 self.absorb_internal(0, &input);
162 }
163 }
164 }
165 }
166
167 fn absorb_nonnative_field_elements<Target: PrimeField>(&mut self, elements: impl IntoIterator<Item = Target>) {
169 Self::push_elements_to_sponge(self, elements, OptimizationType::Weight);
170 }
171
172 fn squeeze_nonnative_field_elements<Target: PrimeField>(&mut self, num: usize) -> SmallVec<[Target; 10]> {
173 self.get_fe(num, false)
174 }
175
176 fn squeeze_native_field_elements(&mut self, num_elements: usize) -> SmallVec<[F; 10]> {
177 if num_elements == 0 {
178 return SmallVec::<[F; 10]>::new();
179 }
180 let mut output = if num_elements <= 10 {
181 smallvec::smallvec_inline![F::zero(); 10]
182 } else {
183 smallvec::smallvec![F::zero(); num_elements]
184 };
185
186 match self.mode {
187 DuplexSpongeMode::Absorbing { next_absorb_index: _ } => {
188 self.permute();
189 self.squeeze_internal(0, &mut output[..num_elements]);
190 }
191 DuplexSpongeMode::Squeezing { mut next_squeeze_index } => {
192 if next_squeeze_index == RATE {
193 self.permute();
194 next_squeeze_index = 0;
195 }
196 self.squeeze_internal(next_squeeze_index, &mut output[..num_elements]);
197 }
198 }
199
200 output.truncate(num_elements);
201 output
202 }
203
204 fn squeeze_short_nonnative_field_elements<Target: PrimeField>(&mut self, num: usize) -> SmallVec<[Target; 10]> {
206 self.get_fe(num, true)
207 }
208}
209
210impl<F: PrimeField, const RATE: usize> PoseidonSponge<F, RATE, 1> {
211 #[inline]
212 fn apply_ark(&mut self, round_number: usize) {
213 for (state_elem, ark_elem) in self.state.iter_mut().zip(&self.parameters.ark[round_number]) {
214 *state_elem += ark_elem;
215 }
216 }
217
218 #[inline]
219 fn apply_s_box(&mut self, is_full_round: bool) {
220 if is_full_round {
221 for elem in self.state.iter_mut() {
223 *elem = elem.pow([self.parameters.alpha]);
224 }
225 } else {
226 self.state[0] = self.state[0].pow([self.parameters.alpha]);
228 }
229 }
230
231 #[inline]
232 fn apply_mds(&mut self) {
233 let mut new_state = State::default();
234 new_state.iter_mut().zip(&self.parameters.mds).for_each(|(new_elem, mds_row)| {
235 *new_elem = F::sum_of_products(self.state.iter(), mds_row.iter());
236 });
237 self.state = new_state;
238 }
239
240 #[inline]
241 fn permute(&mut self) {
242 let partial_rounds = self.parameters.partial_rounds;
244 let full_rounds = self.parameters.full_rounds;
245 let full_rounds_over_2 = full_rounds / 2;
246 let partial_round_range = full_rounds_over_2..(full_rounds_over_2 + partial_rounds);
247
248 for i in 0..(partial_rounds + full_rounds) {
250 let is_full_round = !partial_round_range.contains(&i);
251 self.apply_ark(i);
252 self.apply_s_box(is_full_round);
253 self.apply_mds();
254 }
255 }
256
257 #[inline]
259 fn absorb_internal(&mut self, mut rate_start: usize, input: &[F]) {
260 if !input.is_empty() {
261 let first_chunk_size = std::cmp::min(RATE - rate_start, input.len());
262 let num_elements_remaining = input.len() - first_chunk_size;
263 let (first_chunk, rest_chunk) = input.split_at(first_chunk_size);
264 let rest_chunks = rest_chunk.chunks(RATE);
265 let total_num_chunks = 1 + (num_elements_remaining / RATE) +
270 usize::from((num_elements_remaining % RATE) != 0);
273
274 for (i, chunk) in std::iter::once(first_chunk).chain(rest_chunks).enumerate() {
277 for (element, state_elem) in chunk.iter().zip(&mut self.state.rate_state[rate_start..]) {
278 *state_elem += element;
279 }
280 if i == total_num_chunks - 1 {
283 self.mode = DuplexSpongeMode::Absorbing { next_absorb_index: rate_start + chunk.len() };
284 return;
285 } else {
286 self.permute();
287 }
288 rate_start = 0;
289 }
290 }
291 }
292
293 #[inline]
295 fn squeeze_internal(&mut self, mut rate_start: usize, output: &mut [F]) {
296 let output_size = output.len();
297 if output_size != 0 {
298 let first_chunk_size = std::cmp::min(RATE - rate_start, output.len());
299 let num_output_remaining = output.len() - first_chunk_size;
300 let (first_chunk, rest_chunk) = output.split_at_mut(first_chunk_size);
301 assert_eq!(rest_chunk.len(), num_output_remaining);
302 let rest_chunks = rest_chunk.chunks_mut(RATE);
303 let total_num_chunks = 1 + (num_output_remaining / RATE) +
308 usize::from((num_output_remaining % RATE) != 0);
311
312 for (i, chunk) in std::iter::once(first_chunk).chain(rest_chunks).enumerate() {
315 let range = rate_start..(rate_start + chunk.len());
316 debug_assert_eq!(
317 chunk.len(),
318 self.state.rate_state[range.clone()].len(),
319 "failed with squeeze {output_size} at rate {RATE} and rate_start {rate_start}"
320 );
321 chunk.copy_from_slice(&self.state.rate_state[range]);
322 if i == total_num_chunks - 1 {
325 self.mode = DuplexSpongeMode::Squeezing { next_squeeze_index: (rate_start + chunk.len()) };
326 return;
327 } else {
328 self.permute();
329 }
330 rate_start = 0;
331 }
332 }
333 }
334
335 pub fn compress_elements<TargetField: PrimeField, I: Iterator<Item = (F, F)>>(
338 &self,
339 mut src_limbs: Peekable<I>,
340 ty: OptimizationType,
341 ) -> Vec<F> {
342 let capacity = F::size_in_bits() - 1;
343 let mut dest_limbs = Vec::<F>::new();
344
345 let params = get_params(TargetField::size_in_bits(), F::size_in_bits(), ty);
346
347 let mut num_bits = Vec::new();
349
350 while let Some(first) = src_limbs.next() {
351 let second = src_limbs.peek();
352
353 let first_max_bits_per_limb = params.bits_per_limb + crate::overhead!(first.1 + F::one(), &mut num_bits);
354 let second_max_bits_per_limb = if let Some(second) = second {
355 params.bits_per_limb + crate::overhead!(second.1 + F::one(), &mut num_bits)
356 } else {
357 0
358 };
359
360 if let Some(second) = second {
361 if first_max_bits_per_limb + second_max_bits_per_limb <= capacity {
362 let adjustment_factor = &self.adjustment_factor_lookup_table[second_max_bits_per_limb];
363
364 dest_limbs.push(first.0 * adjustment_factor + second.0);
365 src_limbs.next();
366 } else {
367 dest_limbs.push(first.0);
368 }
369 } else {
370 dest_limbs.push(first.0);
371 }
372 }
373
374 dest_limbs
375 }
376
377 pub fn get_limbs_representations<TargetField: PrimeField>(
380 elem: &TargetField,
381 optimization_type: OptimizationType,
382 ) -> SmallVec<[F; 10]> {
383 Self::get_limbs_representations_from_big_integer::<TargetField>(&elem.to_bigint(), optimization_type)
384 }
385
386 pub fn get_limbs_representations_from_big_integer<TargetField: PrimeField>(
388 elem: &<TargetField as PrimeField>::BigInteger,
389 optimization_type: OptimizationType,
390 ) -> SmallVec<[F; 10]> {
391 let params = get_params(TargetField::size_in_bits(), F::size_in_bits(), optimization_type);
392
393 let mut cur_bits = Vec::new();
395 let mut limbs: SmallVec<[F; 10]> = SmallVec::new();
397 let mut cur = *elem;
398 for _ in 0..params.num_limbs {
399 cur.write_bits_be(&mut cur_bits); let cur_mod_r =
401 <F as PrimeField>::BigInteger::from_bits_be(&cur_bits[cur_bits.len() - params.bits_per_limb..])
402 .unwrap(); limbs.push(F::from_bigint(cur_mod_r).unwrap());
404 cur.divn(params.bits_per_limb as u32);
405 cur_bits.clear();
407 }
408
409 limbs.reverse();
411
412 limbs
413 }
414
415 pub fn push_elements_to_sponge<TargetField: PrimeField>(
417 &mut self,
418 src: impl IntoIterator<Item = TargetField>,
419 ty: OptimizationType,
420 ) {
421 let src_limbs = src
422 .into_iter()
423 .flat_map(|elem| {
424 let limbs = Self::get_limbs_representations(&elem, ty);
425 limbs.into_iter().map(|limb| (limb, F::one()))
426 })
428 .peekable();
429
430 let dest_limbs = self.compress_elements::<TargetField, _>(src_limbs, ty);
431 self.absorb_native_field_elements(&dest_limbs);
432 }
433
434 pub fn get_bits(&mut self, num_bits: usize) -> Vec<bool> {
437 let bits_per_element = F::size_in_bits() - 1;
438 let num_elements = num_bits.div_ceil(bits_per_element);
439
440 let src_elements = self.squeeze_native_field_elements(num_elements);
441 let mut dest_bits = Vec::<bool>::with_capacity(num_elements * bits_per_element);
442
443 let skip = (F::Parameters::REPR_SHAVE_BITS + 1) as usize;
444 for elem in src_elements.iter() {
445 let elem_bits = elem.to_bigint().to_bits_be();
447 dest_bits.extend_from_slice(&elem_bits[skip..]);
448 }
449 dest_bits.truncate(num_bits);
450
451 dest_bits
452 }
453
454 pub fn get_fe<TargetField: PrimeField>(
457 &mut self,
458 num_elements: usize,
459 outputs_short_elements: bool,
460 ) -> SmallVec<[TargetField; 10]> {
461 let num_bits_per_nonnative = if outputs_short_elements {
462 168
463 } else {
464 TargetField::size_in_bits() - 1 };
466 let bits = self.get_bits(num_bits_per_nonnative * num_elements);
467
468 let mut lookup_table = Vec::<TargetField>::with_capacity(num_bits_per_nonnative);
469 let mut cur = TargetField::one();
470 for _ in 0..num_bits_per_nonnative {
471 lookup_table.push(cur);
472 cur.double_in_place();
473 }
474
475 let dest_elements = bits
476 .chunks_exact(num_bits_per_nonnative)
477 .map(|per_nonnative_bits| {
478 let mut res = TargetField::zero();
480
481 for (i, bit) in per_nonnative_bits.iter().rev().enumerate() {
482 if *bit {
483 res += &lookup_table[i];
484 }
485 }
486 res
487 })
488 .collect::<SmallVec<_>>();
489 debug_assert_eq!(dest_elements.len(), num_elements);
490
491 dest_elements
492 }
493}