reed_solomon_simd/engine/
shards.rs1use std::ops::{Bound, Index, IndexMut, Range, RangeBounds};
2
3pub(crate) struct Shards {
7 shard_count: usize,
8 shard_len_64: usize,
10
11 data: Vec<[u8; 64]>,
13}
14
15impl Shards {
16 pub(crate) fn as_ref_mut(&mut self) -> ShardsRefMut {
17 ShardsRefMut::new(self.shard_count, self.shard_len_64, self.data.as_mut())
18 }
19
20 pub(crate) fn new() -> Self {
21 Self {
22 shard_count: 0,
23 shard_len_64: 0,
24 data: Vec::new(),
25 }
26 }
27
28 pub(crate) fn resize(&mut self, shard_count: usize, shard_len_64: usize) {
29 self.shard_count = shard_count;
30 self.shard_len_64 = shard_len_64;
31
32 self.data
33 .resize(self.shard_count * self.shard_len_64, [0; 64]);
34 }
35
36 pub(crate) fn insert(&mut self, index: usize, shard: &[u8]) {
37 debug_assert_eq!(shard.len() % 2, 0);
38
39 let whole_chunk_count = shard.len() / 64;
40 let tail_len = shard.len() % 64;
41
42 let (src_chunks, src_tail) = shard.split_at(shard.len() - tail_len);
43
44 let dst = &mut self[index];
45 dst[..whole_chunk_count]
46 .as_flattened_mut()
47 .copy_from_slice(src_chunks);
48
49 if tail_len > 0 {
52 let (src_lo, src_hi) = src_tail.split_at(tail_len / 2);
53 let (dst_lo, dst_hi) = dst[whole_chunk_count].split_at_mut(32);
54 dst_lo[..src_lo.len()].copy_from_slice(src_lo);
55 dst_hi[..src_hi.len()].copy_from_slice(src_hi);
56 }
57 }
58
59 pub(crate) fn undo_last_chunk_encoding(&mut self, shard_bytes: usize, range: Range<usize>) {
61 let whole_chunk_count = shard_bytes / 64;
62 let tail_len = shard_bytes % 64;
63
64 if tail_len == 0 {
65 return;
66 };
67
68 for idx in range {
69 let last_chunk = &mut self[idx][whole_chunk_count];
70 last_chunk.copy_within(32..32 + tail_len / 2, tail_len / 2);
71 }
72 }
73}
74
75impl Index<usize> for Shards {
79 type Output = [[u8; 64]];
80 fn index(&self, index: usize) -> &Self::Output {
81 &self.data[index * self.shard_len_64..(index + 1) * self.shard_len_64]
82 }
83}
84
85impl IndexMut<usize> for Shards {
89 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
90 &mut self.data[index * self.shard_len_64..(index + 1) * self.shard_len_64]
91 }
92}
93
94pub struct ShardsRefMut<'a> {
99 shard_count: usize,
100 shard_len_64: usize,
101
102 data: &'a mut [[u8; 64]],
103}
104
105impl<'a> ShardsRefMut<'a> {
106 pub fn dist2_mut(
116 &mut self,
117 mut pos: usize,
118 mut dist: usize,
119 ) -> (&mut [[u8; 64]], &mut [[u8; 64]]) {
120 pos *= self.shard_len_64;
121 dist *= self.shard_len_64;
122
123 let (a, b) = self.data[pos..].split_at_mut(dist);
124 (&mut a[..self.shard_len_64], &mut b[..self.shard_len_64])
125 }
126
127 #[allow(clippy::type_complexity)]
139 pub fn dist4_mut(
140 &mut self,
141 mut pos: usize,
142 mut dist: usize,
143 ) -> (
144 &mut [[u8; 64]],
145 &mut [[u8; 64]],
146 &mut [[u8; 64]],
147 &mut [[u8; 64]],
148 ) {
149 pos *= self.shard_len_64;
150 dist *= self.shard_len_64;
151
152 let (ab, cd) = self.data[pos..].split_at_mut(dist * 2);
153 let (a, b) = ab.split_at_mut(dist);
154 let (c, d) = cd.split_at_mut(dist);
155
156 (
157 &mut a[..self.shard_len_64],
158 &mut b[..self.shard_len_64],
159 &mut c[..self.shard_len_64],
160 &mut d[..self.shard_len_64],
161 )
162 }
163
164 pub fn is_empty(&self) -> bool {
166 self.shard_count == 0
167 }
168
169 pub fn len(&self) -> usize {
171 self.shard_count
172 }
173
174 pub fn new(shard_count: usize, shard_len_64: usize, data: &'a mut [[u8; 64]]) -> Self {
180 assert!(data.len() >= shard_count * shard_len_64);
181
182 Self {
183 shard_count,
184 shard_len_64,
185 data: &mut data[..shard_count * shard_len_64],
186 }
187 }
188
189 pub fn split_at_mut(&mut self, mid: usize) -> (ShardsRefMut, ShardsRefMut) {
192 let (a, b) = self.data.split_at_mut(mid * self.shard_len_64);
193
194 (
195 ShardsRefMut::new(mid, self.shard_len_64, a),
196 ShardsRefMut::new(self.shard_count - mid, self.shard_len_64, b),
197 )
198 }
199
200 pub fn zero<R: RangeBounds<usize>>(&mut self, range: R) {
202 let start = match range.start_bound() {
203 Bound::Included(start) => start * self.shard_len_64,
204 Bound::Excluded(start) => (start + 1) * self.shard_len_64,
205 Bound::Unbounded => 0,
206 };
207
208 let end = match range.end_bound() {
209 Bound::Included(end) => (end + 1) * self.shard_len_64,
210 Bound::Excluded(end) => end * self.shard_len_64,
211 Bound::Unbounded => self.shard_count * self.shard_len_64,
212 };
213
214 self.data[start..end].fill([0; 64]);
215 }
216}
217
218impl Index<usize> for ShardsRefMut<'_> {
222 type Output = [[u8; 64]];
223 fn index(&self, index: usize) -> &Self::Output {
224 &self.data[index * self.shard_len_64..(index + 1) * self.shard_len_64]
225 }
226}
227
228impl IndexMut<usize> for ShardsRefMut<'_> {
232 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
233 &mut self.data[index * self.shard_len_64..(index + 1) * self.shard_len_64]
234 }
235}
236
237impl ShardsRefMut<'_> {
241 pub(crate) fn copy_within(&mut self, mut src: usize, mut dest: usize, mut count: usize) {
242 src *= self.shard_len_64;
243 dest *= self.shard_len_64;
244 count *= self.shard_len_64;
245
246 self.data.copy_within(src..src + count, dest);
247 }
248
249 pub(crate) fn flat2_mut(
254 &mut self,
255 mut x: usize,
256 mut y: usize,
257 mut count: usize,
258 ) -> (&mut [[u8; 64]], &mut [[u8; 64]]) {
259 x *= self.shard_len_64;
260 y *= self.shard_len_64;
261 count *= self.shard_len_64;
262
263 if x < y {
264 let (head, tail) = self.data.split_at_mut(y);
265 (&mut head[x..x + count], &mut tail[..count])
266 } else {
267 let (head, tail) = self.data.split_at_mut(x);
268 (&mut tail[..count], &mut head[y..y + count])
269 }
270 }
271}