1use super::*;
17
18impl<E: Environment, const RATE: usize> HashMany for Poseidon<E, RATE> {
19 type Input = Field<E>;
20 type Output = Field<E>;
21
22 #[inline]
23 fn hash_many(&self, input: &[Self::Input], num_outputs: u16) -> Vec<Self::Output> {
24 let mut preimage = Vec::with_capacity(RATE + input.len());
26 preimage.push(self.domain.clone());
27 preimage.push(Field::constant(console::Field::from_u128(input.len() as u128)));
28 preimage.resize(RATE, Field::zero()); preimage.extend_from_slice(input);
30
31 let mut state = vec![Field::zero(); RATE + CAPACITY];
33 let mut mode = DuplexSpongeMode::Absorbing { next_absorb_index: 0 };
34
35 self.absorb(&mut state, &mut mode, &preimage);
37 self.squeeze(&mut state, &mut mode, num_outputs)
38 }
39}
40
41#[allow(clippy::needless_borrow)]
42impl<E: Environment, const RATE: usize> Poseidon<E, RATE> {
43 #[inline]
45 fn absorb(&self, state: &mut [Field<E>], mode: &mut DuplexSpongeMode, input: &[Field<E>]) {
46 if !input.is_empty() {
47 let (mut absorb_index, should_permute) = match *mode {
49 DuplexSpongeMode::Absorbing { next_absorb_index } => match next_absorb_index == RATE {
50 true => (0, true),
51 false => (next_absorb_index, false),
52 },
53 DuplexSpongeMode::Squeezing { .. } => (0, true),
54 };
55
56 if should_permute {
58 self.permute(state);
59 }
60
61 let mut remaining = input;
62 loop {
63 let start = CAPACITY + absorb_index;
65
66 if absorb_index + remaining.len() <= RATE {
68 remaining.iter().enumerate().for_each(|(i, element)| state[start + i] += element);
70 *mode = DuplexSpongeMode::Absorbing { next_absorb_index: absorb_index + remaining.len() };
72 return;
73 }
74
75 let num_absorbed = RATE - absorb_index;
77 remaining.iter().enumerate().take(num_absorbed).for_each(|(i, element)| state[start + i] += element);
78
79 self.permute(state);
81
82 remaining = &remaining[num_absorbed..];
84 absorb_index = 0;
85 }
86 }
87 }
88
89 #[inline]
91 fn squeeze(&self, state: &mut [Field<E>], mode: &mut DuplexSpongeMode, num_outputs: u16) -> Vec<Field<E>> {
92 let mut output = vec![Field::zero(); num_outputs as usize];
93 if num_outputs != 0 {
94 self.squeeze_internal(state, mode, &mut output);
95 }
96 output
97 }
98
99 #[inline]
101 fn squeeze_internal(&self, state: &mut [Field<E>], mode: &mut DuplexSpongeMode, output: &mut [Field<E>]) {
102 let (mut squeeze_index, should_permute) = match *mode {
104 DuplexSpongeMode::Absorbing { .. } => (0, true),
105 DuplexSpongeMode::Squeezing { next_squeeze_index } => match next_squeeze_index == RATE {
106 true => (0, true),
107 false => (next_squeeze_index, false),
108 },
109 };
110
111 if should_permute {
113 self.permute(state);
114 }
115
116 let mut remaining = output;
117 loop {
118 let start = CAPACITY + squeeze_index;
120
121 if squeeze_index + remaining.len() <= RATE {
123 remaining.clone_from_slice(&state[start..(start + remaining.len())]);
125 *mode = DuplexSpongeMode::Squeezing { next_squeeze_index: squeeze_index + remaining.len() };
127 return;
128 }
129
130 let num_squeezed = RATE - squeeze_index;
132 remaining[..num_squeezed].clone_from_slice(&state[start..(start + num_squeezed)]);
133
134 self.permute(state);
136
137 remaining = &mut remaining[num_squeezed..];
139 squeeze_index = 0;
140 }
141 }
142
143 #[inline]
145 fn apply_ark(&self, state: &mut [Field<E>], round: usize) {
146 for (i, element) in state.iter_mut().enumerate() {
147 *element += &self.ark[round][i];
148 }
149 }
150
151 #[inline]
153 fn apply_s_box(&self, state: &mut [Field<E>], is_full_round: bool) {
154 if is_full_round {
155 for element in state.iter_mut() {
157 *element = (&*element).pow(&self.alpha);
158 }
159 } else {
160 state[0] = (&state[0]).pow(&self.alpha);
162 }
163 }
164
165 #[inline]
167 fn apply_mds(&self, state: &mut [Field<E>]) {
168 let mut new_state = Vec::with_capacity(state.len());
169 for i in 0..state.len() {
170 let mut accumulator = Field::zero();
171 for (j, element) in state.iter().enumerate() {
172 accumulator += element * &self.mds[i][j];
173 }
174 new_state.push(accumulator);
175 }
176 state.clone_from_slice(&new_state);
177 }
178
179 #[inline]
181 fn permute(&self, state: &mut [Field<E>]) {
182 let full_rounds_over_2 = self.full_rounds / 2;
184 let partial_round_range = full_rounds_over_2..(full_rounds_over_2 + self.partial_rounds);
185
186 for i in 0..(self.partial_rounds + self.full_rounds) {
188 let is_full_round = !partial_round_range.contains(&i);
189 self.apply_ark(state, i);
190 self.apply_s_box(state, is_full_round);
191 self.apply_mds(state);
192 }
193 }
194}
195
196#[cfg(all(test, feature = "console"))]
197mod tests {
198 use super::*;
199 use snarkvm_circuit_types::environment::Circuit;
200
201 use anyhow::Result;
202
203 const DOMAIN: &str = "PoseidonCircuit0";
204 const ITERATIONS: usize = 10;
205 const RATE: u16 = 4;
206
207 fn check_hash_many(
208 mode: Mode,
209 num_inputs: usize,
210 num_outputs: u16,
211 num_constants: u64,
212 num_public: u64,
213 num_private: u64,
214 num_constraints: u64,
215 rng: &mut TestRng,
216 ) -> Result<()> {
217 use console::HashMany as H;
218
219 let native = console::Poseidon::<<Circuit as Environment>::Network, { RATE as usize }>::setup(DOMAIN)?;
220 let poseidon = Poseidon::<Circuit, { RATE as usize }>::constant(native.clone());
221
222 for i in 0..ITERATIONS {
223 let native_input = (0..num_inputs)
225 .map(|_| console::Field::<<Circuit as Environment>::Network>::rand(rng))
226 .collect::<Vec<_>>();
227 let input = native_input.iter().map(|v| Field::<Circuit>::new(mode, *v)).collect::<Vec<_>>();
228
229 let expected = native.hash_many(&native_input, num_outputs);
231
232 Circuit::scope(format!("Poseidon {mode} {i} {num_outputs}"), || {
234 let candidate = poseidon.hash_many(&input, num_outputs);
235 for (expected_element, candidate_element) in expected.iter().zip_eq(&candidate) {
236 assert_eq!(*expected_element, candidate_element.eject_value());
237 }
238 let case = format!("(mode = {mode}, num_inputs = {num_inputs}, num_outputs = {num_outputs})");
239 assert_scope!(case, num_constants, num_public, num_private, num_constraints);
240 });
241 Circuit::reset();
242 }
243 Ok(())
244 }
245
246 #[test]
247 fn test_hash_many_constant() -> Result<()> {
248 let mut rng = TestRng::default();
249
250 for num_inputs in 0..=RATE {
251 for num_outputs in 0..=RATE {
252 check_hash_many(Mode::Constant, num_inputs as usize, num_outputs, 1, 0, 0, 0, &mut rng)?;
253 }
254 }
255 Ok(())
256 }
257
258 #[test]
259 fn test_hash_many_public() -> Result<()> {
260 let mut rng = TestRng::default();
261
262 for num_outputs in 0..=RATE {
263 check_hash_many(Mode::Public, 0, num_outputs, 1, 0, 0, 0, &mut rng)?;
264 }
265 for num_outputs in 1..=RATE {
266 check_hash_many(Mode::Public, 1, num_outputs, 1, 0, 335, 335, &mut rng)?;
267 check_hash_many(Mode::Public, 2, num_outputs, 1, 0, 340, 340, &mut rng)?;
268 check_hash_many(Mode::Public, 3, num_outputs, 1, 0, 345, 345, &mut rng)?;
269 check_hash_many(Mode::Public, 4, num_outputs, 1, 0, 350, 350, &mut rng)?;
270 check_hash_many(Mode::Public, 5, num_outputs, 1, 0, 705, 705, &mut rng)?;
271 check_hash_many(Mode::Public, 6, num_outputs, 1, 0, 705, 705, &mut rng)?;
272 }
273 for num_outputs in (RATE + 1)..=(RATE * 2) {
274 check_hash_many(Mode::Public, 1, num_outputs, 1, 0, 690, 690, &mut rng)?;
275 check_hash_many(Mode::Public, 2, num_outputs, 1, 0, 695, 695, &mut rng)?;
276 check_hash_many(Mode::Public, 3, num_outputs, 1, 0, 700, 700, &mut rng)?;
277 check_hash_many(Mode::Public, 4, num_outputs, 1, 0, 705, 705, &mut rng)?;
278 check_hash_many(Mode::Public, 5, num_outputs, 1, 0, 1060, 1060, &mut rng)?;
279 check_hash_many(Mode::Public, 6, num_outputs, 1, 0, 1060, 1060, &mut rng)?;
280 }
281 Ok(())
282 }
283
284 #[test]
285 fn test_hash_many_private() -> Result<()> {
286 let mut rng = TestRng::default();
287
288 for num_outputs in 0..=RATE {
289 check_hash_many(Mode::Private, 0, num_outputs, 1, 0, 0, 0, &mut rng)?;
290 }
291 for num_outputs in 1..=RATE {
292 check_hash_many(Mode::Private, 1, num_outputs, 1, 0, 335, 335, &mut rng)?;
293 check_hash_many(Mode::Private, 2, num_outputs, 1, 0, 340, 340, &mut rng)?;
294 check_hash_many(Mode::Private, 3, num_outputs, 1, 0, 345, 345, &mut rng)?;
295 check_hash_many(Mode::Private, 4, num_outputs, 1, 0, 350, 350, &mut rng)?;
296 check_hash_many(Mode::Private, 5, num_outputs, 1, 0, 705, 705, &mut rng)?;
297 check_hash_many(Mode::Private, 6, num_outputs, 1, 0, 705, 705, &mut rng)?;
298 }
299 for num_outputs in (RATE + 1)..=(RATE * 2) {
300 check_hash_many(Mode::Private, 1, num_outputs, 1, 0, 690, 690, &mut rng)?;
301 check_hash_many(Mode::Private, 2, num_outputs, 1, 0, 695, 695, &mut rng)?;
302 check_hash_many(Mode::Private, 3, num_outputs, 1, 0, 700, 700, &mut rng)?;
303 check_hash_many(Mode::Private, 4, num_outputs, 1, 0, 705, 705, &mut rng)?;
304 check_hash_many(Mode::Private, 5, num_outputs, 1, 0, 1060, 1060, &mut rng)?;
305 check_hash_many(Mode::Private, 6, num_outputs, 1, 0, 1060, 1060, &mut rng)?;
306 }
307 Ok(())
308 }
309}