use super::*;
impl<E: Environment, const TYPE: u8, const VARIANT: usize> Hash for Keccak<E, TYPE, VARIANT> {
type Input = Boolean<E>;
type Output = Vec<Boolean<E>>;
#[inline]
fn hash(&self, input: &[Self::Input]) -> Self::Output {
let bitrate = PERMUTATION_WIDTH - 2 * VARIANT;
debug_assert!(bitrate < PERMUTATION_WIDTH, "The bitrate must be less than the permutation width");
debug_assert!(bitrate % 8 == 0, "The bitrate must be a multiple of 8");
if input.is_empty() {
E::halt("The input to the hash function must not be empty")
}
let mut s = vec![Boolean::constant(false); PERMUTATION_WIDTH];
let padded_blocks = match TYPE {
0 => Self::pad_keccak(input, bitrate),
1 => Self::pad_sha3(input, bitrate),
2.. => unreachable!("Invalid Keccak type"),
};
for block in padded_blocks {
for (j, bit) in block.into_iter().enumerate() {
s[j] = &s[j] ^ &bit;
}
s = Self::permutation_f::<PERMUTATION_WIDTH, NUM_ROUNDS>(s, &self.round_constants, &self.rotl);
}
let mut z = s[..bitrate].to_vec();
while z.len() < VARIANT {
s = Self::permutation_f::<PERMUTATION_WIDTH, NUM_ROUNDS>(s, &self.round_constants, &self.rotl);
z.extend(s.iter().take(bitrate).cloned());
}
z.truncate(VARIANT);
z
}
}
impl<E: Environment, const TYPE: u8, const VARIANT: usize> Keccak<E, TYPE, VARIANT> {
fn pad_keccak(input: &[Boolean<E>], bitrate: usize) -> Vec<Vec<Boolean<E>>> {
debug_assert!(bitrate > 0, "The bitrate must be positive");
let mut padded_input = input.to_vec();
padded_input.resize((input.len() + 7) / 8 * 8, Boolean::constant(false));
padded_input.push(Boolean::constant(true));
while (padded_input.len() % bitrate) != (bitrate - 1) {
padded_input.push(Boolean::constant(false));
}
padded_input.push(Boolean::constant(true));
let mut result = Vec::new();
for block in padded_input.chunks(bitrate) {
result.push(block.to_vec());
}
result
}
fn pad_sha3(input: &[Boolean<E>], bitrate: usize) -> Vec<Vec<Boolean<E>>> {
debug_assert!(bitrate > 1, "The bitrate must be greater than 1");
let mut padded_input = input.to_vec();
padded_input.resize((input.len() + 7) / 8 * 8, Boolean::constant(false));
padded_input.push(Boolean::constant(false));
padded_input.push(Boolean::constant(true));
padded_input.push(Boolean::constant(true));
padded_input.push(Boolean::constant(false));
while (padded_input.len() % bitrate) != (bitrate - 1) {
padded_input.push(Boolean::constant(false));
}
padded_input.push(Boolean::constant(true));
let mut result = Vec::new();
for block in padded_input.chunks(bitrate) {
result.push(block.to_vec());
}
result
}
fn permutation_f<const WIDTH: usize, const NUM_ROUNDS: usize>(
input: Vec<Boolean<E>>,
round_constants: &[U64<E>],
rotl: &[usize],
) -> Vec<Boolean<E>> {
debug_assert_eq!(input.len(), WIDTH, "The input vector must have {WIDTH} bits");
debug_assert_eq!(
round_constants.len(),
NUM_ROUNDS,
"The round constants vector must have {NUM_ROUNDS} elements"
);
let mut a = input.chunks(64).map(U64::from_bits_le).collect::<Vec<_>>();
for round_constant in round_constants.iter().take(NUM_ROUNDS) {
a = Self::round(a, round_constant, rotl);
}
let mut bits = Vec::with_capacity(input.len());
a.iter().for_each(|e| e.write_bits_le(&mut bits));
bits
}
fn round(a: Vec<U64<E>>, round_constant: &U64<E>, rotl: &[usize]) -> Vec<U64<E>> {
debug_assert_eq!(a.len(), MODULO * MODULO, "The input vector 'a' must have {} elements", MODULO * MODULO);
let mut c = Vec::with_capacity(MODULO);
for x in 0..MODULO {
c.push(&a[x] ^ &a[x + MODULO] ^ &a[x + (2 * MODULO)] ^ &a[x + (3 * MODULO)] ^ &a[x + (4 * MODULO)]);
}
let mut d = Vec::with_capacity(MODULO);
for x in 0..MODULO {
d.push(&c[(x + 4) % MODULO] ^ Self::rotate_left(&c[(x + 1) % MODULO], 63));
}
let mut a_1 = Vec::with_capacity(MODULO * MODULO);
for y in 0..MODULO {
for x in 0..MODULO {
a_1.push(&a[x + (y * MODULO)] ^ &d[x]);
}
}
let mut a_2 = a_1.clone();
for y in 0..MODULO {
for x in 0..MODULO {
a_2[y + ((((2 * x) + (3 * y)) % MODULO) * MODULO)] =
Self::rotate_left(&a_1[x + (y * MODULO)], rotl[x + (y * MODULO)]);
}
}
let mut a_3 = Vec::with_capacity(MODULO * MODULO);
for y in 0..MODULO {
for x in 0..MODULO {
let a = &a_2[x + (y * MODULO)];
let b = &a_2[((x + 1) % MODULO) + (y * MODULO)];
let c = &a_2[((x + 2) % MODULO) + (y * MODULO)];
a_3.push(a ^ ((!b) & c));
}
}
a_3[0] = &a_3[0] ^ round_constant;
a_3
}
fn rotate_left(value: &U64<E>, n: usize) -> U64<E> {
let mut bits_le = value.to_bits_le();
bits_le.rotate_left(n);
U64::from_bits_le(&bits_le)
}
}
#[cfg(all(test, feature = "console"))]
mod tests {
use super::*;
use console::Rng;
use snarkvm_circuit_types::environment::Circuit;
const ITERATIONS: usize = 3;
macro_rules! check_equivalence {
($console:expr, $circuit:expr) => {
use console::Hash as H;
let rng = &mut TestRng::default();
let mut input_sizes = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 32, 64, 128, 256, 512, 1024];
input_sizes.extend((0..5).map(|_| rng.gen_range(1..1024)));
for num_inputs in input_sizes {
println!("Checking equivalence for {num_inputs} inputs");
let native_input = (0..num_inputs).map(|_| Uniform::rand(rng)).collect::<Vec<bool>>();
let input = native_input.iter().map(|v| Boolean::<Circuit>::new(Mode::Private, *v)).collect::<Vec<_>>();
let expected = $console.hash(&native_input).expect("Failed to hash console input");
let candidate = $circuit.hash(&input);
assert_eq!(expected, candidate.eject_value());
Circuit::reset();
}
};
}
fn check_hash(
mode: Mode,
num_inputs: usize,
num_constants: u64,
num_public: u64,
num_private: u64,
num_constraints: u64,
rng: &mut TestRng,
) {
use console::Hash as H;
let native = console::Keccak256::default();
let keccak = Keccak256::<Circuit>::new();
for i in 0..ITERATIONS {
let native_input = (0..num_inputs).map(|_| Uniform::rand(rng)).collect::<Vec<bool>>();
let input = native_input.iter().map(|v| Boolean::<Circuit>::new(mode, *v)).collect::<Vec<_>>();
let expected = native.hash(&native_input).expect("Failed to hash native input");
Circuit::scope(format!("Keccak {mode} {i}"), || {
let candidate = keccak.hash(&input);
assert_eq!(expected, candidate.eject_value());
let case = format!("(mode = {mode}, num_inputs = {num_inputs})");
assert_scope!(case, num_constants, num_public, num_private, num_constraints);
});
Circuit::reset();
}
}
#[test]
fn test_keccak_256_hash_constant() {
let mut rng = TestRng::default();
check_hash(Mode::Constant, 1, 0, 0, 0, 0, &mut rng);
check_hash(Mode::Constant, 2, 0, 0, 0, 0, &mut rng);
check_hash(Mode::Constant, 3, 0, 0, 0, 0, &mut rng);
check_hash(Mode::Constant, 4, 0, 0, 0, 0, &mut rng);
check_hash(Mode::Constant, 5, 0, 0, 0, 0, &mut rng);
check_hash(Mode::Constant, 6, 0, 0, 0, 0, &mut rng);
check_hash(Mode::Constant, 7, 0, 0, 0, 0, &mut rng);
check_hash(Mode::Constant, 8, 0, 0, 0, 0, &mut rng);
check_hash(Mode::Constant, 16, 0, 0, 0, 0, &mut rng);
check_hash(Mode::Constant, 32, 0, 0, 0, 0, &mut rng);
check_hash(Mode::Constant, 64, 0, 0, 0, 0, &mut rng);
check_hash(Mode::Constant, 128, 0, 0, 0, 0, &mut rng);
check_hash(Mode::Constant, 256, 0, 0, 0, 0, &mut rng);
check_hash(Mode::Constant, 511, 0, 0, 0, 0, &mut rng);
check_hash(Mode::Constant, 512, 0, 0, 0, 0, &mut rng);
check_hash(Mode::Constant, 513, 0, 0, 0, 0, &mut rng);
check_hash(Mode::Constant, 1023, 0, 0, 0, 0, &mut rng);
check_hash(Mode::Constant, 1024, 0, 0, 0, 0, &mut rng);
check_hash(Mode::Constant, 1025, 0, 0, 0, 0, &mut rng);
}
#[test]
fn test_keccak_256_hash_public() {
let mut rng = TestRng::default();
check_hash(Mode::Public, 1, 0, 0, 138157, 138157, &mut rng);
check_hash(Mode::Public, 2, 0, 0, 139108, 139108, &mut rng);
check_hash(Mode::Public, 3, 0, 0, 139741, 139741, &mut rng);
check_hash(Mode::Public, 4, 0, 0, 140318, 140318, &mut rng);
check_hash(Mode::Public, 5, 0, 0, 140879, 140879, &mut rng);
check_hash(Mode::Public, 6, 0, 0, 141350, 141350, &mut rng);
check_hash(Mode::Public, 7, 0, 0, 141787, 141787, &mut rng);
check_hash(Mode::Public, 8, 0, 0, 142132, 142132, &mut rng);
check_hash(Mode::Public, 16, 0, 0, 144173, 144173, &mut rng);
check_hash(Mode::Public, 32, 0, 0, 145394, 145394, &mut rng);
check_hash(Mode::Public, 64, 0, 0, 146650, 146650, &mut rng);
check_hash(Mode::Public, 128, 0, 0, 149248, 149248, &mut rng);
check_hash(Mode::Public, 256, 0, 0, 150848, 150848, &mut rng);
check_hash(Mode::Public, 512, 0, 0, 151424, 151424, &mut rng);
check_hash(Mode::Public, 1024, 0, 0, 152448, 152448, &mut rng);
}
#[test]
fn test_keccak_256_hash_private() {
let mut rng = TestRng::default();
check_hash(Mode::Private, 1, 0, 0, 138157, 138157, &mut rng);
check_hash(Mode::Private, 2, 0, 0, 139108, 139108, &mut rng);
check_hash(Mode::Private, 3, 0, 0, 139741, 139741, &mut rng);
check_hash(Mode::Private, 4, 0, 0, 140318, 140318, &mut rng);
check_hash(Mode::Private, 5, 0, 0, 140879, 140879, &mut rng);
check_hash(Mode::Private, 6, 0, 0, 141350, 141350, &mut rng);
check_hash(Mode::Private, 7, 0, 0, 141787, 141787, &mut rng);
check_hash(Mode::Private, 8, 0, 0, 142132, 142132, &mut rng);
check_hash(Mode::Private, 16, 0, 0, 144173, 144173, &mut rng);
check_hash(Mode::Private, 32, 0, 0, 145394, 145394, &mut rng);
check_hash(Mode::Private, 64, 0, 0, 146650, 146650, &mut rng);
check_hash(Mode::Private, 128, 0, 0, 149248, 149248, &mut rng);
check_hash(Mode::Private, 256, 0, 0, 150848, 150848, &mut rng);
check_hash(Mode::Private, 512, 0, 0, 151424, 151424, &mut rng);
check_hash(Mode::Private, 1024, 0, 0, 152448, 152448, &mut rng);
}
#[test]
fn test_keccak_224_equivalence() {
check_equivalence!(console::Keccak224::default(), Keccak224::<Circuit>::new());
}
#[test]
fn test_keccak_256_equivalence() {
check_equivalence!(console::Keccak256::default(), Keccak256::<Circuit>::new());
}
#[test]
fn test_keccak_384_equivalence() {
check_equivalence!(console::Keccak384::default(), Keccak384::<Circuit>::new());
}
#[test]
fn test_keccak_512_equivalence() {
check_equivalence!(console::Keccak512::default(), Keccak512::<Circuit>::new());
}
#[test]
fn test_sha3_224_equivalence() {
check_equivalence!(console::Sha3_224::default(), Sha3_224::<Circuit>::new());
}
#[test]
fn test_sha3_256_equivalence() {
check_equivalence!(console::Sha3_256::default(), Sha3_256::<Circuit>::new());
}
#[test]
fn test_sha3_384_equivalence() {
check_equivalence!(console::Sha3_384::default(), Sha3_384::<Circuit>::new());
}
#[test]
fn test_sha3_512_equivalence() {
check_equivalence!(console::Sha3_512::default(), Sha3_512::<Circuit>::new());
}
}