1use std::{cmp::Ordering, marker::PhantomData};
2
3use crate::{
4 engine::{Engine, GF_ORDER},
5 rate::{
6 DecoderWork, EncoderWork, HighRateDecoder, HighRateEncoder, LowRateDecoder, LowRateEncoder,
7 Rate, RateDecoder, RateEncoder,
8 },
9 DecoderResult, EncoderResult, Error,
10};
11
12fn use_high_rate(original_count: usize, recovery_count: usize) -> Result<bool, Error> {
16 if original_count > GF_ORDER || recovery_count > GF_ORDER {
17 return Err(Error::UnsupportedShardCount {
18 original_count,
19 recovery_count,
20 });
21 }
22
23 let original_count_pow2 = original_count.next_power_of_two();
24 let recovery_count_pow2 = recovery_count.next_power_of_two();
25
26 let smaller_pow2 = std::cmp::min(original_count_pow2, recovery_count_pow2);
27 let larger = std::cmp::max(original_count, recovery_count);
28
29 if original_count == 0 || recovery_count == 0 || smaller_pow2 + larger > GF_ORDER {
30 return Err(Error::UnsupportedShardCount {
31 original_count,
32 recovery_count,
33 });
34 }
35
36 match original_count_pow2.cmp(&recovery_count_pow2) {
37 Ordering::Less => {
38 Ok(false)
42 }
43
44 Ordering::Greater => {
45 Ok(true)
49 }
50
51 Ordering::Equal => {
52 if original_count <= recovery_count {
56 Ok(true)
58 } else {
59 Ok(false)
61 }
62 }
63 }
64}
65
66pub struct DefaultRate<E: Engine>(PhantomData<E>);
71
72impl<E: Engine> Rate<E> for DefaultRate<E> {
73 type RateEncoder = DefaultRateEncoder<E>;
74 type RateDecoder = DefaultRateDecoder<E>;
75
76 fn supports(original_count: usize, recovery_count: usize) -> bool {
77 use_high_rate(original_count, recovery_count).is_ok()
78 }
79}
80
81enum InnerEncoder<E: Engine> {
85 High(HighRateEncoder<E>),
86 Low(LowRateEncoder<E>),
87
88 None,
90}
91
92impl<E: Engine> Default for InnerEncoder<E> {
93 fn default() -> Self {
94 InnerEncoder::None
95 }
96}
97
98pub struct DefaultRateEncoder<E: Engine>(InnerEncoder<E>);
109
110impl<E: Engine> RateEncoder<E> for DefaultRateEncoder<E> {
111 type Rate = DefaultRate<E>;
112
113 fn add_original_shard<T: AsRef<[u8]>>(&mut self, original_shard: T) -> Result<(), Error> {
114 match &mut self.0 {
115 InnerEncoder::High(high) => high.add_original_shard(original_shard),
116 InnerEncoder::Low(low) => low.add_original_shard(original_shard),
117 InnerEncoder::None => unreachable!(),
118 }
119 }
120
121 fn encode(&mut self) -> Result<EncoderResult, Error> {
122 match &mut self.0 {
123 InnerEncoder::High(high) => high.encode(),
124 InnerEncoder::Low(low) => low.encode(),
125 InnerEncoder::None => unreachable!(),
126 }
127 }
128
129 fn into_parts(self) -> (E, EncoderWork) {
130 match self.0 {
131 InnerEncoder::High(high) => high.into_parts(),
132 InnerEncoder::Low(low) => low.into_parts(),
133 InnerEncoder::None => unreachable!(),
134 }
135 }
136
137 fn new(
138 original_count: usize,
139 recovery_count: usize,
140 shard_bytes: usize,
141 engine: E,
142 work: Option<EncoderWork>,
143 ) -> Result<Self, Error> {
144 let inner = if use_high_rate(original_count, recovery_count)? {
145 InnerEncoder::High(HighRateEncoder::new(
146 original_count,
147 recovery_count,
148 shard_bytes,
149 engine,
150 work,
151 )?)
152 } else {
153 InnerEncoder::Low(LowRateEncoder::new(
154 original_count,
155 recovery_count,
156 shard_bytes,
157 engine,
158 work,
159 )?)
160 };
161
162 Ok(Self(inner))
163 }
164
165 fn reset(
166 &mut self,
167 original_count: usize,
168 recovery_count: usize,
169 shard_bytes: usize,
170 ) -> Result<(), Error> {
171 let new_rate_is_high = use_high_rate(original_count, recovery_count)?;
172
173 self.0 = match std::mem::take(&mut self.0) {
174 InnerEncoder::High(mut high) => {
175 if new_rate_is_high {
176 high.reset(original_count, recovery_count, shard_bytes)?;
177 InnerEncoder::High(high)
178 } else {
179 let (engine, work) = high.into_parts();
180 InnerEncoder::Low(LowRateEncoder::new(
181 original_count,
182 recovery_count,
183 shard_bytes,
184 engine,
185 Some(work),
186 )?)
187 }
188 }
189
190 InnerEncoder::Low(mut low) => {
191 if new_rate_is_high {
192 let (engine, work) = low.into_parts();
193 InnerEncoder::High(HighRateEncoder::new(
194 original_count,
195 recovery_count,
196 shard_bytes,
197 engine,
198 Some(work),
199 )?)
200 } else {
201 low.reset(original_count, recovery_count, shard_bytes)?;
202 InnerEncoder::Low(low)
203 }
204 }
205
206 InnerEncoder::None => unreachable!(),
207 };
208
209 Ok(())
210 }
211}
212
213enum InnerDecoder<E: Engine> {
217 High(HighRateDecoder<E>),
218 Low(LowRateDecoder<E>),
219
220 None,
222}
223
224impl<E: Engine> Default for InnerDecoder<E> {
225 fn default() -> Self {
226 InnerDecoder::None
227 }
228}
229
230pub struct DefaultRateDecoder<E: Engine>(InnerDecoder<E>);
241
242impl<E: Engine> RateDecoder<E> for DefaultRateDecoder<E> {
243 type Rate = DefaultRate<E>;
244
245 fn add_original_shard<T: AsRef<[u8]>>(
246 &mut self,
247 index: usize,
248 original_shard: T,
249 ) -> Result<(), Error> {
250 match &mut self.0 {
251 InnerDecoder::High(high) => high.add_original_shard(index, original_shard),
252 InnerDecoder::Low(low) => low.add_original_shard(index, original_shard),
253 InnerDecoder::None => unreachable!(),
254 }
255 }
256
257 fn add_recovery_shard<T: AsRef<[u8]>>(
258 &mut self,
259 index: usize,
260 recovery_shard: T,
261 ) -> Result<(), Error> {
262 match &mut self.0 {
263 InnerDecoder::High(high) => high.add_recovery_shard(index, recovery_shard),
264 InnerDecoder::Low(low) => low.add_recovery_shard(index, recovery_shard),
265 InnerDecoder::None => unreachable!(),
266 }
267 }
268
269 fn decode(&mut self) -> Result<DecoderResult, Error> {
270 match &mut self.0 {
271 InnerDecoder::High(high) => high.decode(),
272 InnerDecoder::Low(low) => low.decode(),
273 InnerDecoder::None => unreachable!(),
274 }
275 }
276
277 fn into_parts(self) -> (E, DecoderWork) {
278 match self.0 {
279 InnerDecoder::High(high) => high.into_parts(),
280 InnerDecoder::Low(low) => low.into_parts(),
281 InnerDecoder::None => unreachable!(),
282 }
283 }
284
285 fn new(
286 original_count: usize,
287 recovery_count: usize,
288 shard_bytes: usize,
289 engine: E,
290 work: Option<DecoderWork>,
291 ) -> Result<Self, Error> {
292 let inner = if use_high_rate(original_count, recovery_count)? {
293 InnerDecoder::High(HighRateDecoder::new(
294 original_count,
295 recovery_count,
296 shard_bytes,
297 engine,
298 work,
299 )?)
300 } else {
301 InnerDecoder::Low(LowRateDecoder::new(
302 original_count,
303 recovery_count,
304 shard_bytes,
305 engine,
306 work,
307 )?)
308 };
309
310 Ok(Self(inner))
311 }
312
313 fn reset(
314 &mut self,
315 original_count: usize,
316 recovery_count: usize,
317 shard_bytes: usize,
318 ) -> Result<(), Error> {
319 let new_rate_is_high = use_high_rate(original_count, recovery_count)?;
320
321 self.0 = match std::mem::take(&mut self.0) {
322 InnerDecoder::High(mut high) => {
323 if new_rate_is_high {
324 high.reset(original_count, recovery_count, shard_bytes)?;
325 InnerDecoder::High(high)
326 } else {
327 let (engine, work) = high.into_parts();
328 InnerDecoder::Low(LowRateDecoder::new(
329 original_count,
330 recovery_count,
331 shard_bytes,
332 engine,
333 Some(work),
334 )?)
335 }
336 }
337
338 InnerDecoder::Low(mut low) => {
339 if new_rate_is_high {
340 let (engine, work) = low.into_parts();
341 InnerDecoder::High(HighRateDecoder::new(
342 original_count,
343 recovery_count,
344 shard_bytes,
345 engine,
346 Some(work),
347 )?)
348 } else {
349 low.reset(original_count, recovery_count, shard_bytes)?;
350 InnerDecoder::Low(low)
351 }
352 }
353
354 InnerDecoder::None => unreachable!(),
355 };
356
357 Ok(())
358 }
359}
360
361#[cfg(test)]
365mod tests {
366 use super::*;
367 use crate::test_util;
368
369 #[test]
373 fn roundtrips_tiny() {
374 for (original_count, recovery_count, seed, recovery_hash) in test_util::DEFAULT_TINY {
375 roundtrip_single!(
376 DefaultRate,
377 *original_count,
378 *recovery_count,
379 1024,
380 recovery_hash,
381 &[*recovery_count..*original_count],
382 &[0..std::cmp::min(*original_count, *recovery_count)],
383 *seed,
384 );
385 }
386 }
387
388 #[test]
392 fn two_rounds_implicit_reset() {
393 roundtrip_two_rounds!(
394 DefaultRate,
395 false,
396 (2, 3, 1024, test_util::LOW_2_3, &[], &[0, 2], 123),
397 (2, 3, 1024, test_util::LOW_2_3_223, &[0], &[1], 223),
398 );
399 }
400
401 #[test]
402 fn two_rounds_reset_high_to_high() {
403 roundtrip_two_rounds!(
404 DefaultRate,
405 true,
406 (3, 2, 1024, test_util::HIGH_3_2, &[1], &[0, 1], 132),
407 (5, 3, 1024, test_util::HIGH_5_3, &[1, 3], &[0, 1, 2], 153),
408 );
409 }
410
411 #[test]
412 fn two_rounds_reset_high_to_low() {
413 roundtrip_two_rounds!(
414 DefaultRate,
415 true,
416 (3, 2, 1024, test_util::HIGH_3_2, &[1], &[0, 1], 132),
417 (2, 3, 1024, test_util::LOW_2_3, &[], &[0, 2], 123),
418 );
419 }
420
421 #[test]
422 fn two_rounds_reset_low_to_high() {
423 roundtrip_two_rounds!(
424 DefaultRate,
425 true,
426 (2, 3, 1024, test_util::LOW_2_3, &[], &[0, 1], 123),
427 (3, 2, 1024, test_util::HIGH_3_2, &[1], &[0, 1], 132),
428 );
429 }
430
431 #[test]
432 fn two_rounds_reset_low_to_low() {
433 roundtrip_two_rounds!(
434 DefaultRate,
435 true,
436 (2, 3, 1024, test_util::LOW_2_3, &[], &[0, 2], 123),
437 (3, 5, 1024, test_util::LOW_3_5, &[], &[0, 2, 4], 135),
438 );
439 }
440
441 #[test]
445 fn use_high_rate() {
446 fn err(original_count: usize, recovery_count: usize) -> Result<bool, Error> {
447 Err(Error::UnsupportedShardCount {
448 original_count,
449 recovery_count,
450 })
451 }
452
453 for (original_count, recovery_count, expected) in [
454 (0, 1, err(0, 1)),
455 (1, 0, err(1, 0)),
456 (3, 3, Ok(true)),
458 (3, 4, Ok(true)),
459 (3, 5, Ok(false)),
460 (4, 3, Ok(false)),
461 (5, 3, Ok(true)),
462 (4096, 61440, Ok(false)),
464 (4096, 61441, err(4096, 61441)),
465 (4097, 61440, err(4097, 61440)),
466 (61440, 4096, Ok(true)),
468 (61440, 4097, err(61440, 4097)),
469 (61441, 4096, err(61441, 4096)),
470 (usize::MAX, usize::MAX, err(usize::MAX, usize::MAX)),
472 ] {
473 assert_eq!(
474 super::use_high_rate(original_count, recovery_count),
475 expected
476 );
477 }
478 }
479}