snarkvm_circuit_algorithms/poseidon/
hash_many.rs

1// Copyright 2024 Aleo Network Foundation
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use super::*;
17
18impl<E: Environment, const RATE: usize> HashMany for Poseidon<E, RATE> {
19    type Input = Field<E>;
20    type Output = Field<E>;
21
22    #[inline]
23    fn hash_many(&self, input: &[Self::Input], num_outputs: u16) -> Vec<Self::Output> {
24        // Construct the preimage: [ DOMAIN || LENGTH(INPUT) || [0; RATE-2] || INPUT ].
25        let mut preimage = Vec::with_capacity(RATE + input.len());
26        preimage.push(self.domain.clone());
27        preimage.push(Field::constant(console::Field::from_u128(input.len() as u128)));
28        preimage.resize(RATE, Field::zero()); // Pad up to RATE.
29        preimage.extend_from_slice(input);
30
31        // Initialize a new sponge.
32        let mut state = vec![Field::zero(); RATE + CAPACITY];
33        let mut mode = DuplexSpongeMode::Absorbing { next_absorb_index: 0 };
34
35        // Absorb the input and squeeze the output.
36        self.absorb(&mut state, &mut mode, &preimage);
37        self.squeeze(&mut state, &mut mode, num_outputs)
38    }
39}
40
41#[allow(clippy::needless_borrow)]
42impl<E: Environment, const RATE: usize> Poseidon<E, RATE> {
43    /// Absorbs the input elements into state.
44    #[inline]
45    fn absorb(&self, state: &mut [Field<E>], mode: &mut DuplexSpongeMode, input: &[Field<E>]) {
46        if !input.is_empty() {
47            // Determine the absorb index.
48            let (mut absorb_index, should_permute) = match *mode {
49                DuplexSpongeMode::Absorbing { next_absorb_index } => match next_absorb_index == RATE {
50                    true => (0, true),
51                    false => (next_absorb_index, false),
52                },
53                DuplexSpongeMode::Squeezing { .. } => (0, true),
54            };
55
56            // Proceed to permute the state, if necessary.
57            if should_permute {
58                self.permute(state);
59            }
60
61            let mut remaining = input;
62            loop {
63                // Compute the starting index.
64                let start = CAPACITY + absorb_index;
65
66                // Check if we can exit the loop.
67                if absorb_index + remaining.len() <= RATE {
68                    // Absorb the state elements into the input.
69                    remaining.iter().enumerate().for_each(|(i, element)| state[start + i] += element);
70                    // Update the sponge mode.
71                    *mode = DuplexSpongeMode::Absorbing { next_absorb_index: absorb_index + remaining.len() };
72                    return;
73                }
74
75                // Otherwise, proceed to absorb `(rate - absorb_index)` elements.
76                let num_absorbed = RATE - absorb_index;
77                remaining.iter().enumerate().take(num_absorbed).for_each(|(i, element)| state[start + i] += element);
78
79                // Permute the state.
80                self.permute(state);
81
82                // Repeat with the updated input slice and absorb index.
83                remaining = &remaining[num_absorbed..];
84                absorb_index = 0;
85            }
86        }
87    }
88
89    /// Squeeze the specified number of state elements into the output.
90    #[inline]
91    fn squeeze(&self, state: &mut [Field<E>], mode: &mut DuplexSpongeMode, num_outputs: u16) -> Vec<Field<E>> {
92        let mut output = vec![Field::zero(); num_outputs as usize];
93        if num_outputs != 0 {
94            self.squeeze_internal(state, mode, &mut output);
95        }
96        output
97    }
98
99    /// Squeeze the state elements into the output.
100    #[inline]
101    fn squeeze_internal(&self, state: &mut [Field<E>], mode: &mut DuplexSpongeMode, output: &mut [Field<E>]) {
102        // Determine the squeeze index.
103        let (mut squeeze_index, should_permute) = match *mode {
104            DuplexSpongeMode::Absorbing { .. } => (0, true),
105            DuplexSpongeMode::Squeezing { next_squeeze_index } => match next_squeeze_index == RATE {
106                true => (0, true),
107                false => (next_squeeze_index, false),
108            },
109        };
110
111        // Proceed to permute the state, if necessary.
112        if should_permute {
113            self.permute(state);
114        }
115
116        let mut remaining = output;
117        loop {
118            // Compute the starting index.
119            let start = CAPACITY + squeeze_index;
120
121            // Check if we can exit the loop.
122            if squeeze_index + remaining.len() <= RATE {
123                // Store the state elements into the output.
124                remaining.clone_from_slice(&state[start..(start + remaining.len())]);
125                // Update the sponge mode.
126                *mode = DuplexSpongeMode::Squeezing { next_squeeze_index: squeeze_index + remaining.len() };
127                return;
128            }
129
130            // Otherwise, proceed to squeeze `(rate - squeeze_index)` elements.
131            let num_squeezed = RATE - squeeze_index;
132            remaining[..num_squeezed].clone_from_slice(&state[start..(start + num_squeezed)]);
133
134            // Permute.
135            self.permute(state);
136
137            // Repeat with the updated output slice and squeeze index.
138            remaining = &mut remaining[num_squeezed..];
139            squeeze_index = 0;
140        }
141    }
142
143    /// Apply the additive round keys in-place.
144    #[inline]
145    fn apply_ark(&self, state: &mut [Field<E>], round: usize) {
146        for (i, element) in state.iter_mut().enumerate() {
147            *element += &self.ark[round][i];
148        }
149    }
150
151    /// Apply the S-Box based on whether it is a full round or partial round.
152    #[inline]
153    fn apply_s_box(&self, state: &mut [Field<E>], is_full_round: bool) {
154        if is_full_round {
155            // Full rounds apply the S Box (x^alpha) to every element of state
156            for element in state.iter_mut() {
157                *element = (&*element).pow(&self.alpha);
158            }
159        } else {
160            // Partial rounds apply the S Box (x^alpha) to just the first element of state
161            state[0] = (&state[0]).pow(&self.alpha);
162        }
163    }
164
165    /// Apply the Maximally Distance Separating (MDS) matrix in-place.
166    #[inline]
167    fn apply_mds(&self, state: &mut [Field<E>]) {
168        let mut new_state = Vec::with_capacity(state.len());
169        for i in 0..state.len() {
170            let mut accumulator = Field::zero();
171            for (j, element) in state.iter().enumerate() {
172                accumulator += element * &self.mds[i][j];
173            }
174            new_state.push(accumulator);
175        }
176        state.clone_from_slice(&new_state);
177    }
178
179    /// Apply the permutation for all rounds in-place.
180    #[inline]
181    fn permute(&self, state: &mut [Field<E>]) {
182        // Determine the partial rounds range bound.
183        let full_rounds_over_2 = self.full_rounds / 2;
184        let partial_round_range = full_rounds_over_2..(full_rounds_over_2 + self.partial_rounds);
185
186        // Iterate through all rounds to permute.
187        for i in 0..(self.partial_rounds + self.full_rounds) {
188            let is_full_round = !partial_round_range.contains(&i);
189            self.apply_ark(state, i);
190            self.apply_s_box(state, is_full_round);
191            self.apply_mds(state);
192        }
193    }
194}
195
196#[cfg(all(test, feature = "console"))]
197mod tests {
198    use super::*;
199    use snarkvm_circuit_types::environment::Circuit;
200
201    use anyhow::Result;
202
203    const DOMAIN: &str = "PoseidonCircuit0";
204    const ITERATIONS: usize = 10;
205    const RATE: u16 = 4;
206
207    fn check_hash_many(
208        mode: Mode,
209        num_inputs: usize,
210        num_outputs: u16,
211        num_constants: u64,
212        num_public: u64,
213        num_private: u64,
214        num_constraints: u64,
215        rng: &mut TestRng,
216    ) -> Result<()> {
217        use console::HashMany as H;
218
219        let native = console::Poseidon::<<Circuit as Environment>::Network, { RATE as usize }>::setup(DOMAIN)?;
220        let poseidon = Poseidon::<Circuit, { RATE as usize }>::constant(native.clone());
221
222        for i in 0..ITERATIONS {
223            // Prepare the preimage.
224            let native_input = (0..num_inputs)
225                .map(|_| console::Field::<<Circuit as Environment>::Network>::rand(rng))
226                .collect::<Vec<_>>();
227            let input = native_input.iter().map(|v| Field::<Circuit>::new(mode, *v)).collect::<Vec<_>>();
228
229            // Compute the native hash.
230            let expected = native.hash_many(&native_input, num_outputs);
231
232            // Compute the circuit hash.
233            Circuit::scope(format!("Poseidon {mode} {i} {num_outputs}"), || {
234                let candidate = poseidon.hash_many(&input, num_outputs);
235                for (expected_element, candidate_element) in expected.iter().zip_eq(&candidate) {
236                    assert_eq!(*expected_element, candidate_element.eject_value());
237                }
238                let case = format!("(mode = {mode}, num_inputs = {num_inputs}, num_outputs = {num_outputs})");
239                assert_scope!(case, num_constants, num_public, num_private, num_constraints);
240            });
241            Circuit::reset();
242        }
243        Ok(())
244    }
245
246    #[test]
247    fn test_hash_many_constant() -> Result<()> {
248        let mut rng = TestRng::default();
249
250        for num_inputs in 0..=RATE {
251            for num_outputs in 0..=RATE {
252                check_hash_many(Mode::Constant, num_inputs as usize, num_outputs, 1, 0, 0, 0, &mut rng)?;
253            }
254        }
255        Ok(())
256    }
257
258    #[test]
259    fn test_hash_many_public() -> Result<()> {
260        let mut rng = TestRng::default();
261
262        for num_outputs in 0..=RATE {
263            check_hash_many(Mode::Public, 0, num_outputs, 1, 0, 0, 0, &mut rng)?;
264        }
265        for num_outputs in 1..=RATE {
266            check_hash_many(Mode::Public, 1, num_outputs, 1, 0, 335, 335, &mut rng)?;
267            check_hash_many(Mode::Public, 2, num_outputs, 1, 0, 340, 340, &mut rng)?;
268            check_hash_many(Mode::Public, 3, num_outputs, 1, 0, 345, 345, &mut rng)?;
269            check_hash_many(Mode::Public, 4, num_outputs, 1, 0, 350, 350, &mut rng)?;
270            check_hash_many(Mode::Public, 5, num_outputs, 1, 0, 705, 705, &mut rng)?;
271            check_hash_many(Mode::Public, 6, num_outputs, 1, 0, 705, 705, &mut rng)?;
272        }
273        for num_outputs in (RATE + 1)..=(RATE * 2) {
274            check_hash_many(Mode::Public, 1, num_outputs, 1, 0, 690, 690, &mut rng)?;
275            check_hash_many(Mode::Public, 2, num_outputs, 1, 0, 695, 695, &mut rng)?;
276            check_hash_many(Mode::Public, 3, num_outputs, 1, 0, 700, 700, &mut rng)?;
277            check_hash_many(Mode::Public, 4, num_outputs, 1, 0, 705, 705, &mut rng)?;
278            check_hash_many(Mode::Public, 5, num_outputs, 1, 0, 1060, 1060, &mut rng)?;
279            check_hash_many(Mode::Public, 6, num_outputs, 1, 0, 1060, 1060, &mut rng)?;
280        }
281        Ok(())
282    }
283
284    #[test]
285    fn test_hash_many_private() -> Result<()> {
286        let mut rng = TestRng::default();
287
288        for num_outputs in 0..=RATE {
289            check_hash_many(Mode::Private, 0, num_outputs, 1, 0, 0, 0, &mut rng)?;
290        }
291        for num_outputs in 1..=RATE {
292            check_hash_many(Mode::Private, 1, num_outputs, 1, 0, 335, 335, &mut rng)?;
293            check_hash_many(Mode::Private, 2, num_outputs, 1, 0, 340, 340, &mut rng)?;
294            check_hash_many(Mode::Private, 3, num_outputs, 1, 0, 345, 345, &mut rng)?;
295            check_hash_many(Mode::Private, 4, num_outputs, 1, 0, 350, 350, &mut rng)?;
296            check_hash_many(Mode::Private, 5, num_outputs, 1, 0, 705, 705, &mut rng)?;
297            check_hash_many(Mode::Private, 6, num_outputs, 1, 0, 705, 705, &mut rng)?;
298        }
299        for num_outputs in (RATE + 1)..=(RATE * 2) {
300            check_hash_many(Mode::Private, 1, num_outputs, 1, 0, 690, 690, &mut rng)?;
301            check_hash_many(Mode::Private, 2, num_outputs, 1, 0, 695, 695, &mut rng)?;
302            check_hash_many(Mode::Private, 3, num_outputs, 1, 0, 700, 700, &mut rng)?;
303            check_hash_many(Mode::Private, 4, num_outputs, 1, 0, 705, 705, &mut rng)?;
304            check_hash_many(Mode::Private, 5, num_outputs, 1, 0, 1060, 1060, &mut rng)?;
305            check_hash_many(Mode::Private, 6, num_outputs, 1, 0, 1060, 1060, &mut rng)?;
306        }
307        Ok(())
308    }
309}