1use super::*;
17
18impl<E: Environment, const TYPE: u8, const VARIANT: usize> Hash for Keccak<E, TYPE, VARIANT> {
19 type Input = Boolean<E>;
20 type Output = Vec<Boolean<E>>;
21
22 #[inline]
24 fn hash(&self, input: &[Self::Input]) -> Self::Output {
25 let bitrate = PERMUTATION_WIDTH - 2 * VARIANT;
29 debug_assert!(bitrate < PERMUTATION_WIDTH, "The bitrate must be less than the permutation width");
30 debug_assert!(bitrate % 8 == 0, "The bitrate must be a multiple of 8");
31
32 if input.is_empty() {
34 E::halt("The input to the hash function must not be empty")
35 }
36
37 let mut s = vec![Boolean::constant(false); PERMUTATION_WIDTH];
39
40 let padded_blocks = match TYPE {
42 0 => Self::pad_keccak(input, bitrate),
43 1 => Self::pad_sha3(input, bitrate),
44 2.. => unreachable!("Invalid Keccak type"),
45 };
46
47 for block in padded_blocks {
55 for (j, bit) in block.into_iter().enumerate() {
57 s[j] = &s[j] ^ &bit;
58 }
59 s = Self::permutation_f::<PERMUTATION_WIDTH, NUM_ROUNDS>(s, &self.round_constants, &self.rotl);
61 }
62
63 let mut z = s[..bitrate].to_vec();
74 while z.len() < VARIANT {
76 s = Self::permutation_f::<PERMUTATION_WIDTH, NUM_ROUNDS>(s, &self.round_constants, &self.rotl);
78 z.extend(s.iter().take(bitrate).cloned());
80 }
81 z.truncate(VARIANT);
83 z
84 }
85}
86
87impl<E: Environment, const TYPE: u8, const VARIANT: usize> Keccak<E, TYPE, VARIANT> {
88 fn pad_keccak(input: &[Boolean<E>], bitrate: usize) -> Vec<Vec<Boolean<E>>> {
93 debug_assert!(bitrate > 0, "The bitrate must be positive");
94
95 let mut padded_input = input.to_vec();
97 padded_input.resize((input.len() + 7) / 8 * 8, Boolean::constant(false));
98
99 padded_input.push(Boolean::constant(true));
101
102 while (padded_input.len() % bitrate) != (bitrate - 1) {
104 padded_input.push(Boolean::constant(false));
105 }
106
107 padded_input.push(Boolean::constant(true));
109
110 let mut result = Vec::new();
112 for block in padded_input.chunks(bitrate) {
113 result.push(block.to_vec());
114 }
115 result
116 }
117
118 fn pad_sha3(input: &[Boolean<E>], bitrate: usize) -> Vec<Vec<Boolean<E>>> {
123 debug_assert!(bitrate > 1, "The bitrate must be greater than 1");
124
125 let mut padded_input = input.to_vec();
127 padded_input.resize((input.len() + 7) / 8 * 8, Boolean::constant(false));
128
129 padded_input.push(Boolean::constant(false));
131 padded_input.push(Boolean::constant(true));
132 padded_input.push(Boolean::constant(true));
133 padded_input.push(Boolean::constant(false));
134
135 while (padded_input.len() % bitrate) != (bitrate - 1) {
137 padded_input.push(Boolean::constant(false));
138 }
139
140 padded_input.push(Boolean::constant(true));
142
143 let mut result = Vec::new();
145 for block in padded_input.chunks(bitrate) {
146 result.push(block.to_vec());
147 }
148 result
149 }
150
151 fn permutation_f<const WIDTH: usize, const NUM_ROUNDS: usize>(
157 input: Vec<Boolean<E>>,
158 round_constants: &[U64<E>],
159 rotl: &[usize],
160 ) -> Vec<Boolean<E>> {
161 debug_assert_eq!(input.len(), WIDTH, "The input vector must have {WIDTH} bits");
162 debug_assert_eq!(
163 round_constants.len(),
164 NUM_ROUNDS,
165 "The round constants vector must have {NUM_ROUNDS} elements"
166 );
167
168 let mut a = input.chunks(64).map(U64::from_bits_le).collect::<Vec<_>>();
170 for round_constant in round_constants.iter().take(NUM_ROUNDS) {
172 a = Self::round(a, round_constant, rotl);
173 }
174 let mut bits = Vec::with_capacity(input.len());
176 a.iter().for_each(|e| e.write_bits_le(&mut bits));
177 bits
178 }
179
180 fn round(a: Vec<U64<E>>, round_constant: &U64<E>, rotl: &[usize]) -> Vec<U64<E>> {
186 debug_assert_eq!(a.len(), MODULO * MODULO, "The input vector 'a' must have {} elements", MODULO * MODULO);
187
188 let mut c = Vec::with_capacity(MODULO);
198 for x in 0..MODULO {
199 c.push(&a[x] ^ &a[x + MODULO] ^ &a[x + (2 * MODULO)] ^ &a[x + (3 * MODULO)] ^ &a[x + (4 * MODULO)]);
200 }
201
202 let mut d = Vec::with_capacity(MODULO);
212 for x in 0..MODULO {
213 d.push(&c[(x + 4) % MODULO] ^ Self::rotate_left(&c[(x + 1) % MODULO], 63));
214 }
215 let mut a_1 = Vec::with_capacity(MODULO * MODULO);
216 for y in 0..MODULO {
217 for x in 0..MODULO {
218 a_1.push(&a[x + (y * MODULO)] ^ &d[x]);
219 }
220 }
221
222 let mut a_2 = a_1.clone();
241 for y in 0..MODULO {
242 for x in 0..MODULO {
243 a_2[y + ((((2 * x) + (3 * y)) % MODULO) * MODULO)] =
245 Self::rotate_left(&a_1[x + (y * MODULO)], rotl[x + (y * MODULO)]);
246 }
247 }
248
249 let mut a_3 = Vec::with_capacity(MODULO * MODULO);
258 for y in 0..MODULO {
259 for x in 0..MODULO {
260 let a = &a_2[x + (y * MODULO)];
261 let b = &a_2[((x + 1) % MODULO) + (y * MODULO)];
262 let c = &a_2[((x + 2) % MODULO) + (y * MODULO)];
263 a_3.push(a ^ ((!b) & c));
264 }
265 }
266
267 a_3[0] = &a_3[0] ^ round_constant;
272 a_3
273 }
274
275 fn rotate_left(value: &U64<E>, n: usize) -> U64<E> {
277 let mut bits_le = value.to_bits_le();
279 bits_le.rotate_left(n);
280 U64::from_bits_le(&bits_le)
282 }
283}
284
285#[cfg(all(test, feature = "console"))]
286mod tests {
287 use super::*;
288 use console::Rng;
289 use snarkvm_circuit_types::environment::Circuit;
290
291 const ITERATIONS: usize = 3;
292
293 macro_rules! check_equivalence {
294 ($console:expr, $circuit:expr) => {
295 use console::Hash as H;
296
297 let rng = &mut TestRng::default();
298
299 let mut input_sizes = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 32, 64, 128, 256, 512, 1024];
300 input_sizes.extend((0..5).map(|_| rng.gen_range(1..1024)));
301
302 for num_inputs in input_sizes {
303 println!("Checking equivalence for {num_inputs} inputs");
304
305 let native_input = (0..num_inputs).map(|_| Uniform::rand(rng)).collect::<Vec<bool>>();
307 let input = native_input.iter().map(|v| Boolean::<Circuit>::new(Mode::Private, *v)).collect::<Vec<_>>();
308
309 let expected = $console.hash(&native_input).expect("Failed to hash console input");
311
312 let candidate = $circuit.hash(&input);
314 assert_eq!(expected, candidate.eject_value());
315 Circuit::reset();
316 }
317 };
318 }
319
320 fn check_hash(
321 mode: Mode,
322 num_inputs: usize,
323 num_constants: u64,
324 num_public: u64,
325 num_private: u64,
326 num_constraints: u64,
327 rng: &mut TestRng,
328 ) {
329 use console::Hash as H;
330
331 let native = console::Keccak256::default();
332 let keccak = Keccak256::<Circuit>::new();
333
334 for i in 0..ITERATIONS {
335 let native_input = (0..num_inputs).map(|_| Uniform::rand(rng)).collect::<Vec<bool>>();
337 let input = native_input.iter().map(|v| Boolean::<Circuit>::new(mode, *v)).collect::<Vec<_>>();
338
339 let expected = native.hash(&native_input).expect("Failed to hash native input");
341
342 Circuit::scope(format!("Keccak {mode} {i}"), || {
344 let candidate = keccak.hash(&input);
345 assert_eq!(expected, candidate.eject_value());
346 let case = format!("(mode = {mode}, num_inputs = {num_inputs})");
347 assert_scope!(case, num_constants, num_public, num_private, num_constraints);
348 });
349 Circuit::reset();
350 }
351 }
352
353 #[test]
354 fn test_keccak_256_hash_constant() {
355 let mut rng = TestRng::default();
356
357 check_hash(Mode::Constant, 1, 0, 0, 0, 0, &mut rng);
358 check_hash(Mode::Constant, 2, 0, 0, 0, 0, &mut rng);
359 check_hash(Mode::Constant, 3, 0, 0, 0, 0, &mut rng);
360 check_hash(Mode::Constant, 4, 0, 0, 0, 0, &mut rng);
361 check_hash(Mode::Constant, 5, 0, 0, 0, 0, &mut rng);
362 check_hash(Mode::Constant, 6, 0, 0, 0, 0, &mut rng);
363 check_hash(Mode::Constant, 7, 0, 0, 0, 0, &mut rng);
364 check_hash(Mode::Constant, 8, 0, 0, 0, 0, &mut rng);
365 check_hash(Mode::Constant, 16, 0, 0, 0, 0, &mut rng);
366 check_hash(Mode::Constant, 32, 0, 0, 0, 0, &mut rng);
367 check_hash(Mode::Constant, 64, 0, 0, 0, 0, &mut rng);
368 check_hash(Mode::Constant, 128, 0, 0, 0, 0, &mut rng);
369 check_hash(Mode::Constant, 256, 0, 0, 0, 0, &mut rng);
370 check_hash(Mode::Constant, 511, 0, 0, 0, 0, &mut rng);
371 check_hash(Mode::Constant, 512, 0, 0, 0, 0, &mut rng);
372 check_hash(Mode::Constant, 513, 0, 0, 0, 0, &mut rng);
373 check_hash(Mode::Constant, 1023, 0, 0, 0, 0, &mut rng);
374 check_hash(Mode::Constant, 1024, 0, 0, 0, 0, &mut rng);
375 check_hash(Mode::Constant, 1025, 0, 0, 0, 0, &mut rng);
376 }
377
378 #[test]
379 fn test_keccak_256_hash_public() {
380 let mut rng = TestRng::default();
381
382 check_hash(Mode::Public, 1, 0, 0, 138157, 138157, &mut rng);
383 check_hash(Mode::Public, 2, 0, 0, 139108, 139108, &mut rng);
384 check_hash(Mode::Public, 3, 0, 0, 139741, 139741, &mut rng);
385 check_hash(Mode::Public, 4, 0, 0, 140318, 140318, &mut rng);
386 check_hash(Mode::Public, 5, 0, 0, 140879, 140879, &mut rng);
387 check_hash(Mode::Public, 6, 0, 0, 141350, 141350, &mut rng);
388 check_hash(Mode::Public, 7, 0, 0, 141787, 141787, &mut rng);
389 check_hash(Mode::Public, 8, 0, 0, 142132, 142132, &mut rng);
390 check_hash(Mode::Public, 16, 0, 0, 144173, 144173, &mut rng);
391 check_hash(Mode::Public, 32, 0, 0, 145394, 145394, &mut rng);
392 check_hash(Mode::Public, 64, 0, 0, 146650, 146650, &mut rng);
393 check_hash(Mode::Public, 128, 0, 0, 149248, 149248, &mut rng);
394 check_hash(Mode::Public, 256, 0, 0, 150848, 150848, &mut rng);
395 check_hash(Mode::Public, 512, 0, 0, 151424, 151424, &mut rng);
396 check_hash(Mode::Public, 1024, 0, 0, 152448, 152448, &mut rng);
397 }
398
399 #[test]
400 fn test_keccak_256_hash_private() {
401 let mut rng = TestRng::default();
402
403 check_hash(Mode::Private, 1, 0, 0, 138157, 138157, &mut rng);
404 check_hash(Mode::Private, 2, 0, 0, 139108, 139108, &mut rng);
405 check_hash(Mode::Private, 3, 0, 0, 139741, 139741, &mut rng);
406 check_hash(Mode::Private, 4, 0, 0, 140318, 140318, &mut rng);
407 check_hash(Mode::Private, 5, 0, 0, 140879, 140879, &mut rng);
408 check_hash(Mode::Private, 6, 0, 0, 141350, 141350, &mut rng);
409 check_hash(Mode::Private, 7, 0, 0, 141787, 141787, &mut rng);
410 check_hash(Mode::Private, 8, 0, 0, 142132, 142132, &mut rng);
411 check_hash(Mode::Private, 16, 0, 0, 144173, 144173, &mut rng);
412 check_hash(Mode::Private, 32, 0, 0, 145394, 145394, &mut rng);
413 check_hash(Mode::Private, 64, 0, 0, 146650, 146650, &mut rng);
414 check_hash(Mode::Private, 128, 0, 0, 149248, 149248, &mut rng);
415 check_hash(Mode::Private, 256, 0, 0, 150848, 150848, &mut rng);
416 check_hash(Mode::Private, 512, 0, 0, 151424, 151424, &mut rng);
417 check_hash(Mode::Private, 1024, 0, 0, 152448, 152448, &mut rng);
418 }
419
420 #[test]
421 fn test_keccak_224_equivalence() {
422 check_equivalence!(console::Keccak224::default(), Keccak224::<Circuit>::new());
423 }
424
425 #[test]
426 fn test_keccak_256_equivalence() {
427 check_equivalence!(console::Keccak256::default(), Keccak256::<Circuit>::new());
428 }
429
430 #[test]
431 fn test_keccak_384_equivalence() {
432 check_equivalence!(console::Keccak384::default(), Keccak384::<Circuit>::new());
433 }
434
435 #[test]
436 fn test_keccak_512_equivalence() {
437 check_equivalence!(console::Keccak512::default(), Keccak512::<Circuit>::new());
438 }
439
440 #[test]
441 fn test_sha3_224_equivalence() {
442 check_equivalence!(console::Sha3_224::default(), Sha3_224::<Circuit>::new());
443 }
444
445 #[test]
446 fn test_sha3_256_equivalence() {
447 check_equivalence!(console::Sha3_256::default(), Sha3_256::<Circuit>::new());
448 }
449
450 #[test]
451 fn test_sha3_384_equivalence() {
452 check_equivalence!(console::Sha3_384::default(), Sha3_384::<Circuit>::new());
453 }
454
455 #[test]
456 fn test_sha3_512_equivalence() {
457 check_equivalence!(console::Sha3_512::default(), Sha3_512::<Circuit>::new());
458 }
459}