snarkvm_circuit_algorithms/keccak/
hash.rs

1// Copyright 2024 Aleo Network Foundation
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use super::*;
17
18impl<E: Environment, const TYPE: u8, const VARIANT: usize> Hash for Keccak<E, TYPE, VARIANT> {
19    type Input = Boolean<E>;
20    type Output = Vec<Boolean<E>>;
21
22    /// Returns the Keccak hash of the given input as bits.
23    #[inline]
24    fn hash(&self, input: &[Self::Input]) -> Self::Output {
25        // The bitrate `r`.
26        // The capacity is twice the digest length (i.e. twice the variant, where the variant is in {224, 256, 384, 512}),
27        // and the bit rate is the width (1600 in our case) minus the capacity.
28        let bitrate = PERMUTATION_WIDTH - 2 * VARIANT;
29        debug_assert!(bitrate < PERMUTATION_WIDTH, "The bitrate must be less than the permutation width");
30        debug_assert!(bitrate % 8 == 0, "The bitrate must be a multiple of 8");
31
32        // Ensure the input is not empty.
33        if input.is_empty() {
34            E::halt("The input to the hash function must not be empty")
35        }
36
37        // The root state `s` is defined as `0^b`.
38        let mut s = vec![Boolean::constant(false); PERMUTATION_WIDTH];
39
40        // The padded blocks `P`.
41        let padded_blocks = match TYPE {
42            0 => Self::pad_keccak(input, bitrate),
43            1 => Self::pad_sha3(input, bitrate),
44            2.. => unreachable!("Invalid Keccak type"),
45        };
46
47        /* The first part of the sponge construction (the absorbing phase):
48         *
49         * for i = 0 to |P| − 1 do
50         *   s = s ⊕ (P_i || 0^c) # Note: |P_i| + c == b, since |P_i| == r
51         *   s = f(s)
52         * end for
53         */
54        for block in padded_blocks {
55            // s = s ⊕ (P_i || 0^c)
56            for (j, bit) in block.into_iter().enumerate() {
57                s[j] = &s[j] ^ &bit;
58            }
59            // s = f(s)
60            s = Self::permutation_f::<PERMUTATION_WIDTH, NUM_ROUNDS>(s, &self.round_constants, &self.rotl);
61        }
62
63        /* The second part of the sponge construction (the squeezing phase):
64         *
65         * Z = s[0..r-1]
66         * while |Z| < d do // d is the digest length
67         *   s = f(s)
68         *   Z = Z || s[0..r-1]
69         * end while
70         * return Z[0..d-1]
71         */
72        // Z = s[0..r-1]
73        let mut z = s[..bitrate].to_vec();
74        // while |Z| < l do
75        while z.len() < VARIANT {
76            // s = f(s)
77            s = Self::permutation_f::<PERMUTATION_WIDTH, NUM_ROUNDS>(s, &self.round_constants, &self.rotl);
78            // Z = Z || s[0..r-1]
79            z.extend(s.iter().take(bitrate).cloned());
80        }
81        // return Z[0..d-1]
82        z.truncate(VARIANT);
83        z
84    }
85}
86
87impl<E: Environment, const TYPE: u8, const VARIANT: usize> Keccak<E, TYPE, VARIANT> {
88    /// In Keccak, `pad` is a multi-rate padding, defined as `pad(M) = M || 0x01 || 0x00…0x00 || 0x80`,
89    /// where `M` is the input data, and `0x01 || 0x00…0x00 || 0x80` is the padding.
90    /// The padding extends the input data to a multiple of the bitrate `r`, defined as `r = b - c`,
91    /// where `b` is the width of the permutation, and `c` is the capacity.
92    fn pad_keccak(input: &[Boolean<E>], bitrate: usize) -> Vec<Vec<Boolean<E>>> {
93        debug_assert!(bitrate > 0, "The bitrate must be positive");
94
95        // Resize the input to a multiple of 8.
96        let mut padded_input = input.to_vec();
97        padded_input.resize((input.len() + 7) / 8 * 8, Boolean::constant(false));
98
99        // Step 1: Append the bit "1" to the message.
100        padded_input.push(Boolean::constant(true));
101
102        // Step 2: Append "0" bits until the length of the message is congruent to r-1 mod r.
103        while (padded_input.len() % bitrate) != (bitrate - 1) {
104            padded_input.push(Boolean::constant(false));
105        }
106
107        // Step 3: Append the bit "1" to the message.
108        padded_input.push(Boolean::constant(true));
109
110        // Construct the padded blocks.
111        let mut result = Vec::new();
112        for block in padded_input.chunks(bitrate) {
113            result.push(block.to_vec());
114        }
115        result
116    }
117
118    /// In SHA-3, `pad` is a SHAKE, defined as `pad(M) = M || 0x06 || 0x00…0x00 || 0x80`,
119    /// where `M` is the input data, and `0x06 || 0x00…0x00 || 0x80` is the padding.
120    /// The padding extends the input data to a multiple of the bitrate `r`, defined as `r = b - c`,
121    /// where `b` is the width of the permutation, and `c` is the capacity.
122    fn pad_sha3(input: &[Boolean<E>], bitrate: usize) -> Vec<Vec<Boolean<E>>> {
123        debug_assert!(bitrate > 1, "The bitrate must be greater than 1");
124
125        // Resize the input to a multiple of 8.
126        let mut padded_input = input.to_vec();
127        padded_input.resize((input.len() + 7) / 8 * 8, Boolean::constant(false));
128
129        // Step 1: Append the "0x06" byte to the message.
130        padded_input.push(Boolean::constant(false));
131        padded_input.push(Boolean::constant(true));
132        padded_input.push(Boolean::constant(true));
133        padded_input.push(Boolean::constant(false));
134
135        // Step 2: Append "0" bits until the length of the message is congruent to r-1 mod r.
136        while (padded_input.len() % bitrate) != (bitrate - 1) {
137            padded_input.push(Boolean::constant(false));
138        }
139
140        // Step 3: Append the bit "1" to the message.
141        padded_input.push(Boolean::constant(true));
142
143        // Construct the padded blocks.
144        let mut result = Vec::new();
145        for block in padded_input.chunks(bitrate) {
146            result.push(block.to_vec());
147        }
148        result
149    }
150
151    /// The permutation `f` is a function that takes a fixed-length input and produces a fixed-length output,
152    /// defined as `f = Keccak-f[b]`, where `b := 25 * 2^l` is the width of the permutation,
153    /// and `l` is the log width of the permutation.
154    ///
155    /// The round function `Rnd` is applied `12 + 2l` times, where `l` is the log width of the permutation.
156    fn permutation_f<const WIDTH: usize, const NUM_ROUNDS: usize>(
157        input: Vec<Boolean<E>>,
158        round_constants: &[U64<E>],
159        rotl: &[usize],
160    ) -> Vec<Boolean<E>> {
161        debug_assert_eq!(input.len(), WIDTH, "The input vector must have {WIDTH} bits");
162        debug_assert_eq!(
163            round_constants.len(),
164            NUM_ROUNDS,
165            "The round constants vector must have {NUM_ROUNDS} elements"
166        );
167
168        // Partition the input into 64-bit chunks.
169        let mut a = input.chunks(64).map(U64::from_bits_le).collect::<Vec<_>>();
170        // Permute the input.
171        for round_constant in round_constants.iter().take(NUM_ROUNDS) {
172            a = Self::round(a, round_constant, rotl);
173        }
174        // Return the permuted input.
175        let mut bits = Vec::with_capacity(input.len());
176        a.iter().for_each(|e| e.write_bits_le(&mut bits));
177        bits
178    }
179
180    /// The round function `Rnd` is defined as follows:
181    /// ```text
182    /// Rnd = ι ◦ χ ◦ π ◦ ρ ◦ θ
183    /// ```
184    /// where `◦` denotes function composition.
185    fn round(a: Vec<U64<E>>, round_constant: &U64<E>, rotl: &[usize]) -> Vec<U64<E>> {
186        debug_assert_eq!(a.len(), MODULO * MODULO, "The input vector 'a' must have {} elements", MODULO * MODULO);
187
188        /* The first part of Algorithm 1, θ:
189         *
190         * for x = 0 to 4 do
191         *   C[x] = a[x, 0]
192         *   for y = 1 to 4 do
193         *     C[x] = C[x] ⊕ a[x, y]
194         *   end for
195         * end for
196         */
197        let mut c = Vec::with_capacity(MODULO);
198        for x in 0..MODULO {
199            c.push(&a[x] ^ &a[x + MODULO] ^ &a[x + (2 * MODULO)] ^ &a[x + (3 * MODULO)] ^ &a[x + (4 * MODULO)]);
200        }
201
202        /* The second part of Algorithm 1, θ:
203         *
204         * for x = 0 to 4 do
205         *   D[x] = C[x−1] ⊕ ROT(C[x+1],1)
206         *   for y = 0 to 4 do
207         *     A[x, y] = a[x, y] ⊕ D[x]
208         *   end for
209         * end for
210         */
211        let mut d = Vec::with_capacity(MODULO);
212        for x in 0..MODULO {
213            d.push(&c[(x + 4) % MODULO] ^ Self::rotate_left(&c[(x + 1) % MODULO], 63));
214        }
215        let mut a_1 = Vec::with_capacity(MODULO * MODULO);
216        for y in 0..MODULO {
217            for x in 0..MODULO {
218                a_1.push(&a[x + (y * MODULO)] ^ &d[x]);
219            }
220        }
221
222        /* Algorithm 3, π:
223         *
224         * for x = 0 to 4 do
225         *   for y = 0 to 4 do
226         *     (X, Y) = (y, (2*x + 3*y) mod 5)
227         *     A[X, Y] = a[x, y]
228         *   end for
229         * end for
230         *
231         * Algorithm 2, ρ:
232         *
233         * A[0, 0] = a[0, 0]
234         * (x, y) = (1, 0)
235         * for t = 0 to 23 do
236         *   A[x, y] = ROT(a[x, y], (t + 1)(t + 2)/2)
237         *   (x, y) = (y, (2*x + 3*y) mod 5)
238         * end for
239         */
240        let mut a_2 = a_1.clone();
241        for y in 0..MODULO {
242            for x in 0..MODULO {
243                // This step combines the π and ρ steps into one.
244                a_2[y + ((((2 * x) + (3 * y)) % MODULO) * MODULO)] =
245                    Self::rotate_left(&a_1[x + (y * MODULO)], rotl[x + (y * MODULO)]);
246            }
247        }
248
249        /* Algorithm 4, χ:
250         *
251         * for y = 0 to 4 do
252         *   for x = 0 to 4 do
253         *     A[x, y] = a[x, y] ⊕ ((¬a[x+1, y]) ∧ a[x+2, y])
254         *   end for
255         * end for
256         */
257        let mut a_3 = Vec::with_capacity(MODULO * MODULO);
258        for y in 0..MODULO {
259            for x in 0..MODULO {
260                let a = &a_2[x + (y * MODULO)];
261                let b = &a_2[((x + 1) % MODULO) + (y * MODULO)];
262                let c = &a_2[((x + 2) % MODULO) + (y * MODULO)];
263                a_3.push(a ^ ((!b) & c));
264            }
265        }
266
267        /* ι:
268         *
269         * A[0, 0] = A[0, 0] ⊕ RC
270         */
271        a_3[0] = &a_3[0] ^ round_constant;
272        a_3
273    }
274
275    /// Performs a rotate left operation on the given `u64` value.
276    fn rotate_left(value: &U64<E>, n: usize) -> U64<E> {
277        // Perform the rotation.
278        let mut bits_le = value.to_bits_le();
279        bits_le.rotate_left(n);
280        // Return the rotated value.
281        U64::from_bits_le(&bits_le)
282    }
283}
284
285#[cfg(all(test, feature = "console"))]
286mod tests {
287    use super::*;
288    use console::Rng;
289    use snarkvm_circuit_types::environment::Circuit;
290
291    const ITERATIONS: usize = 3;
292
293    macro_rules! check_equivalence {
294        ($console:expr, $circuit:expr) => {
295            use console::Hash as H;
296
297            let rng = &mut TestRng::default();
298
299            let mut input_sizes = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 32, 64, 128, 256, 512, 1024];
300            input_sizes.extend((0..5).map(|_| rng.gen_range(1..1024)));
301
302            for num_inputs in input_sizes {
303                println!("Checking equivalence for {num_inputs} inputs");
304
305                // Prepare the preimage.
306                let native_input = (0..num_inputs).map(|_| Uniform::rand(rng)).collect::<Vec<bool>>();
307                let input = native_input.iter().map(|v| Boolean::<Circuit>::new(Mode::Private, *v)).collect::<Vec<_>>();
308
309                // Compute the console hash.
310                let expected = $console.hash(&native_input).expect("Failed to hash console input");
311
312                // Compute the circuit hash.
313                let candidate = $circuit.hash(&input);
314                assert_eq!(expected, candidate.eject_value());
315                Circuit::reset();
316            }
317        };
318    }
319
320    fn check_hash(
321        mode: Mode,
322        num_inputs: usize,
323        num_constants: u64,
324        num_public: u64,
325        num_private: u64,
326        num_constraints: u64,
327        rng: &mut TestRng,
328    ) {
329        use console::Hash as H;
330
331        let native = console::Keccak256::default();
332        let keccak = Keccak256::<Circuit>::new();
333
334        for i in 0..ITERATIONS {
335            // Prepare the preimage.
336            let native_input = (0..num_inputs).map(|_| Uniform::rand(rng)).collect::<Vec<bool>>();
337            let input = native_input.iter().map(|v| Boolean::<Circuit>::new(mode, *v)).collect::<Vec<_>>();
338
339            // Compute the native hash.
340            let expected = native.hash(&native_input).expect("Failed to hash native input");
341
342            // Compute the circuit hash.
343            Circuit::scope(format!("Keccak {mode} {i}"), || {
344                let candidate = keccak.hash(&input);
345                assert_eq!(expected, candidate.eject_value());
346                let case = format!("(mode = {mode}, num_inputs = {num_inputs})");
347                assert_scope!(case, num_constants, num_public, num_private, num_constraints);
348            });
349            Circuit::reset();
350        }
351    }
352
353    #[test]
354    fn test_keccak_256_hash_constant() {
355        let mut rng = TestRng::default();
356
357        check_hash(Mode::Constant, 1, 0, 0, 0, 0, &mut rng);
358        check_hash(Mode::Constant, 2, 0, 0, 0, 0, &mut rng);
359        check_hash(Mode::Constant, 3, 0, 0, 0, 0, &mut rng);
360        check_hash(Mode::Constant, 4, 0, 0, 0, 0, &mut rng);
361        check_hash(Mode::Constant, 5, 0, 0, 0, 0, &mut rng);
362        check_hash(Mode::Constant, 6, 0, 0, 0, 0, &mut rng);
363        check_hash(Mode::Constant, 7, 0, 0, 0, 0, &mut rng);
364        check_hash(Mode::Constant, 8, 0, 0, 0, 0, &mut rng);
365        check_hash(Mode::Constant, 16, 0, 0, 0, 0, &mut rng);
366        check_hash(Mode::Constant, 32, 0, 0, 0, 0, &mut rng);
367        check_hash(Mode::Constant, 64, 0, 0, 0, 0, &mut rng);
368        check_hash(Mode::Constant, 128, 0, 0, 0, 0, &mut rng);
369        check_hash(Mode::Constant, 256, 0, 0, 0, 0, &mut rng);
370        check_hash(Mode::Constant, 511, 0, 0, 0, 0, &mut rng);
371        check_hash(Mode::Constant, 512, 0, 0, 0, 0, &mut rng);
372        check_hash(Mode::Constant, 513, 0, 0, 0, 0, &mut rng);
373        check_hash(Mode::Constant, 1023, 0, 0, 0, 0, &mut rng);
374        check_hash(Mode::Constant, 1024, 0, 0, 0, 0, &mut rng);
375        check_hash(Mode::Constant, 1025, 0, 0, 0, 0, &mut rng);
376    }
377
378    #[test]
379    fn test_keccak_256_hash_public() {
380        let mut rng = TestRng::default();
381
382        check_hash(Mode::Public, 1, 0, 0, 138157, 138157, &mut rng);
383        check_hash(Mode::Public, 2, 0, 0, 139108, 139108, &mut rng);
384        check_hash(Mode::Public, 3, 0, 0, 139741, 139741, &mut rng);
385        check_hash(Mode::Public, 4, 0, 0, 140318, 140318, &mut rng);
386        check_hash(Mode::Public, 5, 0, 0, 140879, 140879, &mut rng);
387        check_hash(Mode::Public, 6, 0, 0, 141350, 141350, &mut rng);
388        check_hash(Mode::Public, 7, 0, 0, 141787, 141787, &mut rng);
389        check_hash(Mode::Public, 8, 0, 0, 142132, 142132, &mut rng);
390        check_hash(Mode::Public, 16, 0, 0, 144173, 144173, &mut rng);
391        check_hash(Mode::Public, 32, 0, 0, 145394, 145394, &mut rng);
392        check_hash(Mode::Public, 64, 0, 0, 146650, 146650, &mut rng);
393        check_hash(Mode::Public, 128, 0, 0, 149248, 149248, &mut rng);
394        check_hash(Mode::Public, 256, 0, 0, 150848, 150848, &mut rng);
395        check_hash(Mode::Public, 512, 0, 0, 151424, 151424, &mut rng);
396        check_hash(Mode::Public, 1024, 0, 0, 152448, 152448, &mut rng);
397    }
398
399    #[test]
400    fn test_keccak_256_hash_private() {
401        let mut rng = TestRng::default();
402
403        check_hash(Mode::Private, 1, 0, 0, 138157, 138157, &mut rng);
404        check_hash(Mode::Private, 2, 0, 0, 139108, 139108, &mut rng);
405        check_hash(Mode::Private, 3, 0, 0, 139741, 139741, &mut rng);
406        check_hash(Mode::Private, 4, 0, 0, 140318, 140318, &mut rng);
407        check_hash(Mode::Private, 5, 0, 0, 140879, 140879, &mut rng);
408        check_hash(Mode::Private, 6, 0, 0, 141350, 141350, &mut rng);
409        check_hash(Mode::Private, 7, 0, 0, 141787, 141787, &mut rng);
410        check_hash(Mode::Private, 8, 0, 0, 142132, 142132, &mut rng);
411        check_hash(Mode::Private, 16, 0, 0, 144173, 144173, &mut rng);
412        check_hash(Mode::Private, 32, 0, 0, 145394, 145394, &mut rng);
413        check_hash(Mode::Private, 64, 0, 0, 146650, 146650, &mut rng);
414        check_hash(Mode::Private, 128, 0, 0, 149248, 149248, &mut rng);
415        check_hash(Mode::Private, 256, 0, 0, 150848, 150848, &mut rng);
416        check_hash(Mode::Private, 512, 0, 0, 151424, 151424, &mut rng);
417        check_hash(Mode::Private, 1024, 0, 0, 152448, 152448, &mut rng);
418    }
419
420    #[test]
421    fn test_keccak_224_equivalence() {
422        check_equivalence!(console::Keccak224::default(), Keccak224::<Circuit>::new());
423    }
424
425    #[test]
426    fn test_keccak_256_equivalence() {
427        check_equivalence!(console::Keccak256::default(), Keccak256::<Circuit>::new());
428    }
429
430    #[test]
431    fn test_keccak_384_equivalence() {
432        check_equivalence!(console::Keccak384::default(), Keccak384::<Circuit>::new());
433    }
434
435    #[test]
436    fn test_keccak_512_equivalence() {
437        check_equivalence!(console::Keccak512::default(), Keccak512::<Circuit>::new());
438    }
439
440    #[test]
441    fn test_sha3_224_equivalence() {
442        check_equivalence!(console::Sha3_224::default(), Sha3_224::<Circuit>::new());
443    }
444
445    #[test]
446    fn test_sha3_256_equivalence() {
447        check_equivalence!(console::Sha3_256::default(), Sha3_256::<Circuit>::new());
448    }
449
450    #[test]
451    fn test_sha3_384_equivalence() {
452        check_equivalence!(console::Sha3_384::default(), Sha3_384::<Circuit>::new());
453    }
454
455    #[test]
456    fn test_sha3_512_equivalence() {
457        check_equivalence!(console::Sha3_512::default(), Sha3_512::<Circuit>::new());
458    }
459}