reed_solomon_simd/engine/
shards.rs

1use std::ops::{Bound, Index, IndexMut, Range, RangeBounds};
2
3// ======================================================================
4// Shards - CRATE
5
6pub(crate) struct Shards {
7    shard_count: usize,
8    // Shard length in 64 byte chunks
9    shard_len_64: usize,
10
11    // Flat Vec of `shard_count * shard_len_64 * 64` bytes.
12    data: Vec<[u8; 64]>,
13}
14
15impl Shards {
16    pub(crate) fn as_ref_mut(&mut self) -> ShardsRefMut {
17        ShardsRefMut::new(self.shard_count, self.shard_len_64, self.data.as_mut())
18    }
19
20    pub(crate) fn new() -> Self {
21        Self {
22            shard_count: 0,
23            shard_len_64: 0,
24            data: Vec::new(),
25        }
26    }
27
28    pub(crate) fn resize(&mut self, shard_count: usize, shard_len_64: usize) {
29        self.shard_count = shard_count;
30        self.shard_len_64 = shard_len_64;
31
32        self.data
33            .resize(self.shard_count * self.shard_len_64, [0; 64]);
34    }
35
36    pub(crate) fn insert(&mut self, index: usize, shard: &[u8]) {
37        debug_assert_eq!(shard.len() % 2, 0);
38
39        let whole_chunk_count = shard.len() / 64;
40        let tail_len = shard.len() % 64;
41
42        let (src_chunks, src_tail) = shard.split_at(shard.len() - tail_len);
43
44        let dst = &mut self[index];
45        dst[..whole_chunk_count]
46            .as_flattened_mut()
47            .copy_from_slice(src_chunks);
48
49        // Last chunk is special if shard.len() % 64 != 0.
50        // See src/algorithm.md for an explanation.
51        if tail_len > 0 {
52            let (src_lo, src_hi) = src_tail.split_at(tail_len / 2);
53            let (dst_lo, dst_hi) = dst[whole_chunk_count].split_at_mut(32);
54            dst_lo[..src_lo.len()].copy_from_slice(src_lo);
55            dst_hi[..src_hi.len()].copy_from_slice(src_hi);
56        }
57    }
58
59    // Undoes the encoding of the last chunk for the given range of shards
60    pub(crate) fn undo_last_chunk_encoding(&mut self, shard_bytes: usize, range: Range<usize>) {
61        let whole_chunk_count = shard_bytes / 64;
62        let tail_len = shard_bytes % 64;
63
64        if tail_len == 0 {
65            return;
66        };
67
68        for idx in range {
69            let last_chunk = &mut self[idx][whole_chunk_count];
70            last_chunk.copy_within(32..32 + tail_len / 2, tail_len / 2);
71        }
72    }
73}
74
75// ======================================================================
76// Shards - IMPL Index
77
78impl Index<usize> for Shards {
79    type Output = [[u8; 64]];
80    fn index(&self, index: usize) -> &Self::Output {
81        &self.data[index * self.shard_len_64..(index + 1) * self.shard_len_64]
82    }
83}
84
85// ======================================================================
86// Shards - IMPL IndexMut
87
88impl IndexMut<usize> for Shards {
89    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
90        &mut self.data[index * self.shard_len_64..(index + 1) * self.shard_len_64]
91    }
92}
93
94// ======================================================================
95// ShardsRefMut - PUBLIC
96
97/// Mutable reference to a shard array.
98pub struct ShardsRefMut<'a> {
99    shard_count: usize,
100    shard_len_64: usize,
101
102    data: &'a mut [[u8; 64]],
103}
104
105impl<'a> ShardsRefMut<'a> {
106    /// Returns mutable references to shards at `pos` and `pos + dist`.
107    ///
108    /// See source code of [`Naive::fft`] for an example.
109    ///
110    /// # Panics
111    ///
112    /// If `dist` is `0`.
113    ///
114    /// [`Naive::fft`]: crate::engine::Naive#method.fft
115    pub fn dist2_mut(
116        &mut self,
117        mut pos: usize,
118        mut dist: usize,
119    ) -> (&mut [[u8; 64]], &mut [[u8; 64]]) {
120        pos *= self.shard_len_64;
121        dist *= self.shard_len_64;
122
123        let (a, b) = self.data[pos..].split_at_mut(dist);
124        (&mut a[..self.shard_len_64], &mut b[..self.shard_len_64])
125    }
126
127    /// Returns mutable references to shards at
128    /// `pos`, `pos + dist`, `pos + dist * 2` and `pos + dist * 3`.
129    ///
130    /// See source code of [`NoSimd::fft`] for an example
131    /// (specifically the private method `fft_butterfly_two_layers`).
132    ///
133    /// # Panics
134    ///
135    /// If `dist` is `0`.
136    ///
137    /// [`NoSimd::fft`]: crate::engine::NoSimd#method.fft
138    #[allow(clippy::type_complexity)]
139    pub fn dist4_mut(
140        &mut self,
141        mut pos: usize,
142        mut dist: usize,
143    ) -> (
144        &mut [[u8; 64]],
145        &mut [[u8; 64]],
146        &mut [[u8; 64]],
147        &mut [[u8; 64]],
148    ) {
149        pos *= self.shard_len_64;
150        dist *= self.shard_len_64;
151
152        let (ab, cd) = self.data[pos..].split_at_mut(dist * 2);
153        let (a, b) = ab.split_at_mut(dist);
154        let (c, d) = cd.split_at_mut(dist);
155
156        (
157            &mut a[..self.shard_len_64],
158            &mut b[..self.shard_len_64],
159            &mut c[..self.shard_len_64],
160            &mut d[..self.shard_len_64],
161        )
162    }
163
164    /// Returns `true` if this contains no shards.
165    pub fn is_empty(&self) -> bool {
166        self.shard_count == 0
167    }
168
169    /// Returns number of shards.
170    pub fn len(&self) -> usize {
171        self.shard_count
172    }
173
174    /// Creates new [`ShardsRefMut`] that references given `data`.
175    ///
176    /// # Panics
177    ///
178    /// If `data.len() < shard_count * shard_len_64`.
179    pub fn new(shard_count: usize, shard_len_64: usize, data: &'a mut [[u8; 64]]) -> Self {
180        assert!(data.len() >= shard_count * shard_len_64);
181
182        Self {
183            shard_count,
184            shard_len_64,
185            data: &mut data[..shard_count * shard_len_64],
186        }
187    }
188
189    /// Splits this [`ShardsRefMut`] into two so that
190    /// first includes shards `0..mid` and second includes shards `mid..`.
191    pub fn split_at_mut(&mut self, mid: usize) -> (ShardsRefMut, ShardsRefMut) {
192        let (a, b) = self.data.split_at_mut(mid * self.shard_len_64);
193
194        (
195            ShardsRefMut::new(mid, self.shard_len_64, a),
196            ShardsRefMut::new(self.shard_count - mid, self.shard_len_64, b),
197        )
198    }
199
200    /// Fills the given shard-range with `0u8`:s.
201    pub fn zero<R: RangeBounds<usize>>(&mut self, range: R) {
202        let start = match range.start_bound() {
203            Bound::Included(start) => start * self.shard_len_64,
204            Bound::Excluded(start) => (start + 1) * self.shard_len_64,
205            Bound::Unbounded => 0,
206        };
207
208        let end = match range.end_bound() {
209            Bound::Included(end) => (end + 1) * self.shard_len_64,
210            Bound::Excluded(end) => end * self.shard_len_64,
211            Bound::Unbounded => self.shard_count * self.shard_len_64,
212        };
213
214        self.data[start..end].fill([0; 64]);
215    }
216}
217
218// ======================================================================
219// ShardsRefMut - IMPL Index
220
221impl Index<usize> for ShardsRefMut<'_> {
222    type Output = [[u8; 64]];
223    fn index(&self, index: usize) -> &Self::Output {
224        &self.data[index * self.shard_len_64..(index + 1) * self.shard_len_64]
225    }
226}
227
228// ======================================================================
229// ShardsRefMut - IMPL IndexMut
230
231impl IndexMut<usize> for ShardsRefMut<'_> {
232    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
233        &mut self.data[index * self.shard_len_64..(index + 1) * self.shard_len_64]
234    }
235}
236
237// ======================================================================
238// ShardsRefMut - CRATE
239
240impl ShardsRefMut<'_> {
241    pub(crate) fn copy_within(&mut self, mut src: usize, mut dest: usize, mut count: usize) {
242        src *= self.shard_len_64;
243        dest *= self.shard_len_64;
244        count *= self.shard_len_64;
245
246        self.data.copy_within(src..src + count, dest);
247    }
248
249    // Returns mutable references to flat-arrays of shard-ranges
250    // `x .. x + count` and `y .. y + count`.
251    //
252    // Ranges must not overlap.
253    pub(crate) fn flat2_mut(
254        &mut self,
255        mut x: usize,
256        mut y: usize,
257        mut count: usize,
258    ) -> (&mut [[u8; 64]], &mut [[u8; 64]]) {
259        x *= self.shard_len_64;
260        y *= self.shard_len_64;
261        count *= self.shard_len_64;
262
263        if x < y {
264            let (head, tail) = self.data.split_at_mut(y);
265            (&mut head[x..x + count], &mut tail[..count])
266        } else {
267            let (head, tail) = self.data.split_at_mut(x);
268            (&mut tail[..count], &mut head[y..y + count])
269        }
270    }
271}