snarkvm_console_algorithms/poseidon/
mod.rs

1// Copyright 2024 Aleo Network Foundation
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16mod helpers;
17
18mod hash;
19mod hash_many;
20mod hash_to_group;
21mod hash_to_scalar;
22mod prf;
23
24use crate::{Elligator2, poseidon::helpers::*};
25use snarkvm_console_types::prelude::*;
26use snarkvm_fields::{PoseidonDefaultField, PoseidonParameters};
27
28use std::sync::Arc;
29
30const CAPACITY: usize = 1;
31
32/// Poseidon2 is a cryptographic hash function of input rate 2.
33pub type Poseidon2<E> = Poseidon<E, 2>;
34/// Poseidon4 is a cryptographic hash function of input rate 4.
35pub type Poseidon4<E> = Poseidon<E, 4>;
36/// Poseidon8 is a cryptographic hash function of input rate 8.
37pub type Poseidon8<E> = Poseidon<E, 8>;
38
39#[derive(Clone, Debug, PartialEq)]
40pub struct Poseidon<E: Environment, const RATE: usize> {
41    /// The domain separator for the Poseidon hash function.
42    domain: Field<E>,
43    /// The Poseidon parameters for hashing.
44    parameters: Arc<PoseidonParameters<E::Field, RATE, CAPACITY>>,
45}
46
47impl<E: Environment, const RATE: usize> Poseidon<E, RATE> {
48    /// Initializes a new instance of Poseidon.
49    pub fn setup(domain: &str) -> Result<Self> {
50        // Ensure the given domain is within the allowed size in bits.
51        let num_bits = domain.len().saturating_mul(8);
52        let max_bits = Field::<E>::size_in_data_bits();
53        ensure!(num_bits <= max_bits, "Domain cannot exceed {max_bits} bits, found {num_bits} bits");
54
55        Ok(Self {
56            domain: Field::<E>::new_domain_separator(domain),
57            parameters: Arc::new(E::Field::default_poseidon_parameters::<RATE>()?),
58        })
59    }
60
61    /// Returns the domain separator for the hash function.
62    pub fn domain(&self) -> Field<E> {
63        self.domain
64    }
65
66    /// Returns the Poseidon parameters for hashing.
67    pub fn parameters(&self) -> &Arc<PoseidonParameters<E::Field, RATE, CAPACITY>> {
68        &self.parameters
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use super::*;
75    use snarkvm_console_types::environment::Console;
76    use snarkvm_curves::edwards_bls12::Fq;
77    use snarkvm_fields::{PoseidonDefaultField, PoseidonGrainLFSR};
78
79    type CurrentEnvironment = Console;
80
81    use std::{path::PathBuf, sync::Arc};
82
83    /// Returns the path to the `resources` folder for this module.
84    fn resources_path() -> PathBuf {
85        // Construct the path for the `resources` folder.
86        let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
87        path.push("src");
88        path.push("poseidon");
89        path.push("resources");
90
91        // Create the `resources` folder, if it does not exist.
92        if !path.exists() {
93            std::fs::create_dir_all(&path).unwrap_or_else(|_| panic!("Failed to create resources folder: {path:?}"));
94        }
95        // Output the path.
96        path
97    }
98
99    /// Loads the given `test_folder/test_file` and asserts the given `candidate` matches the expected values.
100    #[track_caller]
101    fn assert_snapshot<S1: Into<String>, S2: Into<String>, C: Debug>(test_folder: S1, test_file: S2, candidate: C) {
102        // Construct the path for the test folder.
103        let mut path = resources_path();
104        path.push(test_folder.into());
105
106        // Create the test folder, if it does not exist.
107        if !path.exists() {
108            std::fs::create_dir(&path).unwrap_or_else(|_| panic!("Failed to create test folder: {path:?}"));
109        }
110
111        // Construct the path for the test file.
112        path.push(test_file.into());
113        path.set_extension("snap");
114
115        // Create the test file, if it does not exist.
116        if !path.exists() {
117            std::fs::File::create(&path).unwrap_or_else(|_| panic!("Failed to create file: {path:?}"));
118        }
119
120        // Assert the test file is equal to the expected value.
121        expect_test::expect_file![path].assert_eq(&format!("{candidate:?}"));
122    }
123
124    #[test]
125    fn test_grain_lfsr() -> Result<()> {
126        let mut lfsr = PoseidonGrainLFSR::new(false, 253, 3, 8, 31);
127        assert_snapshot("test_grain_lfsr", "first_sample", lfsr.get_field_elements_rejection_sampling::<Fq>(1)?);
128        assert_snapshot("test_grain_lfsr", "second_sample", lfsr.get_field_elements_rejection_sampling::<Fq>(1)?);
129        Ok(())
130    }
131
132    #[test]
133    fn test_sponge() {
134        const RATE: usize = 2;
135        let parameters = Arc::new(Fq::default_poseidon_parameters::<RATE>().unwrap());
136
137        for absorb in 0..10 {
138            for squeeze in 0..10 {
139                let iteration = format!("absorb_{absorb}_squeeze_{squeeze}");
140
141                let mut sponge = PoseidonSponge::<CurrentEnvironment, RATE, CAPACITY>::new(&parameters);
142                sponge.absorb(&vec![Field::<CurrentEnvironment>::from_u64(1237812u64); absorb]);
143
144                let next_absorb_index = if absorb % RATE != 0 || absorb == 0 { absorb % RATE } else { RATE };
145                assert_eq!(sponge.mode, DuplexSpongeMode::Absorbing { next_absorb_index }, "{iteration}");
146
147                assert_snapshot("test_sponge", &iteration, sponge.squeeze(u16::try_from(squeeze).unwrap()));
148
149                let next_squeeze_index = if squeeze % RATE != 0 || squeeze == 0 { squeeze % RATE } else { RATE };
150                match squeeze == 0 {
151                    true => assert_eq!(sponge.mode, DuplexSpongeMode::Absorbing { next_absorb_index }, "{iteration}"),
152                    false => assert_eq!(sponge.mode, DuplexSpongeMode::Squeezing { next_squeeze_index }, "{iteration}"),
153                }
154            }
155        }
156    }
157
158    #[test]
159    fn test_parameters() {
160        fn single_rate_test<const RATE: usize>() {
161            let parameters = Fq::default_poseidon_parameters::<RATE>().unwrap();
162            assert_snapshot("test_parameters", format!("rate_{RATE}_ark"), parameters.ark);
163            assert_snapshot("test_parameters", format!("rate_{RATE}_mds"), parameters.mds);
164        }
165        // Optimized for constraints.
166        single_rate_test::<2>();
167        single_rate_test::<3>();
168        single_rate_test::<4>();
169        single_rate_test::<5>();
170        single_rate_test::<6>();
171        single_rate_test::<7>();
172        single_rate_test::<8>();
173    }
174
175    #[test]
176    fn test_suite_hash2() {
177        fn test_case_hash2(index: u8, input: Vec<Field<CurrentEnvironment>>) {
178            let poseidon2 = Poseidon2::<Console>::setup("Poseidon2").unwrap();
179            assert_snapshot("test_hash", format!("rate_2_test_{index}"), poseidon2.hash(&input).unwrap());
180        }
181        test_case_hash2(0, vec![]);
182        test_case_hash2(1, vec![Field::<Console>::from_u8(0)]);
183        test_case_hash2(2, vec![Field::<Console>::from_u8(1)]);
184        test_case_hash2(3, vec![Field::<Console>::from_u8(0), Field::<Console>::from_u8(1)]);
185        test_case_hash2(4, vec![Field::<Console>::from_u8(7), Field::<Console>::from_u8(6)]);
186    }
187
188    #[test]
189    fn test_suite_hash4() {
190        fn test_case_hash4(index: u8, input: Vec<Field<CurrentEnvironment>>) {
191            let poseidon4 = Poseidon4::<Console>::setup("Poseidon4").unwrap();
192            assert_snapshot("test_hash", format!("rate_4_test_{index}"), poseidon4.hash(&input).unwrap());
193        }
194        test_case_hash4(0, vec![]);
195        test_case_hash4(1, vec![Field::<Console>::from_u8(0)]);
196        test_case_hash4(2, vec![Field::<Console>::from_u8(1)]);
197        test_case_hash4(3, vec![Field::<Console>::from_u8(0), Field::<Console>::from_u8(1)]);
198        test_case_hash4(4, vec![Field::<Console>::from_u8(7), Field::<Console>::from_u8(6)]);
199        test_case_hash4(5, vec![
200            Field::<Console>::from_str(
201                "3801852864665033841774715284518384682376829752661853198612247855579120198106field",
202            )
203            .unwrap(),
204            Field::<Console>::from_str(
205                "8354898322875240371401674517397790035008442020361740574117886421279083828480field",
206            )
207            .unwrap(),
208            Field::<Console>::from_str(
209                "4810388512520169167962815122521832339992376865086300759308552937986944510606field",
210            )
211            .unwrap(),
212        ]);
213        test_case_hash4(6, vec![
214            Field::<Console>::from_str(
215                "3801852864665033841774715284518384682376829752661853198612247855579120198106field",
216            )
217            .unwrap(),
218            Field::<Console>::from_str(
219                "8354898322875240371401674517397790035008442020361740574117886421279083828480field",
220            )
221            .unwrap(),
222            Field::<Console>::from_str(
223                "4810388512520169167962815122521832339992376865086300759308552937986944510606field",
224            )
225            .unwrap(),
226            Field::<Console>::from_str(
227                "1806278863067630397941269234951941896370617486625414347832536440203404317871field",
228            )
229            .unwrap(),
230        ]);
231        test_case_hash4(7, vec![
232            Field::<Console>::from_str(
233                "3801852864665033841774715284518384682376829752661853198612247855579120198106field",
234            )
235            .unwrap(),
236            Field::<Console>::from_str(
237                "8354898322875240371401674517397790035008442020361740574117886421279083828480field",
238            )
239            .unwrap(),
240            Field::<Console>::from_str(
241                "4810388512520169167962815122521832339992376865086300759308552937986944510606field",
242            )
243            .unwrap(),
244            Field::<Console>::from_str(
245                "1806278863067630397941269234951941896370617486625414347832536440203404317871field",
246            )
247            .unwrap(),
248            Field::<Console>::from_str(
249                "4017177598231920767921734423139954103557056461408532722673217828464276314809field",
250            )
251            .unwrap(),
252        ]);
253    }
254
255    #[test]
256    fn test_suite_hash8() {
257        fn test_case_hash8(index: u16, input: Vec<Field<CurrentEnvironment>>) {
258            let poseidon8 = Poseidon8::<Console>::setup("Poseidon8").unwrap();
259            assert_snapshot("test_hash", format!("rate_8_test_{index}"), poseidon8.hash(&input).unwrap());
260        }
261        test_case_hash8(0, vec![]);
262        test_case_hash8(1, vec![Field::<Console>::from_u8(0)]);
263        test_case_hash8(2, vec![Field::<Console>::from_u8(1)]);
264        test_case_hash8(3, vec![Field::<Console>::from_u8(0), Field::<Console>::from_u8(1)]);
265        test_case_hash8(4, vec![Field::<Console>::from_u8(7), Field::<Console>::from_u8(6)]);
266        test_case_hash8(5, vec![
267            Field::<Console>::from_str(
268                "3801852864665033841774715284518384682376829752661853198612247855579120198106field",
269            )
270            .unwrap(),
271            Field::<Console>::from_str(
272                "8354898322875240371401674517397790035008442020361740574117886421279083828480field",
273            )
274            .unwrap(),
275            Field::<Console>::from_str(
276                "4810388512520169167962815122521832339992376865086300759308552937986944510606field",
277            )
278            .unwrap(),
279        ]);
280        test_case_hash8(6, vec![
281            Field::<Console>::from_str(
282                "3801852864665033841774715284518384682376829752661853198612247855579120198106field",
283            )
284            .unwrap(),
285            Field::<Console>::from_str(
286                "8354898322875240371401674517397790035008442020361740574117886421279083828480field",
287            )
288            .unwrap(),
289            Field::<Console>::from_str(
290                "4810388512520169167962815122521832339992376865086300759308552937986944510606field",
291            )
292            .unwrap(),
293            Field::<Console>::from_str(
294                "1806278863067630397941269234951941896370617486625414347832536440203404317871field",
295            )
296            .unwrap(),
297        ]);
298        test_case_hash8(7, vec![
299            Field::<Console>::from_str(
300                "3801852864665033841774715284518384682376829752661853198612247855579120198106field",
301            )
302            .unwrap(),
303            Field::<Console>::from_str(
304                "8354898322875240371401674517397790035008442020361740574117886421279083828480field",
305            )
306            .unwrap(),
307            Field::<Console>::from_str(
308                "4810388512520169167962815122521832339992376865086300759308552937986944510606field",
309            )
310            .unwrap(),
311            Field::<Console>::from_str(
312                "1806278863067630397941269234951941896370617486625414347832536440203404317871field",
313            )
314            .unwrap(),
315            Field::<Console>::from_str(
316                "4017177598231920767921734423139954103557056461408532722673217828464276314809field",
317            )
318            .unwrap(),
319        ]);
320        test_case_hash8(8, vec![
321            Field::<Console>::from_str(
322                "2241061724039470158487229089505123379386376040366677537043719491567584322339field",
323            )
324            .unwrap(),
325            Field::<Console>::from_str(
326                "4450395467941419565906844040025562669400620759737863109185235386261110553073field",
327            )
328            .unwrap(),
329            Field::<Console>::from_str(
330                "3763549180544198711495347718218896634621699987767108409942867882747700142403field",
331            )
332            .unwrap(),
333            Field::<Console>::from_str(
334                "1834649076610684411560795826346579299134200286711220272747136514724202486145field",
335            )
336            .unwrap(),
337            Field::<Console>::from_str(
338                "3330794675297759513930533281299019673013197332462213086257974185952740704073field",
339            )
340            .unwrap(),
341            Field::<Console>::from_str(
342                "5929621997900969559642343088519370677943323262633114245367700983937202243619field",
343            )
344            .unwrap(),
345            Field::<Console>::from_str(
346                "8211311402459203356251863974142333868284569297703150729090604853345946857386field",
347            )
348            .unwrap(),
349        ]);
350        test_case_hash8(9, vec![
351            Field::<Console>::from_str(
352                "160895951580389706659907027483151875213333010019551276998320919296228647317field",
353            )
354            .unwrap(),
355            Field::<Console>::from_str(
356                "8334099740396373026754940038411748941117628023990297711605274995172393663866field",
357            )
358            .unwrap(),
359            Field::<Console>::from_str(
360                "6508516067551208838086421306235504440162527555399726948591414865066786644888field",
361            )
362            .unwrap(),
363            Field::<Console>::from_str(
364                "5260580011132523115913756761919139190330166964648541423363604516046903841683field",
365            )
366            .unwrap(),
367            Field::<Console>::from_str(
368                "1066299182733912299977577599302716102002738653010828827086884529157392046228field",
369            )
370            .unwrap(),
371            Field::<Console>::from_str(
372                "1977519953625589014039847898215240724041194773120013187722954068145627219929field",
373            )
374            .unwrap(),
375            Field::<Console>::from_str(
376                "1618348632868002512910764605250139381231860094469042556990470848701700964713field",
377            )
378            .unwrap(),
379            Field::<Console>::from_str(
380                "1157459381876765943377450451674060447297483544491073402235960067133285590974field",
381            )
382            .unwrap(),
383        ]);
384        test_case_hash8(10, vec![
385            Field::<Console>::from_str(
386                "3912308888616251672812272013988802988420414245857866136212784631403027079860field",
387            )
388            .unwrap(),
389            Field::<Console>::from_str(
390                "4100923705771018951561873336835055979905965765839649442185404560120892958216field",
391            )
392            .unwrap(),
393            Field::<Console>::from_str(
394                "5701101373789959818781445339314572139971317958997296225671698446757742149719field",
395            )
396            .unwrap(),
397            Field::<Console>::from_str(
398                "5785597627944719799683455467917641287692417422465938462034769734951914291948field",
399            )
400            .unwrap(),
401            Field::<Console>::from_str(
402                "214818498460401597228033958287537426429167258531438668351703993840760770582field",
403            )
404            .unwrap(),
405            Field::<Console>::from_str(
406                "4497884203527978976088488455523871581608892729212445595385399904032800522087field",
407            )
408            .unwrap(),
409            Field::<Console>::from_str(
410                "4010331535874074900042223641934450423780782982190514529696596753456937384201field",
411            )
412            .unwrap(),
413            Field::<Console>::from_str(
414                "6067637133445382691713836557146174628934072680692724940823629181144890569742field",
415            )
416            .unwrap(),
417            Field::<Console>::from_str(
418                "5966421531117752671625849775894572561179958822813329961720805067254995723444field",
419            )
420            .unwrap(),
421        ]);
422    }
423}