reed_solomon_simd/engine/
engine_avx2.rs

1use std::iter::zip;
2
3#[cfg(target_arch = "x86")]
4use std::arch::x86::*;
5#[cfg(target_arch = "x86_64")]
6use std::arch::x86_64::*;
7
8use crate::engine::{
9    tables::{self, Mul128, Multiply128lutT, Skew},
10    utils, Engine, GfElement, ShardsRefMut, GF_MODULUS, GF_ORDER,
11};
12
13// ======================================================================
14// Avx2 - PUBLIC
15
16/// Optimized [`Engine`] using AVX2 instructions.
17///
18/// [`Avx2`] is an optimized engine that follows the same algorithm as
19/// [`NoSimd`] but takes advantage of the x86 AVX2 SIMD instructions.
20///
21/// [`NoSimd`]: crate::engine::NoSimd
22#[derive(Clone, Copy)]
23pub struct Avx2 {
24    mul128: &'static Mul128,
25    skew: &'static Skew,
26}
27
28impl Avx2 {
29    /// Creates new [`Avx2`], initializing all [tables]
30    /// needed for encoding or decoding.
31    ///
32    /// Currently only difference between encoding/decoding is
33    /// [`LogWalsh`] (128 kiB) which is only needed for decoding.
34    ///
35    /// [`LogWalsh`]: crate::engine::tables::LogWalsh
36    pub fn new() -> Self {
37        let mul128 = &*tables::MUL128;
38        let skew = &*tables::SKEW;
39
40        Self { mul128, skew }
41    }
42}
43
44impl Engine for Avx2 {
45    fn fft(
46        &self,
47        data: &mut ShardsRefMut,
48        pos: usize,
49        size: usize,
50        truncated_size: usize,
51        skew_delta: usize,
52    ) {
53        unsafe {
54            self.fft_private_avx2(data, pos, size, truncated_size, skew_delta);
55        }
56    }
57
58    fn ifft(
59        &self,
60        data: &mut ShardsRefMut,
61        pos: usize,
62        size: usize,
63        truncated_size: usize,
64        skew_delta: usize,
65    ) {
66        unsafe {
67            self.ifft_private_avx2(data, pos, size, truncated_size, skew_delta);
68        }
69    }
70
71    fn mul(&self, x: &mut [[u8; 64]], log_m: GfElement) {
72        unsafe {
73            self.mul_avx2(x, log_m);
74        }
75    }
76
77    fn eval_poly(erasures: &mut [GfElement; GF_ORDER], truncated_size: usize) {
78        unsafe { Self::eval_poly_avx2(erasures, truncated_size) }
79    }
80}
81
82// ======================================================================
83// Avx2 - IMPL Default
84
85impl Default for Avx2 {
86    fn default() -> Self {
87        Self::new()
88    }
89}
90
91// ======================================================================
92// Avx2 - PRIVATE
93//
94//
95
96#[derive(Copy, Clone)]
97struct LutAvx2 {
98    t0_lo: __m256i,
99    t1_lo: __m256i,
100    t2_lo: __m256i,
101    t3_lo: __m256i,
102    t0_hi: __m256i,
103    t1_hi: __m256i,
104    t2_hi: __m256i,
105    t3_hi: __m256i,
106}
107
108impl From<&Multiply128lutT> for LutAvx2 {
109    #[inline(always)]
110    fn from(lut: &Multiply128lutT) -> Self {
111        unsafe {
112            Self {
113                t0_lo: _mm256_broadcastsi128_si256(_mm_loadu_si128(
114                    std::ptr::from_ref::<u128>(&lut.lo[0]).cast::<__m128i>(),
115                )),
116                t1_lo: _mm256_broadcastsi128_si256(_mm_loadu_si128(
117                    std::ptr::from_ref::<u128>(&lut.lo[1]).cast::<__m128i>(),
118                )),
119                t2_lo: _mm256_broadcastsi128_si256(_mm_loadu_si128(
120                    std::ptr::from_ref::<u128>(&lut.lo[2]).cast::<__m128i>(),
121                )),
122                t3_lo: _mm256_broadcastsi128_si256(_mm_loadu_si128(
123                    std::ptr::from_ref::<u128>(&lut.lo[3]).cast::<__m128i>(),
124                )),
125                t0_hi: _mm256_broadcastsi128_si256(_mm_loadu_si128(
126                    std::ptr::from_ref::<u128>(&lut.hi[0]).cast::<__m128i>(),
127                )),
128                t1_hi: _mm256_broadcastsi128_si256(_mm_loadu_si128(
129                    std::ptr::from_ref::<u128>(&lut.hi[1]).cast::<__m128i>(),
130                )),
131                t2_hi: _mm256_broadcastsi128_si256(_mm_loadu_si128(
132                    std::ptr::from_ref::<u128>(&lut.hi[2]).cast::<__m128i>(),
133                )),
134                t3_hi: _mm256_broadcastsi128_si256(_mm_loadu_si128(
135                    std::ptr::from_ref::<u128>(&lut.hi[3]).cast::<__m128i>(),
136                )),
137            }
138        }
139    }
140}
141
142impl Avx2 {
143    #[target_feature(enable = "avx2")]
144    unsafe fn mul_avx2(&self, x: &mut [[u8; 64]], log_m: GfElement) {
145        let lut = &self.mul128[log_m as usize];
146        let lut_avx2 = LutAvx2::from(lut);
147
148        for chunk in x.iter_mut() {
149            let x_ptr = chunk.as_mut_ptr().cast::<__m256i>();
150            unsafe {
151                let x_lo = _mm256_loadu_si256(x_ptr);
152                let x_hi = _mm256_loadu_si256(x_ptr.add(1));
153                let (prod_lo, prod_hi) = Self::mul_256(x_lo, x_hi, lut_avx2);
154                _mm256_storeu_si256(x_ptr, prod_lo);
155                _mm256_storeu_si256(x_ptr.add(1), prod_hi);
156            }
157        }
158    }
159
160    // Impelemntation of LEO_MUL_256
161    #[inline(always)]
162    fn mul_256(value_lo: __m256i, value_hi: __m256i, lut_avx2: LutAvx2) -> (__m256i, __m256i) {
163        let mut prod_lo: __m256i;
164        let mut prod_hi: __m256i;
165
166        unsafe {
167            let clr_mask = _mm256_set1_epi8(0x0f);
168
169            let data_0 = _mm256_and_si256(value_lo, clr_mask);
170            prod_lo = _mm256_shuffle_epi8(lut_avx2.t0_lo, data_0);
171            prod_hi = _mm256_shuffle_epi8(lut_avx2.t0_hi, data_0);
172
173            let data_1 = _mm256_and_si256(_mm256_srli_epi64(value_lo, 4), clr_mask);
174            prod_lo = _mm256_xor_si256(prod_lo, _mm256_shuffle_epi8(lut_avx2.t1_lo, data_1));
175            prod_hi = _mm256_xor_si256(prod_hi, _mm256_shuffle_epi8(lut_avx2.t1_hi, data_1));
176
177            let data_0 = _mm256_and_si256(value_hi, clr_mask);
178            prod_lo = _mm256_xor_si256(prod_lo, _mm256_shuffle_epi8(lut_avx2.t2_lo, data_0));
179            prod_hi = _mm256_xor_si256(prod_hi, _mm256_shuffle_epi8(lut_avx2.t2_hi, data_0));
180
181            let data_1 = _mm256_and_si256(_mm256_srli_epi64(value_hi, 4), clr_mask);
182            prod_lo = _mm256_xor_si256(prod_lo, _mm256_shuffle_epi8(lut_avx2.t3_lo, data_1));
183            prod_hi = _mm256_xor_si256(prod_hi, _mm256_shuffle_epi8(lut_avx2.t3_hi, data_1));
184        }
185
186        (prod_lo, prod_hi)
187    }
188
189    //// {x_lo, x_hi} ^= {y_lo, y_hi} * log_m
190    // Implementation of LEO_MULADD_256
191    #[inline(always)]
192    fn muladd_256(
193        mut x_lo: __m256i,
194        mut x_hi: __m256i,
195        y_lo: __m256i,
196        y_hi: __m256i,
197        lut_avx2: LutAvx2,
198    ) -> (__m256i, __m256i) {
199        let (prod_lo, prod_hi) = Self::mul_256(y_lo, y_hi, lut_avx2);
200        unsafe {
201            x_lo = _mm256_xor_si256(x_lo, prod_lo);
202            x_hi = _mm256_xor_si256(x_hi, prod_hi);
203        }
204        (x_lo, x_hi)
205    }
206}
207
208// ======================================================================
209// Avx2 - PRIVATE - FFT (fast Fourier transform)
210
211impl Avx2 {
212    // Implementation of LEO_FFTB_256
213    #[inline(always)]
214    fn fftb_256(x: &mut [u8; 64], y: &mut [u8; 64], lut_avx2: LutAvx2) {
215        let x_ptr = x.as_mut_ptr().cast::<__m256i>();
216        let y_ptr = y.as_mut_ptr().cast::<__m256i>();
217
218        unsafe {
219            let mut x_lo = _mm256_loadu_si256(x_ptr);
220            let mut x_hi = _mm256_loadu_si256(x_ptr.add(1));
221
222            let mut y_lo = _mm256_loadu_si256(y_ptr);
223            let mut y_hi = _mm256_loadu_si256(y_ptr.add(1));
224
225            (x_lo, x_hi) = Self::muladd_256(x_lo, x_hi, y_lo, y_hi, lut_avx2);
226
227            _mm256_storeu_si256(x_ptr, x_lo);
228            _mm256_storeu_si256(x_ptr.add(1), x_hi);
229
230            y_lo = _mm256_xor_si256(y_lo, x_lo);
231            y_hi = _mm256_xor_si256(y_hi, x_hi);
232
233            _mm256_storeu_si256(y_ptr, y_lo);
234            _mm256_storeu_si256(y_ptr.add(1), y_hi);
235        }
236    }
237
238    // Partial butterfly, caller must do `GF_MODULUS` check with `xor`.
239    #[inline(always)]
240    fn fft_butterfly_partial(&self, x: &mut [[u8; 64]], y: &mut [[u8; 64]], log_m: GfElement) {
241        let lut = &self.mul128[log_m as usize];
242        let lut_avx2 = LutAvx2::from(lut);
243
244        for (x_chunk, y_chunk) in zip(x.iter_mut(), y.iter_mut()) {
245            Self::fftb_256(x_chunk, y_chunk, lut_avx2);
246        }
247    }
248
249    #[inline(always)]
250    fn fft_butterfly_two_layers(
251        &self,
252        data: &mut ShardsRefMut,
253        pos: usize,
254        dist: usize,
255        log_m01: GfElement,
256        log_m23: GfElement,
257        log_m02: GfElement,
258    ) {
259        let (s0, s1, s2, s3) = data.dist4_mut(pos, dist);
260
261        // FIRST LAYER
262
263        if log_m02 == GF_MODULUS {
264            utils::xor(s2, s0);
265            utils::xor(s3, s1);
266        } else {
267            self.fft_butterfly_partial(s0, s2, log_m02);
268            self.fft_butterfly_partial(s1, s3, log_m02);
269        }
270
271        // SECOND LAYER
272
273        if log_m01 == GF_MODULUS {
274            utils::xor(s1, s0);
275        } else {
276            self.fft_butterfly_partial(s0, s1, log_m01);
277        }
278
279        if log_m23 == GF_MODULUS {
280            utils::xor(s3, s2);
281        } else {
282            self.fft_butterfly_partial(s2, s3, log_m23);
283        }
284    }
285
286    #[target_feature(enable = "avx2")]
287    unsafe fn fft_private_avx2(
288        &self,
289        data: &mut ShardsRefMut,
290        pos: usize,
291        size: usize,
292        truncated_size: usize,
293        skew_delta: usize,
294    ) {
295        // Drop unsafe privileges
296        self.fft_private(data, pos, size, truncated_size, skew_delta);
297    }
298
299    #[inline(always)]
300    fn fft_private(
301        &self,
302        data: &mut ShardsRefMut,
303        pos: usize,
304        size: usize,
305        truncated_size: usize,
306        skew_delta: usize,
307    ) {
308        // TWO LAYERS AT TIME
309
310        let mut dist4 = size;
311        let mut dist = size >> 2;
312        while dist != 0 {
313            let mut r = 0;
314            while r < truncated_size {
315                let base = r + dist + skew_delta - 1;
316
317                let log_m01 = self.skew[base];
318                let log_m02 = self.skew[base + dist];
319                let log_m23 = self.skew[base + dist * 2];
320
321                for i in r..r + dist {
322                    self.fft_butterfly_two_layers(data, pos + i, dist, log_m01, log_m23, log_m02);
323                }
324
325                r += dist4;
326            }
327            dist4 = dist;
328            dist >>= 2;
329        }
330
331        // FINAL ODD LAYER
332
333        if dist4 == 2 {
334            let mut r = 0;
335            while r < truncated_size {
336                let log_m = self.skew[r + skew_delta];
337
338                let (x, y) = data.dist2_mut(pos + r, 1);
339
340                if log_m == GF_MODULUS {
341                    utils::xor(y, x);
342                } else {
343                    self.fft_butterfly_partial(x, y, log_m);
344                }
345
346                r += 2;
347            }
348        }
349    }
350}
351
352// ======================================================================
353// Avx2 - PRIVATE - IFFT (inverse fast Fourier transform)
354
355impl Avx2 {
356    // Implementation of LEO_IFFTB_256
357    #[inline(always)]
358    fn ifftb_256(x: &mut [u8; 64], y: &mut [u8; 64], lut_avx2: LutAvx2) {
359        let x_ptr = x.as_mut_ptr().cast::<__m256i>();
360        let y_ptr = y.as_mut_ptr().cast::<__m256i>();
361
362        unsafe {
363            let mut x_lo = _mm256_loadu_si256(x_ptr);
364            let mut x_hi = _mm256_loadu_si256(x_ptr.add(1));
365
366            let mut y_lo = _mm256_loadu_si256(y_ptr);
367            let mut y_hi = _mm256_loadu_si256(y_ptr.add(1));
368
369            y_lo = _mm256_xor_si256(y_lo, x_lo);
370            y_hi = _mm256_xor_si256(y_hi, x_hi);
371
372            _mm256_storeu_si256(y_ptr, y_lo);
373            _mm256_storeu_si256(y_ptr.add(1), y_hi);
374
375            (x_lo, x_hi) = Self::muladd_256(x_lo, x_hi, y_lo, y_hi, lut_avx2);
376
377            _mm256_storeu_si256(x_ptr, x_lo);
378            _mm256_storeu_si256(x_ptr.add(1), x_hi);
379        }
380    }
381
382    #[inline(always)]
383    fn ifft_butterfly_partial(&self, x: &mut [[u8; 64]], y: &mut [[u8; 64]], log_m: GfElement) {
384        let lut = &self.mul128[log_m as usize];
385        let lut_avx2 = LutAvx2::from(lut);
386
387        for (x_chunk, y_chunk) in zip(x.iter_mut(), y.iter_mut()) {
388            Self::ifftb_256(x_chunk, y_chunk, lut_avx2);
389        }
390    }
391
392    #[inline(always)]
393    fn ifft_butterfly_two_layers(
394        &self,
395        data: &mut ShardsRefMut,
396        pos: usize,
397        dist: usize,
398        log_m01: GfElement,
399        log_m23: GfElement,
400        log_m02: GfElement,
401    ) {
402        let (s0, s1, s2, s3) = data.dist4_mut(pos, dist);
403
404        // FIRST LAYER
405
406        if log_m01 == GF_MODULUS {
407            utils::xor(s1, s0);
408        } else {
409            self.ifft_butterfly_partial(s0, s1, log_m01);
410        }
411
412        if log_m23 == GF_MODULUS {
413            utils::xor(s3, s2);
414        } else {
415            self.ifft_butterfly_partial(s2, s3, log_m23);
416        }
417
418        // SECOND LAYER
419
420        if log_m02 == GF_MODULUS {
421            utils::xor(s2, s0);
422            utils::xor(s3, s1);
423        } else {
424            self.ifft_butterfly_partial(s0, s2, log_m02);
425            self.ifft_butterfly_partial(s1, s3, log_m02);
426        }
427    }
428
429    #[target_feature(enable = "avx2")]
430    unsafe fn ifft_private_avx2(
431        &self,
432        data: &mut ShardsRefMut,
433        pos: usize,
434        size: usize,
435        truncated_size: usize,
436        skew_delta: usize,
437    ) {
438        // Drop unsafe privileges
439        self.ifft_private(data, pos, size, truncated_size, skew_delta);
440    }
441
442    #[inline(always)]
443    fn ifft_private(
444        &self,
445        data: &mut ShardsRefMut,
446        pos: usize,
447        size: usize,
448        truncated_size: usize,
449        skew_delta: usize,
450    ) {
451        // TWO LAYERS AT TIME
452
453        let mut dist = 1;
454        let mut dist4 = 4;
455        while dist4 <= size {
456            let mut r = 0;
457            while r < truncated_size {
458                let base = r + dist + skew_delta - 1;
459
460                let log_m01 = self.skew[base];
461                let log_m02 = self.skew[base + dist];
462                let log_m23 = self.skew[base + dist * 2];
463
464                for i in r..r + dist {
465                    self.ifft_butterfly_two_layers(data, pos + i, dist, log_m01, log_m23, log_m02);
466                }
467
468                r += dist4;
469            }
470            dist = dist4;
471            dist4 <<= 2;
472        }
473
474        // FINAL ODD LAYER
475
476        if dist < size {
477            let log_m = self.skew[dist + skew_delta - 1];
478            if log_m == GF_MODULUS {
479                utils::xor_within(data, pos + dist, pos, dist);
480            } else {
481                let (mut a, mut b) = data.split_at_mut(pos + dist);
482                for i in 0..dist {
483                    self.ifft_butterfly_partial(
484                        &mut a[pos + i], // data[pos + i]
485                        &mut b[i],       // data[pos + i + dist]
486                        log_m,
487                    );
488                }
489            }
490        }
491    }
492}
493
494// ======================================================================
495// Avx2 - PRIVATE - Evaluate polynomial
496
497impl Avx2 {
498    #[target_feature(enable = "avx2")]
499    unsafe fn eval_poly_avx2(erasures: &mut [GfElement; GF_ORDER], truncated_size: usize) {
500        utils::eval_poly(erasures, truncated_size);
501    }
502}
503
504// ======================================================================
505// TESTS
506
507// Engines are tested indirectly via roundtrip tests of HighRate and LowRate.