twenty_first/util_types/
sponge.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
use std::fmt::Debug;
use std::iter;

use itertools::Itertools;
use num_traits::ConstOne;
use num_traits::ConstZero;

use crate::math::b_field_element::BFieldElement;

pub const RATE: usize = 10;

/// The hasher [Domain] differentiates between the modes of hashing.
///
/// The main purpose of declaring the domain is to prevent collisions between different types of
/// hashing by introducing defining differences in the way the hash function's internal state
/// (e.g. a sponge state's capacity) is initialized.
#[derive(Debug, PartialEq, Eq)]
pub enum Domain {
    /// The `VariableLength` domain is used for hashing objects that potentially serialize to more
    /// than [`RATE`] number of field elements.
    VariableLength,

    /// The `FixedLength` domain is used for hashing objects that always fit within [RATE] number
    /// of fields elements, e.g. a pair of [Digest](crate::math::digest::Digest)s.
    FixedLength,
}

/// A [cryptographic sponge][sponge]. Should only be based on a cryptographic permutation, e.g.,
/// [`Tip5`][tip5].
///
/// [sponge]: https://keccak.team/files/CSF-0.1.pdf
/// [tip5]: crate::prelude::Tip5
pub trait Sponge: Clone + Debug + Default + Send + Sync {
    const RATE: usize;

    fn init() -> Self;

    fn absorb(&mut self, input: [BFieldElement; RATE]);

    fn squeeze(&mut self) -> [BFieldElement; RATE];

    fn pad_and_absorb_all(&mut self, input: &[BFieldElement]) {
        // pad input with [1, 0, 0, …] – padding is at least one element
        let padded_length = (input.len() + 1).next_multiple_of(RATE);
        let padding_iter =
            iter::once(&BFieldElement::ONE).chain(iter::repeat(&BFieldElement::ZERO));
        let padded_input = input.iter().chain(padding_iter).take(padded_length);

        for chunk in padded_input.chunks(RATE).into_iter() {
            // the padded input has length some multiple of `RATE`
            let absorb_elems = chunk.cloned().collect_vec().try_into().unwrap();
            self.absorb(absorb_elems);
        }
    }
}

#[cfg(test)]
mod tests {
    use std::ops::Mul;

    use rand::distr::Distribution;
    use rand::distr::StandardUniform;
    use rand::Rng;

    use super::*;
    use crate::math::digest::Digest;
    use crate::math::tip5::Tip5;
    use crate::math::x_field_element::EXTENSION_DEGREE;
    use crate::prelude::BFieldCodec;
    use crate::prelude::XFieldElement;

    fn encode_prop<T>(smallest: T, largest: T)
    where
        T: Eq + BFieldCodec,
        StandardUniform: Distribution<T>,
    {
        let smallest_seq = smallest.encode();
        let largest_seq = largest.encode();
        assert_ne!(smallest_seq, largest_seq);
        assert_eq!(smallest_seq.len(), largest_seq.len());

        let mut rng = rand::rng();
        let random_a: T = rng.random();
        let random_b: T = rng.random();

        if random_a != random_b {
            assert_ne!(random_a.encode(), random_b.encode());
        } else {
            assert_eq!(random_a.encode(), random_b.encode());
        }
    }

    #[test]
    fn to_sequence_test() {
        // bool
        encode_prop(false, true);

        // u32
        encode_prop(0u32, u32::MAX);

        // u64
        encode_prop(0u64, u64::MAX);

        // BFieldElement
        let bfe_max = BFieldElement::new(BFieldElement::MAX);
        encode_prop(BFieldElement::ZERO, bfe_max);

        // XFieldElement
        let xfe_max = XFieldElement::new([bfe_max; EXTENSION_DEGREE]);
        encode_prop(XFieldElement::ZERO, xfe_max);

        // Digest
        let digest_max = Digest::new([bfe_max; Digest::LEN]);
        encode_prop(Digest::ALL_ZERO, digest_max);

        // u128
        encode_prop(0u128, u128::MAX);
    }

    fn sample_indices_prop(max: u32, num_indices: usize) {
        let mut sponge = Tip5::randomly_seeded();
        let indices = sponge.sample_indices(max, num_indices);
        assert_eq!(num_indices, indices.len());
        assert!(indices.into_iter().all(|index| index < max));
    }

    #[test]
    fn sample_indices_test() {
        let cases = [
            (2, 0),
            (4, 1),
            (8, 9),
            (16, 10),
            (32, 11),
            (64, 19),
            (128, 20),
            (256, 21),
            (512, 65),
        ];

        for (upper_bound, num_indices) in cases {
            sample_indices_prop(upper_bound, num_indices);
        }
    }

    #[test]
    fn sample_scalars_test() {
        let amounts = [0, 1, 2, 3, 4];
        let mut sponge = Tip5::randomly_seeded();
        let mut product = XFieldElement::ONE;
        for amount in amounts {
            let scalars = sponge.sample_scalars(amount);
            assert_eq!(amount, scalars.len());
            product *= scalars
                .into_iter()
                .fold(XFieldElement::ONE, XFieldElement::mul);
        }
        assert_ne!(product, XFieldElement::ZERO); // false failure with prob ~2^{-192}
    }
}