1mod helpers;
17
18mod hash;
19mod hash_many;
20mod hash_to_group;
21mod hash_to_scalar;
22mod prf;
23
24use crate::{Elligator2, poseidon::helpers::*};
25use snarkvm_console_types::prelude::*;
26use snarkvm_fields::{PoseidonDefaultField, PoseidonParameters};
27
28use std::sync::Arc;
29
30const CAPACITY: usize = 1;
31
32pub type Poseidon2<E> = Poseidon<E, 2>;
34pub type Poseidon4<E> = Poseidon<E, 4>;
36pub type Poseidon8<E> = Poseidon<E, 8>;
38
39#[derive(Clone, Debug, PartialEq)]
40pub struct Poseidon<E: Environment, const RATE: usize> {
41 domain: Field<E>,
43 parameters: Arc<PoseidonParameters<E::Field, RATE, CAPACITY>>,
45}
46
47impl<E: Environment, const RATE: usize> Poseidon<E, RATE> {
48 pub fn setup(domain: &str) -> Result<Self> {
50 let num_bits = domain.len().saturating_mul(8);
52 let max_bits = Field::<E>::size_in_data_bits();
53 ensure!(num_bits <= max_bits, "Domain cannot exceed {max_bits} bits, found {num_bits} bits");
54
55 Ok(Self {
56 domain: Field::<E>::new_domain_separator(domain),
57 parameters: Arc::new(E::Field::default_poseidon_parameters::<RATE>()?),
58 })
59 }
60
61 pub fn domain(&self) -> Field<E> {
63 self.domain
64 }
65
66 pub fn parameters(&self) -> &Arc<PoseidonParameters<E::Field, RATE, CAPACITY>> {
68 &self.parameters
69 }
70}
71
72#[cfg(test)]
73mod tests {
74 use super::*;
75 use snarkvm_console_types::environment::Console;
76 use snarkvm_curves::edwards_bls12::Fq;
77 use snarkvm_fields::{PoseidonDefaultField, PoseidonGrainLFSR};
78
79 type CurrentEnvironment = Console;
80
81 use std::{path::PathBuf, sync::Arc};
82
83 fn resources_path() -> PathBuf {
85 let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
87 path.push("src");
88 path.push("poseidon");
89 path.push("resources");
90
91 if !path.exists() {
93 std::fs::create_dir_all(&path).unwrap_or_else(|_| panic!("Failed to create resources folder: {path:?}"));
94 }
95 path
97 }
98
99 #[track_caller]
101 fn assert_snapshot<S1: Into<String>, S2: Into<String>, C: Debug>(test_folder: S1, test_file: S2, candidate: C) {
102 let mut path = resources_path();
104 path.push(test_folder.into());
105
106 if !path.exists() {
108 std::fs::create_dir(&path).unwrap_or_else(|_| panic!("Failed to create test folder: {path:?}"));
109 }
110
111 path.push(test_file.into());
113 path.set_extension("snap");
114
115 if !path.exists() {
117 std::fs::File::create(&path).unwrap_or_else(|_| panic!("Failed to create file: {path:?}"));
118 }
119
120 expect_test::expect_file![path].assert_eq(&format!("{candidate:?}"));
122 }
123
124 #[test]
125 fn test_grain_lfsr() -> Result<()> {
126 let mut lfsr = PoseidonGrainLFSR::new(false, 253, 3, 8, 31);
127 assert_snapshot("test_grain_lfsr", "first_sample", lfsr.get_field_elements_rejection_sampling::<Fq>(1)?);
128 assert_snapshot("test_grain_lfsr", "second_sample", lfsr.get_field_elements_rejection_sampling::<Fq>(1)?);
129 Ok(())
130 }
131
132 #[test]
133 fn test_sponge() {
134 const RATE: usize = 2;
135 let parameters = Arc::new(Fq::default_poseidon_parameters::<RATE>().unwrap());
136
137 for absorb in 0..10 {
138 for squeeze in 0..10 {
139 let iteration = format!("absorb_{absorb}_squeeze_{squeeze}");
140
141 let mut sponge = PoseidonSponge::<CurrentEnvironment, RATE, CAPACITY>::new(¶meters);
142 sponge.absorb(&vec![Field::<CurrentEnvironment>::from_u64(1237812u64); absorb]);
143
144 let next_absorb_index = if absorb % RATE != 0 || absorb == 0 { absorb % RATE } else { RATE };
145 assert_eq!(sponge.mode, DuplexSpongeMode::Absorbing { next_absorb_index }, "{iteration}");
146
147 assert_snapshot("test_sponge", &iteration, sponge.squeeze(u16::try_from(squeeze).unwrap()));
148
149 let next_squeeze_index = if squeeze % RATE != 0 || squeeze == 0 { squeeze % RATE } else { RATE };
150 match squeeze == 0 {
151 true => assert_eq!(sponge.mode, DuplexSpongeMode::Absorbing { next_absorb_index }, "{iteration}"),
152 false => assert_eq!(sponge.mode, DuplexSpongeMode::Squeezing { next_squeeze_index }, "{iteration}"),
153 }
154 }
155 }
156 }
157
158 #[test]
159 fn test_parameters() {
160 fn single_rate_test<const RATE: usize>() {
161 let parameters = Fq::default_poseidon_parameters::<RATE>().unwrap();
162 assert_snapshot("test_parameters", format!("rate_{RATE}_ark"), parameters.ark);
163 assert_snapshot("test_parameters", format!("rate_{RATE}_mds"), parameters.mds);
164 }
165 single_rate_test::<2>();
167 single_rate_test::<3>();
168 single_rate_test::<4>();
169 single_rate_test::<5>();
170 single_rate_test::<6>();
171 single_rate_test::<7>();
172 single_rate_test::<8>();
173 }
174
175 #[test]
176 fn test_suite_hash2() {
177 fn test_case_hash2(index: u8, input: Vec<Field<CurrentEnvironment>>) {
178 let poseidon2 = Poseidon2::<Console>::setup("Poseidon2").unwrap();
179 assert_snapshot("test_hash", format!("rate_2_test_{index}"), poseidon2.hash(&input).unwrap());
180 }
181 test_case_hash2(0, vec![]);
182 test_case_hash2(1, vec![Field::<Console>::from_u8(0)]);
183 test_case_hash2(2, vec![Field::<Console>::from_u8(1)]);
184 test_case_hash2(3, vec![Field::<Console>::from_u8(0), Field::<Console>::from_u8(1)]);
185 test_case_hash2(4, vec![Field::<Console>::from_u8(7), Field::<Console>::from_u8(6)]);
186 }
187
188 #[test]
189 fn test_suite_hash4() {
190 fn test_case_hash4(index: u8, input: Vec<Field<CurrentEnvironment>>) {
191 let poseidon4 = Poseidon4::<Console>::setup("Poseidon4").unwrap();
192 assert_snapshot("test_hash", format!("rate_4_test_{index}"), poseidon4.hash(&input).unwrap());
193 }
194 test_case_hash4(0, vec![]);
195 test_case_hash4(1, vec![Field::<Console>::from_u8(0)]);
196 test_case_hash4(2, vec![Field::<Console>::from_u8(1)]);
197 test_case_hash4(3, vec![Field::<Console>::from_u8(0), Field::<Console>::from_u8(1)]);
198 test_case_hash4(4, vec![Field::<Console>::from_u8(7), Field::<Console>::from_u8(6)]);
199 test_case_hash4(5, vec![
200 Field::<Console>::from_str(
201 "3801852864665033841774715284518384682376829752661853198612247855579120198106field",
202 )
203 .unwrap(),
204 Field::<Console>::from_str(
205 "8354898322875240371401674517397790035008442020361740574117886421279083828480field",
206 )
207 .unwrap(),
208 Field::<Console>::from_str(
209 "4810388512520169167962815122521832339992376865086300759308552937986944510606field",
210 )
211 .unwrap(),
212 ]);
213 test_case_hash4(6, vec![
214 Field::<Console>::from_str(
215 "3801852864665033841774715284518384682376829752661853198612247855579120198106field",
216 )
217 .unwrap(),
218 Field::<Console>::from_str(
219 "8354898322875240371401674517397790035008442020361740574117886421279083828480field",
220 )
221 .unwrap(),
222 Field::<Console>::from_str(
223 "4810388512520169167962815122521832339992376865086300759308552937986944510606field",
224 )
225 .unwrap(),
226 Field::<Console>::from_str(
227 "1806278863067630397941269234951941896370617486625414347832536440203404317871field",
228 )
229 .unwrap(),
230 ]);
231 test_case_hash4(7, vec![
232 Field::<Console>::from_str(
233 "3801852864665033841774715284518384682376829752661853198612247855579120198106field",
234 )
235 .unwrap(),
236 Field::<Console>::from_str(
237 "8354898322875240371401674517397790035008442020361740574117886421279083828480field",
238 )
239 .unwrap(),
240 Field::<Console>::from_str(
241 "4810388512520169167962815122521832339992376865086300759308552937986944510606field",
242 )
243 .unwrap(),
244 Field::<Console>::from_str(
245 "1806278863067630397941269234951941896370617486625414347832536440203404317871field",
246 )
247 .unwrap(),
248 Field::<Console>::from_str(
249 "4017177598231920767921734423139954103557056461408532722673217828464276314809field",
250 )
251 .unwrap(),
252 ]);
253 }
254
255 #[test]
256 fn test_suite_hash8() {
257 fn test_case_hash8(index: u16, input: Vec<Field<CurrentEnvironment>>) {
258 let poseidon8 = Poseidon8::<Console>::setup("Poseidon8").unwrap();
259 assert_snapshot("test_hash", format!("rate_8_test_{index}"), poseidon8.hash(&input).unwrap());
260 }
261 test_case_hash8(0, vec![]);
262 test_case_hash8(1, vec![Field::<Console>::from_u8(0)]);
263 test_case_hash8(2, vec![Field::<Console>::from_u8(1)]);
264 test_case_hash8(3, vec![Field::<Console>::from_u8(0), Field::<Console>::from_u8(1)]);
265 test_case_hash8(4, vec![Field::<Console>::from_u8(7), Field::<Console>::from_u8(6)]);
266 test_case_hash8(5, vec![
267 Field::<Console>::from_str(
268 "3801852864665033841774715284518384682376829752661853198612247855579120198106field",
269 )
270 .unwrap(),
271 Field::<Console>::from_str(
272 "8354898322875240371401674517397790035008442020361740574117886421279083828480field",
273 )
274 .unwrap(),
275 Field::<Console>::from_str(
276 "4810388512520169167962815122521832339992376865086300759308552937986944510606field",
277 )
278 .unwrap(),
279 ]);
280 test_case_hash8(6, vec![
281 Field::<Console>::from_str(
282 "3801852864665033841774715284518384682376829752661853198612247855579120198106field",
283 )
284 .unwrap(),
285 Field::<Console>::from_str(
286 "8354898322875240371401674517397790035008442020361740574117886421279083828480field",
287 )
288 .unwrap(),
289 Field::<Console>::from_str(
290 "4810388512520169167962815122521832339992376865086300759308552937986944510606field",
291 )
292 .unwrap(),
293 Field::<Console>::from_str(
294 "1806278863067630397941269234951941896370617486625414347832536440203404317871field",
295 )
296 .unwrap(),
297 ]);
298 test_case_hash8(7, vec![
299 Field::<Console>::from_str(
300 "3801852864665033841774715284518384682376829752661853198612247855579120198106field",
301 )
302 .unwrap(),
303 Field::<Console>::from_str(
304 "8354898322875240371401674517397790035008442020361740574117886421279083828480field",
305 )
306 .unwrap(),
307 Field::<Console>::from_str(
308 "4810388512520169167962815122521832339992376865086300759308552937986944510606field",
309 )
310 .unwrap(),
311 Field::<Console>::from_str(
312 "1806278863067630397941269234951941896370617486625414347832536440203404317871field",
313 )
314 .unwrap(),
315 Field::<Console>::from_str(
316 "4017177598231920767921734423139954103557056461408532722673217828464276314809field",
317 )
318 .unwrap(),
319 ]);
320 test_case_hash8(8, vec![
321 Field::<Console>::from_str(
322 "2241061724039470158487229089505123379386376040366677537043719491567584322339field",
323 )
324 .unwrap(),
325 Field::<Console>::from_str(
326 "4450395467941419565906844040025562669400620759737863109185235386261110553073field",
327 )
328 .unwrap(),
329 Field::<Console>::from_str(
330 "3763549180544198711495347718218896634621699987767108409942867882747700142403field",
331 )
332 .unwrap(),
333 Field::<Console>::from_str(
334 "1834649076610684411560795826346579299134200286711220272747136514724202486145field",
335 )
336 .unwrap(),
337 Field::<Console>::from_str(
338 "3330794675297759513930533281299019673013197332462213086257974185952740704073field",
339 )
340 .unwrap(),
341 Field::<Console>::from_str(
342 "5929621997900969559642343088519370677943323262633114245367700983937202243619field",
343 )
344 .unwrap(),
345 Field::<Console>::from_str(
346 "8211311402459203356251863974142333868284569297703150729090604853345946857386field",
347 )
348 .unwrap(),
349 ]);
350 test_case_hash8(9, vec![
351 Field::<Console>::from_str(
352 "160895951580389706659907027483151875213333010019551276998320919296228647317field",
353 )
354 .unwrap(),
355 Field::<Console>::from_str(
356 "8334099740396373026754940038411748941117628023990297711605274995172393663866field",
357 )
358 .unwrap(),
359 Field::<Console>::from_str(
360 "6508516067551208838086421306235504440162527555399726948591414865066786644888field",
361 )
362 .unwrap(),
363 Field::<Console>::from_str(
364 "5260580011132523115913756761919139190330166964648541423363604516046903841683field",
365 )
366 .unwrap(),
367 Field::<Console>::from_str(
368 "1066299182733912299977577599302716102002738653010828827086884529157392046228field",
369 )
370 .unwrap(),
371 Field::<Console>::from_str(
372 "1977519953625589014039847898215240724041194773120013187722954068145627219929field",
373 )
374 .unwrap(),
375 Field::<Console>::from_str(
376 "1618348632868002512910764605250139381231860094469042556990470848701700964713field",
377 )
378 .unwrap(),
379 Field::<Console>::from_str(
380 "1157459381876765943377450451674060447297483544491073402235960067133285590974field",
381 )
382 .unwrap(),
383 ]);
384 test_case_hash8(10, vec![
385 Field::<Console>::from_str(
386 "3912308888616251672812272013988802988420414245857866136212784631403027079860field",
387 )
388 .unwrap(),
389 Field::<Console>::from_str(
390 "4100923705771018951561873336835055979905965765839649442185404560120892958216field",
391 )
392 .unwrap(),
393 Field::<Console>::from_str(
394 "5701101373789959818781445339314572139971317958997296225671698446757742149719field",
395 )
396 .unwrap(),
397 Field::<Console>::from_str(
398 "5785597627944719799683455467917641287692417422465938462034769734951914291948field",
399 )
400 .unwrap(),
401 Field::<Console>::from_str(
402 "214818498460401597228033958287537426429167258531438668351703993840760770582field",
403 )
404 .unwrap(),
405 Field::<Console>::from_str(
406 "4497884203527978976088488455523871581608892729212445595385399904032800522087field",
407 )
408 .unwrap(),
409 Field::<Console>::from_str(
410 "4010331535874074900042223641934450423780782982190514529696596753456937384201field",
411 )
412 .unwrap(),
413 Field::<Console>::from_str(
414 "6067637133445382691713836557146174628934072680692724940823629181144890569742field",
415 )
416 .unwrap(),
417 Field::<Console>::from_str(
418 "5966421531117752671625849775894572561179958822813329961720805067254995723444field",
419 )
420 .unwrap(),
421 ]);
422 }
423}