1use thiserror::Error;
6
7pub const HASH_BYTES: usize = 32;
9
10#[derive(Error, Debug)]
13pub enum PoseidonSyscallError {
14 #[error("Invalid parameters.")]
15 InvalidParameters,
16 #[error("Invalid endianness.")]
17 InvalidEndianness,
18 #[error("Invalid number of inputs. Maximum allowed is 12.")]
19 InvalidNumberOfInputs,
20 #[error("Input is an empty slice.")]
21 EmptyInput,
22 #[error(
23 "Invalid length of the input. The length matching the modulus of the prime field is 32."
24 )]
25 InvalidInputLength,
26 #[error("Failed to convert bytest into a prime field element.")]
27 BytesToPrimeFieldElement,
28 #[error("Input is larger than the modulus of the prime field.")]
29 InputLargerThanModulus,
30 #[error("Failed to convert a vector of bytes into an array.")]
31 VecToArray,
32 #[error("Failed to convert the number of inputs from u64 to u8.")]
33 U64Tou8,
34 #[error("Failed to convert bytes to BigInt")]
35 BytesToBigInt,
36 #[error("Invalid width. Choose a width between 2 and 16 for 1 to 15 inputs.")]
37 InvalidWidthCircom,
38 #[error("Unexpected error")]
39 Unexpected,
40}
41
42impl From<u64> for PoseidonSyscallError {
43 fn from(error: u64) -> Self {
44 match error {
45 1 => PoseidonSyscallError::InvalidParameters,
46 2 => PoseidonSyscallError::InvalidEndianness,
47 3 => PoseidonSyscallError::InvalidNumberOfInputs,
48 4 => PoseidonSyscallError::EmptyInput,
49 5 => PoseidonSyscallError::InvalidInputLength,
50 6 => PoseidonSyscallError::BytesToPrimeFieldElement,
51 7 => PoseidonSyscallError::InputLargerThanModulus,
52 8 => PoseidonSyscallError::VecToArray,
53 9 => PoseidonSyscallError::U64Tou8,
54 10 => PoseidonSyscallError::BytesToBigInt,
55 11 => PoseidonSyscallError::InvalidWidthCircom,
56 _ => PoseidonSyscallError::Unexpected,
57 }
58 }
59}
60
61impl From<PoseidonSyscallError> for u64 {
62 fn from(error: PoseidonSyscallError) -> Self {
63 match error {
64 PoseidonSyscallError::InvalidParameters => 1,
65 PoseidonSyscallError::InvalidEndianness => 2,
66 PoseidonSyscallError::InvalidNumberOfInputs => 3,
67 PoseidonSyscallError::EmptyInput => 4,
68 PoseidonSyscallError::InvalidInputLength => 5,
69 PoseidonSyscallError::BytesToPrimeFieldElement => 6,
70 PoseidonSyscallError::InputLargerThanModulus => 7,
71 PoseidonSyscallError::VecToArray => 8,
72 PoseidonSyscallError::U64Tou8 => 9,
73 PoseidonSyscallError::BytesToBigInt => 10,
74 PoseidonSyscallError::InvalidWidthCircom => 11,
75 PoseidonSyscallError::Unexpected => 12,
76 }
77 }
78}
79
80#[repr(u64)]
95pub enum Parameters {
96 Bn254X5 = 0,
107}
108
109impl TryFrom<u64> for Parameters {
110 type Error = PoseidonSyscallError;
111
112 fn try_from(value: u64) -> Result<Self, Self::Error> {
113 match value {
114 x if x == Parameters::Bn254X5 as u64 => Ok(Parameters::Bn254X5),
115 _ => Err(PoseidonSyscallError::InvalidParameters),
116 }
117 }
118}
119
120impl From<Parameters> for u64 {
121 fn from(value: Parameters) -> Self {
122 match value {
123 Parameters::Bn254X5 => 0,
124 }
125 }
126}
127
128#[repr(u64)]
130pub enum Endianness {
131 BigEndian = 0,
133 LittleEndian,
135}
136
137impl TryFrom<u64> for Endianness {
138 type Error = PoseidonSyscallError;
139
140 fn try_from(value: u64) -> Result<Self, Self::Error> {
141 match value {
142 x if x == Endianness::BigEndian as u64 => Ok(Endianness::BigEndian),
143 x if x == Endianness::LittleEndian as u64 => Ok(Endianness::LittleEndian),
144 _ => Err(PoseidonSyscallError::InvalidEndianness),
145 }
146 }
147}
148
149impl From<Endianness> for u64 {
150 fn from(value: Endianness) -> Self {
151 match value {
152 Endianness::BigEndian => 0,
153 Endianness::LittleEndian => 1,
154 }
155 }
156}
157
158#[repr(transparent)]
160pub struct PoseidonHash(pub [u8; HASH_BYTES]);
161
162impl PoseidonHash {
163 pub fn new(hash_array: [u8; HASH_BYTES]) -> Self {
164 Self(hash_array)
165 }
166
167 pub fn to_bytes(&self) -> [u8; HASH_BYTES] {
168 self.0
169 }
170}
171
172#[cfg(target_os = "solana")]
173pub use solana_define_syscall::definitions::sol_poseidon;
174
175#[allow(unused_variables)]
207pub fn hashv(
208 parameters: Parameters,
211 endianness: Endianness,
212 vals: &[&[u8]],
213) -> Result<PoseidonHash, PoseidonSyscallError> {
214 #[cfg(not(target_os = "solana"))]
217 {
218 use {
219 ark_bn254::Fr,
220 light_poseidon::{Poseidon, PoseidonBytesHasher, PoseidonError},
221 };
222
223 #[allow(non_local_definitions)]
224 impl From<PoseidonError> for PoseidonSyscallError {
225 fn from(error: PoseidonError) -> Self {
226 match error {
227 PoseidonError::InvalidNumberOfInputs { .. } => {
228 PoseidonSyscallError::InvalidNumberOfInputs
229 }
230 PoseidonError::EmptyInput => PoseidonSyscallError::EmptyInput,
231 PoseidonError::InvalidInputLength { .. } => {
232 PoseidonSyscallError::InvalidInputLength
233 }
234 PoseidonError::BytesToPrimeFieldElement { .. } => {
235 PoseidonSyscallError::BytesToPrimeFieldElement
236 }
237 PoseidonError::InputLargerThanModulus => {
238 PoseidonSyscallError::InputLargerThanModulus
239 }
240 PoseidonError::VecToArray => PoseidonSyscallError::VecToArray,
241 PoseidonError::U64Tou8 => PoseidonSyscallError::U64Tou8,
242 PoseidonError::BytesToBigInt => PoseidonSyscallError::BytesToBigInt,
243 PoseidonError::InvalidWidthCircom { .. } => {
244 PoseidonSyscallError::InvalidWidthCircom
245 }
246 }
247 }
248 }
249
250 let mut hasher =
251 Poseidon::<Fr>::new_circom(vals.len()).map_err(PoseidonSyscallError::from)?;
252 let res = match endianness {
253 Endianness::BigEndian => hasher.hash_bytes_be(vals),
254 Endianness::LittleEndian => hasher.hash_bytes_le(vals),
255 }
256 .map_err(PoseidonSyscallError::from)?;
257
258 Ok(PoseidonHash(res))
259 }
260 #[cfg(target_os = "solana")]
262 {
263 let mut hash_result = [0; HASH_BYTES];
264 let result = unsafe {
265 sol_poseidon(
266 parameters.into(),
267 endianness.into(),
268 vals as *const _ as *const u8,
269 vals.len() as u64,
270 &mut hash_result as *mut _ as *mut u8,
271 )
272 };
273
274 match result {
275 0 => Ok(PoseidonHash::new(hash_result)),
276 _ => Err(PoseidonSyscallError::Unexpected),
277 }
278 }
279}
280
281pub fn hash(
312 parameters: Parameters,
313 endianness: Endianness,
314 val: &[u8],
315) -> Result<PoseidonHash, PoseidonSyscallError> {
316 hashv(parameters, endianness, &[val])
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322
323 #[test]
324 fn test_poseidon_input_ones_be() {
325 let input = [1u8; 32];
326
327 let hash = hash(Parameters::Bn254X5, Endianness::BigEndian, &input).unwrap();
328 assert_eq!(
329 hash.to_bytes(),
330 [
331 5, 191, 172, 229, 129, 238, 97, 119, 204, 25, 198, 197, 99, 99, 166, 136, 130, 241,
332 30, 132, 7, 172, 99, 157, 185, 145, 224, 210, 127, 27, 117, 230
333 ]
334 );
335 }
336
337 #[test]
338 fn test_poseidon_input_ones_le() {
339 let input = [1u8; 32];
340
341 let hash = hash(Parameters::Bn254X5, Endianness::LittleEndian, &input).unwrap();
342 assert_eq!(
343 hash.to_bytes(),
344 [
345 230, 117, 27, 127, 210, 224, 145, 185, 157, 99, 172, 7, 132, 30, 241, 130, 136,
346 166, 99, 99, 197, 198, 25, 204, 119, 97, 238, 129, 229, 172, 191, 5
347 ],
348 );
349 }
350
351 #[test]
352 fn test_poseidon_input_ones_twos_be() {
353 let input1 = [1u8; 32];
354 let input2 = [2u8; 32];
355
356 let hash = hashv(
357 Parameters::Bn254X5,
358 Endianness::BigEndian,
359 &[&input1, &input2],
360 )
361 .unwrap();
362 assert_eq!(
363 hash.to_bytes(),
364 [
365 13, 84, 225, 147, 143, 138, 140, 28, 125, 235, 94, 3, 85, 242, 99, 25, 32, 123,
366 132, 254, 156, 162, 206, 27, 38, 231, 53, 200, 41, 130, 25, 144
367 ]
368 );
369 }
370
371 #[test]
372 fn test_poseidon_input_ones_twos_le() {
373 let input1 = [1u8; 32];
374 let input2 = [2u8; 32];
375
376 let hash = hashv(
377 Parameters::Bn254X5,
378 Endianness::LittleEndian,
379 &[&input1, &input2],
380 )
381 .unwrap();
382 assert_eq!(
383 hash.to_bytes(),
384 [
385 144, 25, 130, 41, 200, 53, 231, 38, 27, 206, 162, 156, 254, 132, 123, 32, 25, 99,
386 242, 85, 3, 94, 235, 125, 28, 140, 138, 143, 147, 225, 84, 13
387 ]
388 );
389 }
390
391 #[test]
392 fn test_poseidon_input_one() {
393 let input = [
394 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
395 0, 0, 1,
396 ];
397
398 let expected_hashes = [
399 [
400 41, 23, 97, 0, 234, 169, 98, 189, 193, 254, 108, 101, 77, 106, 60, 19, 14, 150,
401 164, 209, 22, 139, 51, 132, 139, 137, 125, 197, 2, 130, 1, 51,
402 ],
403 [
404 0, 122, 243, 70, 226, 211, 4, 39, 158, 121, 224, 169, 243, 2, 63, 119, 18, 148,
405 167, 138, 203, 112, 231, 63, 144, 175, 226, 124, 173, 64, 30, 129,
406 ],
407 [
408 2, 192, 6, 110, 16, 167, 42, 189, 43, 51, 195, 178, 20, 203, 62, 129, 188, 177,
409 182, 227, 9, 97, 205, 35, 194, 2, 177, 134, 115, 191, 37, 67,
410 ],
411 [
412 8, 44, 156, 55, 10, 13, 36, 244, 65, 111, 188, 65, 74, 55, 104, 31, 120, 68, 45,
413 39, 216, 99, 133, 153, 28, 23, 214, 252, 12, 75, 125, 113,
414 ],
415 [
416 16, 56, 150, 5, 174, 104, 141, 79, 20, 219, 133, 49, 34, 196, 125, 102, 168, 3,
417 199, 43, 65, 88, 156, 177, 191, 134, 135, 65, 178, 6, 185, 187,
418 ],
419 [
420 42, 115, 246, 121, 50, 140, 62, 171, 114, 74, 163, 229, 189, 191, 80, 179, 144, 53,
421 215, 114, 159, 19, 91, 151, 9, 137, 15, 133, 197, 220, 94, 118,
422 ],
423 [
424 34, 118, 49, 10, 167, 243, 52, 58, 40, 66, 20, 19, 157, 157, 169, 89, 190, 42, 49,
425 178, 199, 8, 165, 248, 25, 84, 178, 101, 229, 58, 48, 184,
426 ],
427 [
428 23, 126, 20, 83, 196, 70, 225, 176, 125, 43, 66, 51, 66, 81, 71, 9, 92, 79, 202,
429 187, 35, 61, 35, 11, 109, 70, 162, 20, 217, 91, 40, 132,
430 ],
431 [
432 14, 143, 238, 47, 228, 157, 163, 15, 222, 235, 72, 196, 46, 187, 68, 204, 110, 231,
433 5, 95, 97, 251, 202, 94, 49, 59, 138, 95, 202, 131, 76, 71,
434 ],
435 [
436 46, 196, 198, 94, 99, 120, 171, 140, 115, 48, 133, 79, 74, 112, 119, 193, 255, 146,
437 96, 228, 72, 133, 196, 184, 29, 209, 49, 173, 58, 134, 205, 150,
438 ],
439 [
440 0, 113, 61, 65, 236, 166, 53, 241, 23, 212, 236, 188, 235, 95, 58, 102, 220, 65,
441 66, 235, 112, 181, 103, 101, 188, 53, 143, 27, 236, 64, 187, 155,
442 ],
443 [
444 20, 57, 11, 224, 186, 239, 36, 155, 212, 124, 101, 221, 172, 101, 194, 229, 46,
445 133, 19, 192, 129, 193, 205, 114, 201, 128, 6, 9, 142, 154, 143, 190,
446 ],
447 ];
448
449 for (i, expected_hash) in expected_hashes.iter().enumerate() {
450 let inputs = vec![&input; i + 1]
451 .into_iter()
452 .map(|arr| &arr[..])
453 .collect::<Vec<_>>();
454 let hash = hashv(Parameters::Bn254X5, Endianness::BigEndian, &inputs).unwrap();
455 assert_eq!(hash.to_bytes(), *expected_hash);
456 }
457 }
458}