solana_sdk/
epoch_rewards_hasher.rs1use {
2 siphasher::sip::SipHasher13,
3 solana_sdk::{hash::Hash, pubkey::Pubkey},
4 std::hash::Hasher,
5};
6
7#[derive(Debug, Clone)]
8pub struct EpochRewardsHasher {
9 hasher: SipHasher13,
10 partitions: usize,
11}
12
13impl EpochRewardsHasher {
14 pub fn new(partitions: usize, seed: &Hash) -> Self {
16 let mut hasher = SipHasher13::new();
17 hasher.write(seed.as_ref());
18 Self { hasher, partitions }
19 }
20
21 pub fn hash_address_to_partition(self, address: &Pubkey) -> usize {
23 let Self {
24 mut hasher,
25 partitions,
26 } = self;
27 hasher.write(address.as_ref());
28 let hash64 = hasher.finish();
29
30 hash_to_partition(hash64, partitions)
31 }
32}
33
34#[allow(clippy::arithmetic_side_effects)]
39fn hash_to_partition(hash: u64, partitions: usize) -> usize {
40 ((partitions as u128)
41 .saturating_mul(u128::from(hash))
42 .saturating_div(u128::from(u64::MAX).saturating_add(1))) as usize
43}
44
45#[cfg(test)]
46mod tests {
47 #![allow(clippy::arithmetic_side_effects)]
48 use {super::*, std::ops::RangeInclusive};
49
50 #[test]
51 fn test_get_equal_partition_range() {
52 let range = get_equal_partition_range(0, 2);
55 assert_eq!(*range.start(), 0);
56 assert_eq!(*range.end(), u64::MAX / 2);
57 let range = get_equal_partition_range(1, 2);
58 assert_eq!(*range.start(), u64::MAX / 2 + 1);
59 assert_eq!(*range.end(), u64::MAX);
60 }
61
62 #[test]
63 fn test_hash_to_partitions() {
64 let partitions = 16;
65 assert_eq!(hash_to_partition(0, partitions), 0);
66 assert_eq!(hash_to_partition(u64::MAX / 16, partitions), 0);
67 assert_eq!(hash_to_partition(u64::MAX / 16 + 1, partitions), 1);
68 assert_eq!(hash_to_partition(u64::MAX / 16 * 2, partitions), 1);
69 assert_eq!(hash_to_partition(u64::MAX / 16 * 2 + 1, partitions), 1);
70 assert_eq!(hash_to_partition(u64::MAX - 1, partitions), partitions - 1);
71 assert_eq!(hash_to_partition(u64::MAX, partitions), partitions - 1);
72 }
73
74 fn test_partitions(partition: usize, partitions: usize) {
75 let partition = partition.min(partitions - 1);
76 let range = get_equal_partition_range(partition, partitions);
77 assert_eq!(hash_to_partition(*range.start(), partitions), partition);
79 assert_eq!(hash_to_partition(*range.end(), partitions), partition);
80 if partition < partitions - 1 {
81 assert_eq!(
83 hash_to_partition(*range.end() + 1, partitions),
84 partition + 1
85 );
86 } else {
87 assert_eq!(*range.end(), u64::MAX);
88 }
89 if partition > 0 {
90 assert_eq!(
92 hash_to_partition(*range.start() - 1, partitions),
93 partition - 1
94 );
95 } else {
96 assert_eq!(*range.start(), 0);
97 }
98 }
99
100 #[test]
101 fn test_hash_to_partitions_equal_ranges() {
102 for partitions in [2, 4, 8, 16, 4096] {
103 assert_eq!(hash_to_partition(0, partitions), 0);
104 for partition in [0, 1, 2, partitions - 1] {
105 test_partitions(partition, partitions);
106 }
107
108 let range = get_equal_partition_range(0, partitions);
109 for partition in 1..partitions {
110 let this_range = get_equal_partition_range(partition, partitions);
111 assert_eq!(
112 this_range.end() - this_range.start(),
113 range.end() - range.start()
114 );
115 }
116 }
117 for partitions in [3, 19, 1019, 4095] {
119 for partition in [0, 1, 2, partitions - 1] {
120 test_partitions(partition, partitions);
121 }
122 let expected_len_of_partition =
123 ((u128::from(u64::MAX) + 1) / partitions as u128) as u64;
124 for partition in 0..partitions {
125 let this_range = get_equal_partition_range(partition, partitions);
126 let len = this_range.end() - this_range.start();
127 assert!(
129 len == expected_len_of_partition || len + 1 == expected_len_of_partition,
130 "{}, {}, {}, {}",
131 expected_len_of_partition,
132 len,
133 partition,
134 partitions
135 );
136 }
137 }
138 }
139
140 fn get_equal_partition_range(partition: usize, partitions: usize) -> RangeInclusive<u64> {
143 let max_inclusive = u128::from(u64::MAX);
144 let max_plus_1 = max_inclusive + 1;
145 let partition = partition as u128;
146 let partitions = partitions as u128;
147 let mut start = max_plus_1 * partition / partitions;
148 if partition > 0 && start * partitions / max_plus_1 == partition - 1 {
149 start += 1;
151 }
152
153 let mut end_inclusive = start + max_plus_1 / partitions - 1;
154 if partition < partitions.saturating_sub(1) {
155 let next = end_inclusive + 1;
156 if next * partitions / max_plus_1 == partition {
157 end_inclusive += 1;
159 }
160 } else {
161 end_inclusive = max_inclusive;
162 }
163 RangeInclusive::new(start as u64, end_inclusive as u64)
164 }
165
166 #[test]
168 fn test_hasher_copy() {
169 let seed = Hash::new_unique();
170 let partitions = 10;
171 let hasher = EpochRewardsHasher::new(partitions, &seed);
172
173 let pk = Pubkey::new_unique();
174
175 let b1 = hasher.clone().hash_address_to_partition(&pk);
176 let b2 = hasher.hash_address_to_partition(&pk);
177 assert_eq!(b1, b2);
178
179 let mut hasher = SipHasher13::new();
181 hasher.write(seed.as_ref());
182 hasher.write(pk.as_ref());
183 let partition = hash_to_partition(hasher.finish(), partitions);
184 assert_eq!(partition, b1);
185 }
186}