1use std::{cmp::Ordering, marker::PhantomData};
2
3use crate::{
4 engine::{Engine, GF_ORDER},
5 rate::{
6 DecoderWork, EncoderWork, HighRateDecoder, HighRateEncoder, LowRateDecoder, LowRateEncoder,
7 Rate, RateDecoder, RateEncoder,
8 },
9 DecoderResult, EncoderResult, Error,
10};
11
12fn use_high_rate(original_count: usize, recovery_count: usize) -> Result<bool, Error> {
16 if original_count > GF_ORDER || recovery_count > GF_ORDER {
17 return Err(Error::UnsupportedShardCount {
18 original_count,
19 recovery_count,
20 });
21 }
22
23 let original_count_pow2 = original_count.next_power_of_two();
24 let recovery_count_pow2 = recovery_count.next_power_of_two();
25
26 let smaller_pow2 = std::cmp::min(original_count_pow2, recovery_count_pow2);
27 let larger = std::cmp::max(original_count, recovery_count);
28
29 if original_count == 0 || recovery_count == 0 || smaller_pow2 + larger > GF_ORDER {
30 return Err(Error::UnsupportedShardCount {
31 original_count,
32 recovery_count,
33 });
34 }
35
36 match original_count_pow2.cmp(&recovery_count_pow2) {
37 Ordering::Less => {
38 Ok(false)
42 }
43
44 Ordering::Greater => {
45 Ok(true)
49 }
50
51 Ordering::Equal => {
52 if original_count <= recovery_count {
56 Ok(true)
58 } else {
59 Ok(false)
61 }
62 }
63 }
64}
65
66pub struct DefaultRate<E: Engine>(PhantomData<E>);
71
72impl<E: Engine> Rate<E> for DefaultRate<E> {
73 type RateEncoder = DefaultRateEncoder<E>;
74 type RateDecoder = DefaultRateDecoder<E>;
75
76 fn supports(original_count: usize, recovery_count: usize) -> bool {
77 use_high_rate(original_count, recovery_count).is_ok()
78 }
79}
80
81#[derive(Default)]
85enum InnerEncoder<E: Engine> {
86 High(HighRateEncoder<E>),
87 Low(LowRateEncoder<E>),
88
89 #[default]
91 None,
92}
93
94pub struct DefaultRateEncoder<E: Engine>(InnerEncoder<E>);
105
106impl<E: Engine> RateEncoder<E> for DefaultRateEncoder<E> {
107 type Rate = DefaultRate<E>;
108
109 fn add_original_shard<T: AsRef<[u8]>>(&mut self, original_shard: T) -> Result<(), Error> {
110 match &mut self.0 {
111 InnerEncoder::High(high) => high.add_original_shard(original_shard),
112 InnerEncoder::Low(low) => low.add_original_shard(original_shard),
113 InnerEncoder::None => unreachable!(),
114 }
115 }
116
117 fn encode(&mut self) -> Result<EncoderResult, Error> {
118 match &mut self.0 {
119 InnerEncoder::High(high) => high.encode(),
120 InnerEncoder::Low(low) => low.encode(),
121 InnerEncoder::None => unreachable!(),
122 }
123 }
124
125 fn into_parts(self) -> (E, EncoderWork) {
126 match self.0 {
127 InnerEncoder::High(high) => high.into_parts(),
128 InnerEncoder::Low(low) => low.into_parts(),
129 InnerEncoder::None => unreachable!(),
130 }
131 }
132
133 fn new(
134 original_count: usize,
135 recovery_count: usize,
136 shard_bytes: usize,
137 engine: E,
138 work: Option<EncoderWork>,
139 ) -> Result<Self, Error> {
140 let inner = if use_high_rate(original_count, recovery_count)? {
141 InnerEncoder::High(HighRateEncoder::new(
142 original_count,
143 recovery_count,
144 shard_bytes,
145 engine,
146 work,
147 )?)
148 } else {
149 InnerEncoder::Low(LowRateEncoder::new(
150 original_count,
151 recovery_count,
152 shard_bytes,
153 engine,
154 work,
155 )?)
156 };
157
158 Ok(Self(inner))
159 }
160
161 fn reset(
162 &mut self,
163 original_count: usize,
164 recovery_count: usize,
165 shard_bytes: usize,
166 ) -> Result<(), Error> {
167 let new_rate_is_high = use_high_rate(original_count, recovery_count)?;
168
169 self.0 = match std::mem::take(&mut self.0) {
170 InnerEncoder::High(mut high) => {
171 if new_rate_is_high {
172 high.reset(original_count, recovery_count, shard_bytes)?;
173 InnerEncoder::High(high)
174 } else {
175 let (engine, work) = high.into_parts();
176 InnerEncoder::Low(LowRateEncoder::new(
177 original_count,
178 recovery_count,
179 shard_bytes,
180 engine,
181 Some(work),
182 )?)
183 }
184 }
185
186 InnerEncoder::Low(mut low) => {
187 if new_rate_is_high {
188 let (engine, work) = low.into_parts();
189 InnerEncoder::High(HighRateEncoder::new(
190 original_count,
191 recovery_count,
192 shard_bytes,
193 engine,
194 Some(work),
195 )?)
196 } else {
197 low.reset(original_count, recovery_count, shard_bytes)?;
198 InnerEncoder::Low(low)
199 }
200 }
201
202 InnerEncoder::None => unreachable!(),
203 };
204
205 Ok(())
206 }
207}
208
209#[derive(Default)]
213enum InnerDecoder<E: Engine> {
214 High(HighRateDecoder<E>),
215 Low(LowRateDecoder<E>),
216
217 #[default]
219 None,
220}
221
222pub struct DefaultRateDecoder<E: Engine>(InnerDecoder<E>);
233
234impl<E: Engine> RateDecoder<E> for DefaultRateDecoder<E> {
235 type Rate = DefaultRate<E>;
236
237 fn add_original_shard<T: AsRef<[u8]>>(
238 &mut self,
239 index: usize,
240 original_shard: T,
241 ) -> Result<(), Error> {
242 match &mut self.0 {
243 InnerDecoder::High(high) => high.add_original_shard(index, original_shard),
244 InnerDecoder::Low(low) => low.add_original_shard(index, original_shard),
245 InnerDecoder::None => unreachable!(),
246 }
247 }
248
249 fn add_recovery_shard<T: AsRef<[u8]>>(
250 &mut self,
251 index: usize,
252 recovery_shard: T,
253 ) -> Result<(), Error> {
254 match &mut self.0 {
255 InnerDecoder::High(high) => high.add_recovery_shard(index, recovery_shard),
256 InnerDecoder::Low(low) => low.add_recovery_shard(index, recovery_shard),
257 InnerDecoder::None => unreachable!(),
258 }
259 }
260
261 fn decode(&mut self) -> Result<DecoderResult, Error> {
262 match &mut self.0 {
263 InnerDecoder::High(high) => high.decode(),
264 InnerDecoder::Low(low) => low.decode(),
265 InnerDecoder::None => unreachable!(),
266 }
267 }
268
269 fn into_parts(self) -> (E, DecoderWork) {
270 match self.0 {
271 InnerDecoder::High(high) => high.into_parts(),
272 InnerDecoder::Low(low) => low.into_parts(),
273 InnerDecoder::None => unreachable!(),
274 }
275 }
276
277 fn new(
278 original_count: usize,
279 recovery_count: usize,
280 shard_bytes: usize,
281 engine: E,
282 work: Option<DecoderWork>,
283 ) -> Result<Self, Error> {
284 let inner = if use_high_rate(original_count, recovery_count)? {
285 InnerDecoder::High(HighRateDecoder::new(
286 original_count,
287 recovery_count,
288 shard_bytes,
289 engine,
290 work,
291 )?)
292 } else {
293 InnerDecoder::Low(LowRateDecoder::new(
294 original_count,
295 recovery_count,
296 shard_bytes,
297 engine,
298 work,
299 )?)
300 };
301
302 Ok(Self(inner))
303 }
304
305 fn reset(
306 &mut self,
307 original_count: usize,
308 recovery_count: usize,
309 shard_bytes: usize,
310 ) -> Result<(), Error> {
311 let new_rate_is_high = use_high_rate(original_count, recovery_count)?;
312
313 self.0 = match std::mem::take(&mut self.0) {
314 InnerDecoder::High(mut high) => {
315 if new_rate_is_high {
316 high.reset(original_count, recovery_count, shard_bytes)?;
317 InnerDecoder::High(high)
318 } else {
319 let (engine, work) = high.into_parts();
320 InnerDecoder::Low(LowRateDecoder::new(
321 original_count,
322 recovery_count,
323 shard_bytes,
324 engine,
325 Some(work),
326 )?)
327 }
328 }
329
330 InnerDecoder::Low(mut low) => {
331 if new_rate_is_high {
332 let (engine, work) = low.into_parts();
333 InnerDecoder::High(HighRateDecoder::new(
334 original_count,
335 recovery_count,
336 shard_bytes,
337 engine,
338 Some(work),
339 )?)
340 } else {
341 low.reset(original_count, recovery_count, shard_bytes)?;
342 InnerDecoder::Low(low)
343 }
344 }
345
346 InnerDecoder::None => unreachable!(),
347 };
348
349 Ok(())
350 }
351}
352
353#[cfg(test)]
357mod tests {
358 use super::*;
359 use crate::test_util;
360
361 #[test]
365 fn roundtrips_tiny() {
366 for (original_count, recovery_count, seed, recovery_hash) in test_util::DEFAULT_TINY {
367 roundtrip_single!(
368 DefaultRate,
369 *original_count,
370 *recovery_count,
371 1024,
372 recovery_hash,
373 &[*recovery_count..*original_count],
374 &[0..std::cmp::min(*original_count, *recovery_count)],
375 *seed,
376 );
377 }
378 }
379
380 #[test]
384 fn two_rounds_implicit_reset() {
385 roundtrip_two_rounds!(
386 DefaultRate,
387 false,
388 (2, 3, 1024, test_util::LOW_2_3, &[], &[0, 2], 123),
389 (2, 3, 1024, test_util::LOW_2_3_223, &[0], &[1], 223),
390 );
391 }
392
393 #[test]
394 fn two_rounds_reset_high_to_high() {
395 roundtrip_two_rounds!(
396 DefaultRate,
397 true,
398 (3, 2, 1024, test_util::HIGH_3_2, &[1], &[0, 1], 132),
399 (5, 3, 1024, test_util::HIGH_5_3, &[1, 3], &[0, 1, 2], 153),
400 );
401 }
402
403 #[test]
404 fn two_rounds_reset_high_to_low() {
405 roundtrip_two_rounds!(
406 DefaultRate,
407 true,
408 (3, 2, 1024, test_util::HIGH_3_2, &[1], &[0, 1], 132),
409 (2, 3, 1024, test_util::LOW_2_3, &[], &[0, 2], 123),
410 );
411 }
412
413 #[test]
414 fn two_rounds_reset_low_to_high() {
415 roundtrip_two_rounds!(
416 DefaultRate,
417 true,
418 (2, 3, 1024, test_util::LOW_2_3, &[], &[0, 1], 123),
419 (3, 2, 1024, test_util::HIGH_3_2, &[1], &[0, 1], 132),
420 );
421 }
422
423 #[test]
424 fn two_rounds_reset_low_to_low() {
425 roundtrip_two_rounds!(
426 DefaultRate,
427 true,
428 (2, 3, 1024, test_util::LOW_2_3, &[], &[0, 2], 123),
429 (3, 5, 1024, test_util::LOW_3_5, &[], &[0, 2, 4], 135),
430 );
431 }
432
433 #[test]
437 fn use_high_rate() {
438 fn err(original_count: usize, recovery_count: usize) -> Result<bool, Error> {
439 Err(Error::UnsupportedShardCount {
440 original_count,
441 recovery_count,
442 })
443 }
444
445 for (original_count, recovery_count, expected) in [
446 (0, 1, err(0, 1)),
447 (1, 0, err(1, 0)),
448 (3, 3, Ok(true)),
450 (3, 4, Ok(true)),
451 (3, 5, Ok(false)),
452 (4, 3, Ok(false)),
453 (5, 3, Ok(true)),
454 (4096, 61440, Ok(false)),
456 (4096, 61441, err(4096, 61441)),
457 (4097, 61440, err(4097, 61440)),
458 (61440, 4096, Ok(true)),
460 (61440, 4097, err(61440, 4097)),
461 (61441, 4096, err(61441, 4096)),
462 (usize::MAX, usize::MAX, err(usize::MAX, usize::MAX)),
464 ] {
465 assert_eq!(
466 super::use_high_rate(original_count, recovery_count),
467 expected
468 );
469 }
470 }
471}