reed_solomon_simd/rate/
rate_default.rs

1use std::{cmp::Ordering, marker::PhantomData};
2
3use crate::{
4    engine::{Engine, GF_ORDER},
5    rate::{
6        DecoderWork, EncoderWork, HighRateDecoder, HighRateEncoder, LowRateDecoder, LowRateEncoder,
7        Rate, RateDecoder, RateEncoder,
8    },
9    DecoderResult, EncoderResult, Error,
10};
11
12// ======================================================================
13// FUNCTIONS - PRIVATE
14
15fn use_high_rate(original_count: usize, recovery_count: usize) -> Result<bool, Error> {
16    if original_count > GF_ORDER || recovery_count > GF_ORDER {
17        return Err(Error::UnsupportedShardCount {
18            original_count,
19            recovery_count,
20        });
21    }
22
23    let original_count_pow2 = original_count.next_power_of_two();
24    let recovery_count_pow2 = recovery_count.next_power_of_two();
25
26    let smaller_pow2 = std::cmp::min(original_count_pow2, recovery_count_pow2);
27    let larger = std::cmp::max(original_count, recovery_count);
28
29    if original_count == 0 || recovery_count == 0 || smaller_pow2 + larger > GF_ORDER {
30        return Err(Error::UnsupportedShardCount {
31            original_count,
32            recovery_count,
33        });
34    }
35
36    match original_count_pow2.cmp(&recovery_count_pow2) {
37        Ordering::Less => {
38            // The "correct" rate is generally faster here,
39            // and also must be used if `recovery_count > 32768`.
40
41            Ok(false)
42        }
43
44        Ordering::Greater => {
45            // The "correct" rate is generally faster here,
46            // and also must be used if `original_count > 32768`.
47
48            Ok(true)
49        }
50
51        Ordering::Equal => {
52            // Here counter-intuitively the "wrong" rate is generally faster
53            // in decoding if `original_count` and `recovery_count` differ a lot.
54
55            if original_count <= recovery_count {
56                // Using the "wrong" rate on purpose.
57                Ok(true)
58            } else {
59                // Using the "wrong" rate on purpose.
60                Ok(false)
61            }
62        }
63    }
64}
65
66// ======================================================================
67// DefaultRate - PUBLIC
68
69/// Reed-Solomon encoder/decoder generator using high or low rate as appropriate.
70pub struct DefaultRate<E: Engine>(PhantomData<E>);
71
72impl<E: Engine> Rate<E> for DefaultRate<E> {
73    type RateEncoder = DefaultRateEncoder<E>;
74    type RateDecoder = DefaultRateDecoder<E>;
75
76    fn supports(original_count: usize, recovery_count: usize) -> bool {
77        use_high_rate(original_count, recovery_count).is_ok()
78    }
79}
80
81// ======================================================================
82// InnerEncoder - PRIVATE
83
84#[derive(Default)]
85enum InnerEncoder<E: Engine> {
86    High(HighRateEncoder<E>),
87    Low(LowRateEncoder<E>),
88
89    // This is only used temporarily during `reset`, never anywhere else.
90    #[default]
91    None,
92}
93
94// ======================================================================
95// DefaultRateEncoder - PUBLIC
96
97/// Reed-Solomon encoder using high or low rate as appropriate.
98///
99/// This is basically same as [`ReedSolomonEncoder`]
100/// except with slightly different API which allows
101/// specifying [`Engine`] and [`EncoderWork`].
102///
103/// [`ReedSolomonEncoder`]: crate::ReedSolomonEncoder
104pub struct DefaultRateEncoder<E: Engine>(InnerEncoder<E>);
105
106impl<E: Engine> RateEncoder<E> for DefaultRateEncoder<E> {
107    type Rate = DefaultRate<E>;
108
109    fn add_original_shard<T: AsRef<[u8]>>(&mut self, original_shard: T) -> Result<(), Error> {
110        match &mut self.0 {
111            InnerEncoder::High(high) => high.add_original_shard(original_shard),
112            InnerEncoder::Low(low) => low.add_original_shard(original_shard),
113            InnerEncoder::None => unreachable!(),
114        }
115    }
116
117    fn encode(&mut self) -> Result<EncoderResult, Error> {
118        match &mut self.0 {
119            InnerEncoder::High(high) => high.encode(),
120            InnerEncoder::Low(low) => low.encode(),
121            InnerEncoder::None => unreachable!(),
122        }
123    }
124
125    fn into_parts(self) -> (E, EncoderWork) {
126        match self.0 {
127            InnerEncoder::High(high) => high.into_parts(),
128            InnerEncoder::Low(low) => low.into_parts(),
129            InnerEncoder::None => unreachable!(),
130        }
131    }
132
133    fn new(
134        original_count: usize,
135        recovery_count: usize,
136        shard_bytes: usize,
137        engine: E,
138        work: Option<EncoderWork>,
139    ) -> Result<Self, Error> {
140        let inner = if use_high_rate(original_count, recovery_count)? {
141            InnerEncoder::High(HighRateEncoder::new(
142                original_count,
143                recovery_count,
144                shard_bytes,
145                engine,
146                work,
147            )?)
148        } else {
149            InnerEncoder::Low(LowRateEncoder::new(
150                original_count,
151                recovery_count,
152                shard_bytes,
153                engine,
154                work,
155            )?)
156        };
157
158        Ok(Self(inner))
159    }
160
161    fn reset(
162        &mut self,
163        original_count: usize,
164        recovery_count: usize,
165        shard_bytes: usize,
166    ) -> Result<(), Error> {
167        let new_rate_is_high = use_high_rate(original_count, recovery_count)?;
168
169        self.0 = match std::mem::take(&mut self.0) {
170            InnerEncoder::High(mut high) => {
171                if new_rate_is_high {
172                    high.reset(original_count, recovery_count, shard_bytes)?;
173                    InnerEncoder::High(high)
174                } else {
175                    let (engine, work) = high.into_parts();
176                    InnerEncoder::Low(LowRateEncoder::new(
177                        original_count,
178                        recovery_count,
179                        shard_bytes,
180                        engine,
181                        Some(work),
182                    )?)
183                }
184            }
185
186            InnerEncoder::Low(mut low) => {
187                if new_rate_is_high {
188                    let (engine, work) = low.into_parts();
189                    InnerEncoder::High(HighRateEncoder::new(
190                        original_count,
191                        recovery_count,
192                        shard_bytes,
193                        engine,
194                        Some(work),
195                    )?)
196                } else {
197                    low.reset(original_count, recovery_count, shard_bytes)?;
198                    InnerEncoder::Low(low)
199                }
200            }
201
202            InnerEncoder::None => unreachable!(),
203        };
204
205        Ok(())
206    }
207}
208
209// ======================================================================
210// InnerDecoder - PRIVATE
211
212#[derive(Default)]
213enum InnerDecoder<E: Engine> {
214    High(HighRateDecoder<E>),
215    Low(LowRateDecoder<E>),
216
217    // This is only used temporarily during `reset`, never anywhere else.
218    #[default]
219    None,
220}
221
222// ======================================================================
223// DefaultRateDecoder - PUBLIC
224
225/// Reed-Solomon decoder using high or low rate as appropriate.
226///
227/// This is basically same as [`ReedSolomonDecoder`]
228/// except with slightly different API which allows
229/// specifying [`Engine`] and [`DecoderWork`].
230///
231/// [`ReedSolomonDecoder`]: crate::ReedSolomonDecoder
232pub struct DefaultRateDecoder<E: Engine>(InnerDecoder<E>);
233
234impl<E: Engine> RateDecoder<E> for DefaultRateDecoder<E> {
235    type Rate = DefaultRate<E>;
236
237    fn add_original_shard<T: AsRef<[u8]>>(
238        &mut self,
239        index: usize,
240        original_shard: T,
241    ) -> Result<(), Error> {
242        match &mut self.0 {
243            InnerDecoder::High(high) => high.add_original_shard(index, original_shard),
244            InnerDecoder::Low(low) => low.add_original_shard(index, original_shard),
245            InnerDecoder::None => unreachable!(),
246        }
247    }
248
249    fn add_recovery_shard<T: AsRef<[u8]>>(
250        &mut self,
251        index: usize,
252        recovery_shard: T,
253    ) -> Result<(), Error> {
254        match &mut self.0 {
255            InnerDecoder::High(high) => high.add_recovery_shard(index, recovery_shard),
256            InnerDecoder::Low(low) => low.add_recovery_shard(index, recovery_shard),
257            InnerDecoder::None => unreachable!(),
258        }
259    }
260
261    fn decode(&mut self) -> Result<DecoderResult, Error> {
262        match &mut self.0 {
263            InnerDecoder::High(high) => high.decode(),
264            InnerDecoder::Low(low) => low.decode(),
265            InnerDecoder::None => unreachable!(),
266        }
267    }
268
269    fn into_parts(self) -> (E, DecoderWork) {
270        match self.0 {
271            InnerDecoder::High(high) => high.into_parts(),
272            InnerDecoder::Low(low) => low.into_parts(),
273            InnerDecoder::None => unreachable!(),
274        }
275    }
276
277    fn new(
278        original_count: usize,
279        recovery_count: usize,
280        shard_bytes: usize,
281        engine: E,
282        work: Option<DecoderWork>,
283    ) -> Result<Self, Error> {
284        let inner = if use_high_rate(original_count, recovery_count)? {
285            InnerDecoder::High(HighRateDecoder::new(
286                original_count,
287                recovery_count,
288                shard_bytes,
289                engine,
290                work,
291            )?)
292        } else {
293            InnerDecoder::Low(LowRateDecoder::new(
294                original_count,
295                recovery_count,
296                shard_bytes,
297                engine,
298                work,
299            )?)
300        };
301
302        Ok(Self(inner))
303    }
304
305    fn reset(
306        &mut self,
307        original_count: usize,
308        recovery_count: usize,
309        shard_bytes: usize,
310    ) -> Result<(), Error> {
311        let new_rate_is_high = use_high_rate(original_count, recovery_count)?;
312
313        self.0 = match std::mem::take(&mut self.0) {
314            InnerDecoder::High(mut high) => {
315                if new_rate_is_high {
316                    high.reset(original_count, recovery_count, shard_bytes)?;
317                    InnerDecoder::High(high)
318                } else {
319                    let (engine, work) = high.into_parts();
320                    InnerDecoder::Low(LowRateDecoder::new(
321                        original_count,
322                        recovery_count,
323                        shard_bytes,
324                        engine,
325                        Some(work),
326                    )?)
327                }
328            }
329
330            InnerDecoder::Low(mut low) => {
331                if new_rate_is_high {
332                    let (engine, work) = low.into_parts();
333                    InnerDecoder::High(HighRateDecoder::new(
334                        original_count,
335                        recovery_count,
336                        shard_bytes,
337                        engine,
338                        Some(work),
339                    )?)
340                } else {
341                    low.reset(original_count, recovery_count, shard_bytes)?;
342                    InnerDecoder::Low(low)
343                }
344            }
345
346            InnerDecoder::None => unreachable!(),
347        };
348
349        Ok(())
350    }
351}
352
353// ======================================================================
354// TESTS
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359    use crate::test_util;
360
361    // ============================================================
362    // ROUNDTRIPS - SINGLE ROUND
363
364    #[test]
365    fn roundtrips_tiny() {
366        for (original_count, recovery_count, seed, recovery_hash) in test_util::DEFAULT_TINY {
367            roundtrip_single!(
368                DefaultRate,
369                *original_count,
370                *recovery_count,
371                1024,
372                recovery_hash,
373                &[*recovery_count..*original_count],
374                &[0..std::cmp::min(*original_count, *recovery_count)],
375                *seed,
376            );
377        }
378    }
379
380    // ============================================================
381    // ROUNDTRIPS - TWO ROUNDS
382
383    #[test]
384    fn two_rounds_implicit_reset() {
385        roundtrip_two_rounds!(
386            DefaultRate,
387            false,
388            (2, 3, 1024, test_util::LOW_2_3, &[], &[0, 2], 123),
389            (2, 3, 1024, test_util::LOW_2_3_223, &[0], &[1], 223),
390        );
391    }
392
393    #[test]
394    fn two_rounds_reset_high_to_high() {
395        roundtrip_two_rounds!(
396            DefaultRate,
397            true,
398            (3, 2, 1024, test_util::HIGH_3_2, &[1], &[0, 1], 132),
399            (5, 3, 1024, test_util::HIGH_5_3, &[1, 3], &[0, 1, 2], 153),
400        );
401    }
402
403    #[test]
404    fn two_rounds_reset_high_to_low() {
405        roundtrip_two_rounds!(
406            DefaultRate,
407            true,
408            (3, 2, 1024, test_util::HIGH_3_2, &[1], &[0, 1], 132),
409            (2, 3, 1024, test_util::LOW_2_3, &[], &[0, 2], 123),
410        );
411    }
412
413    #[test]
414    fn two_rounds_reset_low_to_high() {
415        roundtrip_two_rounds!(
416            DefaultRate,
417            true,
418            (2, 3, 1024, test_util::LOW_2_3, &[], &[0, 1], 123),
419            (3, 2, 1024, test_util::HIGH_3_2, &[1], &[0, 1], 132),
420        );
421    }
422
423    #[test]
424    fn two_rounds_reset_low_to_low() {
425        roundtrip_two_rounds!(
426            DefaultRate,
427            true,
428            (2, 3, 1024, test_util::LOW_2_3, &[], &[0, 2], 123),
429            (3, 5, 1024, test_util::LOW_3_5, &[], &[0, 2, 4], 135),
430        );
431    }
432
433    // ============================================================
434    // use_high_rate
435
436    #[test]
437    fn use_high_rate() {
438        fn err(original_count: usize, recovery_count: usize) -> Result<bool, Error> {
439            Err(Error::UnsupportedShardCount {
440                original_count,
441                recovery_count,
442            })
443        }
444
445        for (original_count, recovery_count, expected) in [
446            (0, 1, err(0, 1)),
447            (1, 0, err(1, 0)),
448            // CORRECT/WRONG RATE
449            (3, 3, Ok(true)),
450            (3, 4, Ok(true)),
451            (3, 5, Ok(false)),
452            (4, 3, Ok(false)),
453            (5, 3, Ok(true)),
454            // LOW RATE LIMIT
455            (4096, 61440, Ok(false)),
456            (4096, 61441, err(4096, 61441)),
457            (4097, 61440, err(4097, 61440)),
458            // HIGH RATE LIMIT
459            (61440, 4096, Ok(true)),
460            (61440, 4097, err(61440, 4097)),
461            (61441, 4096, err(61441, 4096)),
462            // OVERFLOW CHECK
463            (usize::MAX, usize::MAX, err(usize::MAX, usize::MAX)),
464        ] {
465            assert_eq!(
466                super::use_high_rate(original_count, recovery_count),
467                expected
468            );
469        }
470    }
471}