solana_sdk/
epoch_rewards_hasher.rs

1use {
2    siphasher::sip::SipHasher13,
3    solana_sdk::{hash::Hash, pubkey::Pubkey},
4    std::hash::Hasher,
5};
6
7#[derive(Debug, Clone)]
8pub struct EpochRewardsHasher {
9    hasher: SipHasher13,
10    partitions: usize,
11}
12
13impl EpochRewardsHasher {
14    /// Use SipHasher13 keyed on the `seed` for calculating epoch reward partition
15    pub fn new(partitions: usize, seed: &Hash) -> Self {
16        let mut hasher = SipHasher13::new();
17        hasher.write(seed.as_ref());
18        Self { hasher, partitions }
19    }
20
21    /// Return partition index (0..partitions) by hashing `address` with the `hasher`
22    pub fn hash_address_to_partition(self, address: &Pubkey) -> usize {
23        let Self {
24            mut hasher,
25            partitions,
26        } = self;
27        hasher.write(address.as_ref());
28        let hash64 = hasher.finish();
29
30        hash_to_partition(hash64, partitions)
31    }
32}
33
34/// Compute the partition index by modulo the address hash to number of partitions w.o bias.
35/// (rand_int * DESIRED_RANGE_MAX) / (RAND_MAX + 1)
36// Clippy objects to `u128::from(u64::MAX).saturating_add(1)`, even though it
37// can never overflow
38#[allow(clippy::arithmetic_side_effects)]
39fn hash_to_partition(hash: u64, partitions: usize) -> usize {
40    ((partitions as u128)
41        .saturating_mul(u128::from(hash))
42        .saturating_div(u128::from(u64::MAX).saturating_add(1))) as usize
43}
44
45#[cfg(test)]
46mod tests {
47    #![allow(clippy::arithmetic_side_effects)]
48    use {super::*, std::ops::RangeInclusive};
49
50    #[test]
51    fn test_get_equal_partition_range() {
52        // show how 2 equal partition ranges are 0..=(max/2), (max/2+1)..=max
53        // the inclusive is tricky to think about
54        let range = get_equal_partition_range(0, 2);
55        assert_eq!(*range.start(), 0);
56        assert_eq!(*range.end(), u64::MAX / 2);
57        let range = get_equal_partition_range(1, 2);
58        assert_eq!(*range.start(), u64::MAX / 2 + 1);
59        assert_eq!(*range.end(), u64::MAX);
60    }
61
62    #[test]
63    fn test_hash_to_partitions() {
64        let partitions = 16;
65        assert_eq!(hash_to_partition(0, partitions), 0);
66        assert_eq!(hash_to_partition(u64::MAX / 16, partitions), 0);
67        assert_eq!(hash_to_partition(u64::MAX / 16 + 1, partitions), 1);
68        assert_eq!(hash_to_partition(u64::MAX / 16 * 2, partitions), 1);
69        assert_eq!(hash_to_partition(u64::MAX / 16 * 2 + 1, partitions), 1);
70        assert_eq!(hash_to_partition(u64::MAX - 1, partitions), partitions - 1);
71        assert_eq!(hash_to_partition(u64::MAX, partitions), partitions - 1);
72    }
73
74    fn test_partitions(partition: usize, partitions: usize) {
75        let partition = partition.min(partitions - 1);
76        let range = get_equal_partition_range(partition, partitions);
77        // beginning and end of this partition
78        assert_eq!(hash_to_partition(*range.start(), partitions), partition);
79        assert_eq!(hash_to_partition(*range.end(), partitions), partition);
80        if partition < partitions - 1 {
81            // first index in next partition
82            assert_eq!(
83                hash_to_partition(*range.end() + 1, partitions),
84                partition + 1
85            );
86        } else {
87            assert_eq!(*range.end(), u64::MAX);
88        }
89        if partition > 0 {
90            // last index in previous partition
91            assert_eq!(
92                hash_to_partition(*range.start() - 1, partitions),
93                partition - 1
94            );
95        } else {
96            assert_eq!(*range.start(), 0);
97        }
98    }
99
100    #[test]
101    fn test_hash_to_partitions_equal_ranges() {
102        for partitions in [2, 4, 8, 16, 4096] {
103            assert_eq!(hash_to_partition(0, partitions), 0);
104            for partition in [0, 1, 2, partitions - 1] {
105                test_partitions(partition, partitions);
106            }
107
108            let range = get_equal_partition_range(0, partitions);
109            for partition in 1..partitions {
110                let this_range = get_equal_partition_range(partition, partitions);
111                assert_eq!(
112                    this_range.end() - this_range.start(),
113                    range.end() - range.start()
114                );
115            }
116        }
117        // verify non-evenly divisible partitions (partitions will be different sizes by at most 1 from any other partition)
118        for partitions in [3, 19, 1019, 4095] {
119            for partition in [0, 1, 2, partitions - 1] {
120                test_partitions(partition, partitions);
121            }
122            let expected_len_of_partition =
123                ((u128::from(u64::MAX) + 1) / partitions as u128) as u64;
124            for partition in 0..partitions {
125                let this_range = get_equal_partition_range(partition, partitions);
126                let len = this_range.end() - this_range.start();
127                // size is same or 1 less
128                assert!(
129                    len == expected_len_of_partition || len + 1 == expected_len_of_partition,
130                    "{}, {}, {}, {}",
131                    expected_len_of_partition,
132                    len,
133                    partition,
134                    partitions
135                );
136            }
137        }
138    }
139
140    /// return start and end_inclusive of `partition` indexes out of from u64::MAX+1 elements in equal `partitions`
141    /// These will be equal as long as (u64::MAX + 1) divides by `partitions` evenly
142    fn get_equal_partition_range(partition: usize, partitions: usize) -> RangeInclusive<u64> {
143        let max_inclusive = u128::from(u64::MAX);
144        let max_plus_1 = max_inclusive + 1;
145        let partition = partition as u128;
146        let partitions = partitions as u128;
147        let mut start = max_plus_1 * partition / partitions;
148        if partition > 0 && start * partitions / max_plus_1 == partition - 1 {
149            // partitions don't evenly divide and the start of this partition needs to be 1 greater
150            start += 1;
151        }
152
153        let mut end_inclusive = start + max_plus_1 / partitions - 1;
154        if partition < partitions.saturating_sub(1) {
155            let next = end_inclusive + 1;
156            if next * partitions / max_plus_1 == partition {
157                // this partition is far enough into partitions such that the len of this partition is 1 larger than expected
158                end_inclusive += 1;
159            }
160        } else {
161            end_inclusive = max_inclusive;
162        }
163        RangeInclusive::new(start as u64, end_inclusive as u64)
164    }
165
166    /// Make sure that each time hash_address_to_partition is called, it uses the initial seed state and that clone correctly copies the initial hasher state.
167    #[test]
168    fn test_hasher_copy() {
169        let seed = Hash::new_unique();
170        let partitions = 10;
171        let hasher = EpochRewardsHasher::new(partitions, &seed);
172
173        let pk = Pubkey::new_unique();
174
175        let b1 = hasher.clone().hash_address_to_partition(&pk);
176        let b2 = hasher.hash_address_to_partition(&pk);
177        assert_eq!(b1, b2);
178
179        // make sure b1 includes the seed's hash
180        let mut hasher = SipHasher13::new();
181        hasher.write(seed.as_ref());
182        hasher.write(pk.as_ref());
183        let partition = hash_to_partition(hasher.finish(), partitions);
184        assert_eq!(partition, b1);
185    }
186}