curve25519_dalek/backend/vector/ifma/
field.rs

1// -*- mode: rust; -*-
2//
3// This file is part of curve25519-dalek.
4// Copyright (c) 2016-2021 isis lovecruft
5// Copyright (c) 2016-2019 Henry de Valence
6// See LICENSE for licensing information.
7//
8// Authors:
9// - isis agora lovecruft <isis@patternsinthevoid.net>
10// - Henry de Valence <hdevalence@hdevalence.ca>
11
12#![allow(non_snake_case)]
13
14use crate::backend::vector::packed_simd::u64x4;
15use core::ops::{Add, Mul, Neg};
16
17use crate::backend::serial::u64::field::FieldElement51;
18
19use curve25519_dalek_derive::unsafe_target_feature;
20
21/// A wrapper around `vpmadd52luq` that works on `u64x4`.
22#[unsafe_target_feature("avx512ifma,avx512vl")]
23#[inline]
24unsafe fn madd52lo(z: u64x4, x: u64x4, y: u64x4) -> u64x4 {
25    use core::arch::x86_64::_mm256_madd52lo_epu64;
26    _mm256_madd52lo_epu64(z.into(), x.into(), y.into()).into()
27}
28
29/// A wrapper around `vpmadd52huq` that works on `u64x4`.
30#[unsafe_target_feature("avx512ifma,avx512vl")]
31#[inline]
32unsafe fn madd52hi(z: u64x4, x: u64x4, y: u64x4) -> u64x4 {
33    use core::arch::x86_64::_mm256_madd52hi_epu64;
34    _mm256_madd52hi_epu64(z.into(), x.into(), y.into()).into()
35}
36
37/// A vector of four field elements in radix 2^51, with unreduced coefficients.
38#[derive(Copy, Clone, Debug)]
39pub struct F51x4Unreduced(pub(crate) [u64x4; 5]);
40
41/// A vector of four field elements in radix 2^51, with reduced coefficients.
42#[derive(Copy, Clone, Debug)]
43pub struct F51x4Reduced(pub(crate) [u64x4; 5]);
44
45#[allow(clippy::upper_case_acronyms)]
46#[derive(Copy, Clone)]
47pub enum Shuffle {
48    AAAA,
49    BBBB,
50    BADC,
51    BACD,
52    ADDA,
53    CBCB,
54    ABDC,
55    ABAB,
56    DBBD,
57    CACA,
58}
59
60#[unsafe_target_feature("avx512ifma,avx512vl")]
61#[inline(always)]
62fn shuffle_lanes(x: u64x4, control: Shuffle) -> u64x4 {
63    unsafe {
64        use core::arch::x86_64::_mm256_permute4x64_epi64 as perm;
65
66        match control {
67            Shuffle::AAAA => perm(x.into(), 0b00_00_00_00).into(),
68            Shuffle::BBBB => perm(x.into(), 0b01_01_01_01).into(),
69            Shuffle::BADC => perm(x.into(), 0b10_11_00_01).into(),
70            Shuffle::BACD => perm(x.into(), 0b11_10_00_01).into(),
71            Shuffle::ADDA => perm(x.into(), 0b00_11_11_00).into(),
72            Shuffle::CBCB => perm(x.into(), 0b01_10_01_10).into(),
73            Shuffle::ABDC => perm(x.into(), 0b10_11_01_00).into(),
74            Shuffle::ABAB => perm(x.into(), 0b01_00_01_00).into(),
75            Shuffle::DBBD => perm(x.into(), 0b11_01_01_11).into(),
76            Shuffle::CACA => perm(x.into(), 0b00_10_00_10).into(),
77        }
78    }
79}
80
81#[allow(clippy::upper_case_acronyms)]
82#[derive(Copy, Clone)]
83pub enum Lanes {
84    D,
85    C,
86    AB,
87    AC,
88    AD,
89    BCD,
90}
91
92#[unsafe_target_feature("avx512ifma,avx512vl")]
93#[inline]
94fn blend_lanes(x: u64x4, y: u64x4, control: Lanes) -> u64x4 {
95    unsafe {
96        use core::arch::x86_64::_mm256_blend_epi32 as blend;
97
98        match control {
99            Lanes::D => blend(x.into(), y.into(), 0b11_00_00_00).into(),
100            Lanes::C => blend(x.into(), y.into(), 0b00_11_00_00).into(),
101            Lanes::AB => blend(x.into(), y.into(), 0b00_00_11_11).into(),
102            Lanes::AC => blend(x.into(), y.into(), 0b00_11_00_11).into(),
103            Lanes::AD => blend(x.into(), y.into(), 0b11_00_00_11).into(),
104            Lanes::BCD => blend(x.into(), y.into(), 0b11_11_11_00).into(),
105        }
106    }
107}
108
109#[unsafe_target_feature("avx512ifma,avx512vl")]
110impl F51x4Unreduced {
111    pub const ZERO: F51x4Unreduced = F51x4Unreduced([u64x4::splat_const::<0>(); 5]);
112
113    pub fn new(
114        x0: &FieldElement51,
115        x1: &FieldElement51,
116        x2: &FieldElement51,
117        x3: &FieldElement51,
118    ) -> F51x4Unreduced {
119        F51x4Unreduced([
120            u64x4::new(x0.0[0], x1.0[0], x2.0[0], x3.0[0]),
121            u64x4::new(x0.0[1], x1.0[1], x2.0[1], x3.0[1]),
122            u64x4::new(x0.0[2], x1.0[2], x2.0[2], x3.0[2]),
123            u64x4::new(x0.0[3], x1.0[3], x2.0[3], x3.0[3]),
124            u64x4::new(x0.0[4], x1.0[4], x2.0[4], x3.0[4]),
125        ])
126    }
127
128    pub fn split(&self) -> [FieldElement51; 4] {
129        let x = &self.0;
130        [
131            FieldElement51([
132                x[0].extract::<0>(),
133                x[1].extract::<0>(),
134                x[2].extract::<0>(),
135                x[3].extract::<0>(),
136                x[4].extract::<0>(),
137            ]),
138            FieldElement51([
139                x[0].extract::<1>(),
140                x[1].extract::<1>(),
141                x[2].extract::<1>(),
142                x[3].extract::<1>(),
143                x[4].extract::<1>(),
144            ]),
145            FieldElement51([
146                x[0].extract::<2>(),
147                x[1].extract::<2>(),
148                x[2].extract::<2>(),
149                x[3].extract::<2>(),
150                x[4].extract::<2>(),
151            ]),
152            FieldElement51([
153                x[0].extract::<3>(),
154                x[1].extract::<3>(),
155                x[2].extract::<3>(),
156                x[3].extract::<3>(),
157                x[4].extract::<3>(),
158            ]),
159        ]
160    }
161
162    #[inline]
163    pub fn diff_sum(&self) -> F51x4Unreduced {
164        // tmp1 = (B, A, D, C)
165        let tmp1 = self.shuffle(Shuffle::BADC);
166        // tmp2 = (-A, B, -C, D)
167        let tmp2 = self.blend(&self.negate_lazy(), Lanes::AC);
168        // (B - A, B + A, D - C, D + C)
169        tmp1 + tmp2
170    }
171
172    #[inline]
173    pub fn negate_lazy(&self) -> F51x4Unreduced {
174        let lo = u64x4::splat(36028797018963664u64);
175        let hi = u64x4::splat(36028797018963952u64);
176        F51x4Unreduced([
177            lo - self.0[0],
178            hi - self.0[1],
179            hi - self.0[2],
180            hi - self.0[3],
181            hi - self.0[4],
182        ])
183    }
184
185    #[inline]
186    pub fn shuffle(&self, control: Shuffle) -> F51x4Unreduced {
187        F51x4Unreduced([
188            shuffle_lanes(self.0[0], control),
189            shuffle_lanes(self.0[1], control),
190            shuffle_lanes(self.0[2], control),
191            shuffle_lanes(self.0[3], control),
192            shuffle_lanes(self.0[4], control),
193        ])
194    }
195
196    #[inline]
197    pub fn blend(&self, other: &F51x4Unreduced, control: Lanes) -> F51x4Unreduced {
198        F51x4Unreduced([
199            blend_lanes(self.0[0], other.0[0], control),
200            blend_lanes(self.0[1], other.0[1], control),
201            blend_lanes(self.0[2], other.0[2], control),
202            blend_lanes(self.0[3], other.0[3], control),
203            blend_lanes(self.0[4], other.0[4], control),
204        ])
205    }
206}
207
208#[unsafe_target_feature("avx512ifma,avx512vl")]
209impl Neg for F51x4Reduced {
210    type Output = F51x4Reduced;
211
212    fn neg(self) -> F51x4Reduced {
213        F51x4Unreduced::from(self).negate_lazy().into()
214    }
215}
216
217use subtle::Choice;
218use subtle::ConditionallySelectable;
219
220#[unsafe_target_feature("avx512ifma,avx512vl")]
221impl ConditionallySelectable for F51x4Reduced {
222    #[inline]
223    fn conditional_select(a: &F51x4Reduced, b: &F51x4Reduced, choice: Choice) -> F51x4Reduced {
224        let mask = (-(choice.unwrap_u8() as i64)) as u64;
225        let mask_vec = u64x4::splat(mask);
226        F51x4Reduced([
227            a.0[0] ^ (mask_vec & (a.0[0] ^ b.0[0])),
228            a.0[1] ^ (mask_vec & (a.0[1] ^ b.0[1])),
229            a.0[2] ^ (mask_vec & (a.0[2] ^ b.0[2])),
230            a.0[3] ^ (mask_vec & (a.0[3] ^ b.0[3])),
231            a.0[4] ^ (mask_vec & (a.0[4] ^ b.0[4])),
232        ])
233    }
234
235    #[inline]
236    fn conditional_assign(&mut self, other: &F51x4Reduced, choice: Choice) {
237        let mask = (-(choice.unwrap_u8() as i64)) as u64;
238        let mask_vec = u64x4::splat(mask);
239        self.0[0] ^= mask_vec & (self.0[0] ^ other.0[0]);
240        self.0[1] ^= mask_vec & (self.0[1] ^ other.0[1]);
241        self.0[2] ^= mask_vec & (self.0[2] ^ other.0[2]);
242        self.0[3] ^= mask_vec & (self.0[3] ^ other.0[3]);
243        self.0[4] ^= mask_vec & (self.0[4] ^ other.0[4]);
244    }
245}
246
247#[unsafe_target_feature("avx512ifma,avx512vl")]
248impl F51x4Reduced {
249    #[inline]
250    pub fn shuffle(&self, control: Shuffle) -> F51x4Reduced {
251        F51x4Reduced([
252            shuffle_lanes(self.0[0], control),
253            shuffle_lanes(self.0[1], control),
254            shuffle_lanes(self.0[2], control),
255            shuffle_lanes(self.0[3], control),
256            shuffle_lanes(self.0[4], control),
257        ])
258    }
259
260    #[inline]
261    pub fn blend(&self, other: &F51x4Reduced, control: Lanes) -> F51x4Reduced {
262        F51x4Reduced([
263            blend_lanes(self.0[0], other.0[0], control),
264            blend_lanes(self.0[1], other.0[1], control),
265            blend_lanes(self.0[2], other.0[2], control),
266            blend_lanes(self.0[3], other.0[3], control),
267            blend_lanes(self.0[4], other.0[4], control),
268        ])
269    }
270
271    #[inline]
272    pub fn square(&self) -> F51x4Unreduced {
273        unsafe {
274            let x = &self.0;
275
276            // Represent values with coeff. 2
277            let mut z0_2 = u64x4::splat(0);
278            let mut z1_2 = u64x4::splat(0);
279            let mut z2_2 = u64x4::splat(0);
280            let mut z3_2 = u64x4::splat(0);
281            let mut z4_2 = u64x4::splat(0);
282            let mut z5_2 = u64x4::splat(0);
283            let mut z6_2 = u64x4::splat(0);
284            let mut z7_2 = u64x4::splat(0);
285            let mut z9_2 = u64x4::splat(0);
286
287            // Represent values with coeff. 4
288            let mut z2_4 = u64x4::splat(0);
289            let mut z3_4 = u64x4::splat(0);
290            let mut z4_4 = u64x4::splat(0);
291            let mut z5_4 = u64x4::splat(0);
292            let mut z6_4 = u64x4::splat(0);
293            let mut z7_4 = u64x4::splat(0);
294            let mut z8_4 = u64x4::splat(0);
295
296            let mut z0_1 = u64x4::splat(0);
297            z0_1 = madd52lo(z0_1, x[0], x[0]);
298
299            let mut z1_1 = u64x4::splat(0);
300            z1_2 = madd52lo(z1_2, x[0], x[1]);
301            z1_2 = madd52hi(z1_2, x[0], x[0]);
302
303            z2_4 = madd52hi(z2_4, x[0], x[1]);
304            let mut z2_1 = z2_4.shl::<2>();
305            z2_2 = madd52lo(z2_2, x[0], x[2]);
306            z2_1 = madd52lo(z2_1, x[1], x[1]);
307
308            z3_4 = madd52hi(z3_4, x[0], x[2]);
309            let mut z3_1 = z3_4.shl::<2>();
310            z3_2 = madd52lo(z3_2, x[1], x[2]);
311            z3_2 = madd52lo(z3_2, x[0], x[3]);
312            z3_2 = madd52hi(z3_2, x[1], x[1]);
313
314            z4_4 = madd52hi(z4_4, x[1], x[2]);
315            z4_4 = madd52hi(z4_4, x[0], x[3]);
316            let mut z4_1 = z4_4.shl::<2>();
317            z4_2 = madd52lo(z4_2, x[1], x[3]);
318            z4_2 = madd52lo(z4_2, x[0], x[4]);
319            z4_1 = madd52lo(z4_1, x[2], x[2]);
320
321            z5_4 = madd52hi(z5_4, x[1], x[3]);
322            z5_4 = madd52hi(z5_4, x[0], x[4]);
323            let mut z5_1 = z5_4.shl::<2>();
324            z5_2 = madd52lo(z5_2, x[2], x[3]);
325            z5_2 = madd52lo(z5_2, x[1], x[4]);
326            z5_2 = madd52hi(z5_2, x[2], x[2]);
327
328            z6_4 = madd52hi(z6_4, x[2], x[3]);
329            z6_4 = madd52hi(z6_4, x[1], x[4]);
330            let mut z6_1 = z6_4.shl::<2>();
331            z6_2 = madd52lo(z6_2, x[2], x[4]);
332            z6_1 = madd52lo(z6_1, x[3], x[3]);
333
334            z7_4 = madd52hi(z7_4, x[2], x[4]);
335            let mut z7_1 = z7_4.shl::<2>();
336            z7_2 = madd52lo(z7_2, x[3], x[4]);
337            z7_2 = madd52hi(z7_2, x[3], x[3]);
338
339            z8_4 = madd52hi(z8_4, x[3], x[4]);
340            let mut z8_1 = z8_4.shl::<2>();
341            z8_1 = madd52lo(z8_1, x[4], x[4]);
342
343            let mut z9_1 = u64x4::splat(0);
344            z9_2 = madd52hi(z9_2, x[4], x[4]);
345
346            z5_1 += z5_2.shl::<1>();
347            z6_1 += z6_2.shl::<1>();
348            z7_1 += z7_2.shl::<1>();
349            z9_1 += z9_2.shl::<1>();
350
351            let mut t0 = u64x4::splat(0);
352            let mut t1 = u64x4::splat(0);
353            let r19 = u64x4::splat(19);
354
355            t0 = madd52hi(t0, r19, z9_1);
356            t1 = madd52lo(t1, r19, z9_1.shr::<52>());
357
358            z4_2 = madd52lo(z4_2, r19, z8_1.shr::<52>());
359            z3_2 = madd52lo(z3_2, r19, z7_1.shr::<52>());
360            z2_2 = madd52lo(z2_2, r19, z6_1.shr::<52>());
361            z1_2 = madd52lo(z1_2, r19, z5_1.shr::<52>());
362
363            z0_2 = madd52lo(z0_2, r19, t0 + t1);
364            z1_2 = madd52hi(z1_2, r19, z5_1);
365            z2_2 = madd52hi(z2_2, r19, z6_1);
366            z3_2 = madd52hi(z3_2, r19, z7_1);
367            z4_2 = madd52hi(z4_2, r19, z8_1);
368
369            z0_1 = madd52lo(z0_1, r19, z5_1);
370            z1_1 = madd52lo(z1_1, r19, z6_1);
371            z2_1 = madd52lo(z2_1, r19, z7_1);
372            z3_1 = madd52lo(z3_1, r19, z8_1);
373            z4_1 = madd52lo(z4_1, r19, z9_1);
374
375            F51x4Unreduced([
376                z0_1 + z0_2 + z0_2,
377                z1_1 + z1_2 + z1_2,
378                z2_1 + z2_2 + z2_2,
379                z3_1 + z3_2 + z3_2,
380                z4_1 + z4_2 + z4_2,
381            ])
382        }
383    }
384}
385
386#[unsafe_target_feature("avx512ifma,avx512vl")]
387impl From<F51x4Reduced> for F51x4Unreduced {
388    #[inline]
389    fn from(x: F51x4Reduced) -> F51x4Unreduced {
390        F51x4Unreduced(x.0)
391    }
392}
393
394#[unsafe_target_feature("avx512ifma,avx512vl")]
395impl From<F51x4Unreduced> for F51x4Reduced {
396    #[inline]
397    fn from(x: F51x4Unreduced) -> F51x4Reduced {
398        let mask = u64x4::splat((1 << 51) - 1);
399        let r19 = u64x4::splat(19);
400
401        // Compute carryouts in parallel
402        let c0 = x.0[0].shr::<51>();
403        let c1 = x.0[1].shr::<51>();
404        let c2 = x.0[2].shr::<51>();
405        let c3 = x.0[3].shr::<51>();
406        let c4 = x.0[4].shr::<51>();
407
408        unsafe {
409            F51x4Reduced([
410                madd52lo(x.0[0] & mask, c4, r19),
411                (x.0[1] & mask) + c0,
412                (x.0[2] & mask) + c1,
413                (x.0[3] & mask) + c2,
414                (x.0[4] & mask) + c3,
415            ])
416        }
417    }
418}
419
420#[unsafe_target_feature("avx512ifma,avx512vl")]
421impl Add<F51x4Unreduced> for F51x4Unreduced {
422    type Output = F51x4Unreduced;
423    #[inline]
424    fn add(self, rhs: F51x4Unreduced) -> F51x4Unreduced {
425        F51x4Unreduced([
426            self.0[0] + rhs.0[0],
427            self.0[1] + rhs.0[1],
428            self.0[2] + rhs.0[2],
429            self.0[3] + rhs.0[3],
430            self.0[4] + rhs.0[4],
431        ])
432    }
433}
434
435#[unsafe_target_feature("avx512ifma,avx512vl")]
436impl<'a> Mul<(u32, u32, u32, u32)> for &'a F51x4Reduced {
437    type Output = F51x4Unreduced;
438    #[inline]
439    fn mul(self, scalars: (u32, u32, u32, u32)) -> F51x4Unreduced {
440        unsafe {
441            let x = &self.0;
442            let y = u64x4::new(
443                scalars.0 as u64,
444                scalars.1 as u64,
445                scalars.2 as u64,
446                scalars.3 as u64,
447            );
448            let r19 = u64x4::splat(19);
449
450            let mut z0_1 = u64x4::splat(0);
451            let mut z1_1 = u64x4::splat(0);
452            let mut z2_1 = u64x4::splat(0);
453            let mut z3_1 = u64x4::splat(0);
454            let mut z4_1 = u64x4::splat(0);
455            let mut z1_2 = u64x4::splat(0);
456            let mut z2_2 = u64x4::splat(0);
457            let mut z3_2 = u64x4::splat(0);
458            let mut z4_2 = u64x4::splat(0);
459            let mut z5_2 = u64x4::splat(0);
460
461            // Wave 0
462            z4_2 = madd52hi(z4_2, y, x[3]);
463            z5_2 = madd52hi(z5_2, y, x[4]);
464            z4_1 = madd52lo(z4_1, y, x[4]);
465            z0_1 = madd52lo(z0_1, y, x[0]);
466            z3_1 = madd52lo(z3_1, y, x[3]);
467            z2_1 = madd52lo(z2_1, y, x[2]);
468            z1_1 = madd52lo(z1_1, y, x[1]);
469            z3_2 = madd52hi(z3_2, y, x[2]);
470
471            // Wave 2
472            z2_2 = madd52hi(z2_2, y, x[1]);
473            z1_2 = madd52hi(z1_2, y, x[0]);
474            z0_1 = madd52lo(z0_1, z5_2 + z5_2, r19);
475
476            F51x4Unreduced([
477                z0_1,
478                z1_1 + z1_2 + z1_2,
479                z2_1 + z2_2 + z2_2,
480                z3_1 + z3_2 + z3_2,
481                z4_1 + z4_2 + z4_2,
482            ])
483        }
484    }
485}
486
487#[unsafe_target_feature("avx512ifma,avx512vl")]
488impl<'a, 'b> Mul<&'b F51x4Reduced> for &'a F51x4Reduced {
489    type Output = F51x4Unreduced;
490    #[inline]
491    fn mul(self, rhs: &'b F51x4Reduced) -> F51x4Unreduced {
492        unsafe {
493            // Inputs
494            let x = &self.0;
495            let y = &rhs.0;
496
497            // Accumulators for terms with coeff 1
498            let mut z0_1 = u64x4::splat(0);
499            let mut z1_1 = u64x4::splat(0);
500            let mut z2_1 = u64x4::splat(0);
501            let mut z3_1 = u64x4::splat(0);
502            let mut z4_1 = u64x4::splat(0);
503            let mut z5_1 = u64x4::splat(0);
504            let mut z6_1 = u64x4::splat(0);
505            let mut z7_1 = u64x4::splat(0);
506            let mut z8_1 = u64x4::splat(0);
507
508            // Accumulators for terms with coeff 2
509            let mut z0_2 = u64x4::splat(0);
510            let mut z1_2 = u64x4::splat(0);
511            let mut z2_2 = u64x4::splat(0);
512            let mut z3_2 = u64x4::splat(0);
513            let mut z4_2 = u64x4::splat(0);
514            let mut z5_2 = u64x4::splat(0);
515            let mut z6_2 = u64x4::splat(0);
516            let mut z7_2 = u64x4::splat(0);
517            let mut z8_2 = u64x4::splat(0);
518            let mut z9_2 = u64x4::splat(0);
519
520            // LLVM doesn't seem to do much work reordering IFMA
521            // instructions, so try to organize them into "waves" of 8
522            // independent operations (4c latency, 0.5 c throughput
523            // means 8 in flight)
524
525            // Wave 0
526            z4_1 = madd52lo(z4_1, x[2], y[2]);
527            z5_2 = madd52hi(z5_2, x[2], y[2]);
528            z5_1 = madd52lo(z5_1, x[4], y[1]);
529            z6_2 = madd52hi(z6_2, x[4], y[1]);
530            z6_1 = madd52lo(z6_1, x[4], y[2]);
531            z7_2 = madd52hi(z7_2, x[4], y[2]);
532            z7_1 = madd52lo(z7_1, x[4], y[3]);
533            z8_2 = madd52hi(z8_2, x[4], y[3]);
534
535            // Wave 1
536            z4_1 = madd52lo(z4_1, x[3], y[1]);
537            z5_2 = madd52hi(z5_2, x[3], y[1]);
538            z5_1 = madd52lo(z5_1, x[3], y[2]);
539            z6_2 = madd52hi(z6_2, x[3], y[2]);
540            z6_1 = madd52lo(z6_1, x[3], y[3]);
541            z7_2 = madd52hi(z7_2, x[3], y[3]);
542            z7_1 = madd52lo(z7_1, x[3], y[4]);
543            z8_2 = madd52hi(z8_2, x[3], y[4]);
544
545            // Wave 2
546            z8_1 = madd52lo(z8_1, x[4], y[4]);
547            z9_2 = madd52hi(z9_2, x[4], y[4]);
548            z4_1 = madd52lo(z4_1, x[4], y[0]);
549            z5_2 = madd52hi(z5_2, x[4], y[0]);
550            z5_1 = madd52lo(z5_1, x[2], y[3]);
551            z6_2 = madd52hi(z6_2, x[2], y[3]);
552            z6_1 = madd52lo(z6_1, x[2], y[4]);
553            z7_2 = madd52hi(z7_2, x[2], y[4]);
554
555            let z8 = z8_1 + z8_2 + z8_2;
556            let z9 = z9_2 + z9_2;
557
558            // Wave 3
559            z3_1 = madd52lo(z3_1, x[3], y[0]);
560            z4_2 = madd52hi(z4_2, x[3], y[0]);
561            z4_1 = madd52lo(z4_1, x[1], y[3]);
562            z5_2 = madd52hi(z5_2, x[1], y[3]);
563            z5_1 = madd52lo(z5_1, x[1], y[4]);
564            z6_2 = madd52hi(z6_2, x[1], y[4]);
565            z2_1 = madd52lo(z2_1, x[2], y[0]);
566            z3_2 = madd52hi(z3_2, x[2], y[0]);
567
568            let z6 = z6_1 + z6_2 + z6_2;
569            let z7 = z7_1 + z7_2 + z7_2;
570
571            // Wave 4
572            z3_1 = madd52lo(z3_1, x[2], y[1]);
573            z4_2 = madd52hi(z4_2, x[2], y[1]);
574            z4_1 = madd52lo(z4_1, x[0], y[4]);
575            z5_2 = madd52hi(z5_2, x[0], y[4]);
576            z1_1 = madd52lo(z1_1, x[1], y[0]);
577            z2_2 = madd52hi(z2_2, x[1], y[0]);
578            z2_1 = madd52lo(z2_1, x[1], y[1]);
579            z3_2 = madd52hi(z3_2, x[1], y[1]);
580
581            let z5 = z5_1 + z5_2 + z5_2;
582
583            // Wave 5
584            z3_1 = madd52lo(z3_1, x[1], y[2]);
585            z4_2 = madd52hi(z4_2, x[1], y[2]);
586            z0_1 = madd52lo(z0_1, x[0], y[0]);
587            z1_2 = madd52hi(z1_2, x[0], y[0]);
588            z1_1 = madd52lo(z1_1, x[0], y[1]);
589            z2_1 = madd52lo(z2_1, x[0], y[2]);
590            z2_2 = madd52hi(z2_2, x[0], y[1]);
591            z3_2 = madd52hi(z3_2, x[0], y[2]);
592
593            let mut t0 = u64x4::splat(0);
594            let mut t1 = u64x4::splat(0);
595            let r19 = u64x4::splat(19);
596
597            // Wave 6
598            t0 = madd52hi(t0, r19, z9);
599            t1 = madd52lo(t1, r19, z9.shr::<52>());
600            z3_1 = madd52lo(z3_1, x[0], y[3]);
601            z4_2 = madd52hi(z4_2, x[0], y[3]);
602            z1_2 = madd52lo(z1_2, r19, z5.shr::<52>());
603            z2_2 = madd52lo(z2_2, r19, z6.shr::<52>());
604            z3_2 = madd52lo(z3_2, r19, z7.shr::<52>());
605            z0_1 = madd52lo(z0_1, r19, z5);
606
607            // Wave 7
608            z4_1 = madd52lo(z4_1, r19, z9);
609            z1_1 = madd52lo(z1_1, r19, z6);
610            z0_2 = madd52lo(z0_2, r19, t0 + t1);
611            z4_2 = madd52hi(z4_2, r19, z8);
612            z2_1 = madd52lo(z2_1, r19, z7);
613            z1_2 = madd52hi(z1_2, r19, z5);
614            z2_2 = madd52hi(z2_2, r19, z6);
615            z3_2 = madd52hi(z3_2, r19, z7);
616
617            // Wave 8
618            z3_1 = madd52lo(z3_1, r19, z8);
619            z4_2 = madd52lo(z4_2, r19, z8.shr::<52>());
620
621            F51x4Unreduced([
622                z0_1 + z0_2 + z0_2,
623                z1_1 + z1_2 + z1_2,
624                z2_1 + z2_2 + z2_2,
625                z3_1 + z3_2 + z3_2,
626                z4_1 + z4_2 + z4_2,
627            ])
628        }
629    }
630}
631
632#[cfg(target_feature = "avx512ifma,avx512vl")]
633#[cfg(test)]
634mod test {
635    use super::*;
636
637    #[test]
638    fn vpmadd52luq() {
639        let x = u64x4::splat(2);
640        let y = u64x4::splat(3);
641        let mut z = u64x4::splat(5);
642
643        z = unsafe { madd52lo(z, x, y) };
644
645        assert_eq!(z, u64x4::splat(5 + 2 * 3));
646    }
647
648    #[test]
649    fn new_split_round_trip_on_reduced_input() {
650        // Invert a small field element to get a big one
651        let a = FieldElement51([2438, 24, 243, 0, 0]).invert();
652
653        let ax4 = F51x4Unreduced::new(&a, &a, &a, &a);
654        let splits = ax4.split();
655
656        for i in 0..4 {
657            assert_eq!(a, splits[i]);
658        }
659    }
660
661    #[test]
662    fn new_split_round_trip_on_unreduced_input() {
663        // Invert a small field element to get a big one
664        let a = FieldElement51([2438, 24, 243, 0, 0]).invert();
665        // ... but now multiply it by 16 without reducing coeffs
666        let a16 = FieldElement51([
667            a.0[0] << 4,
668            a.0[1] << 4,
669            a.0[2] << 4,
670            a.0[3] << 4,
671            a.0[4] << 4,
672        ]);
673
674        let a16x4 = F51x4Unreduced::new(&a16, &a16, &a16, &a16);
675        let splits = a16x4.split();
676
677        for i in 0..4 {
678            assert_eq!(a16, splits[i]);
679        }
680    }
681
682    #[test]
683    fn test_reduction() {
684        // Invert a small field element to get a big one
685        let a = FieldElement51([2438, 24, 243, 0, 0]).invert();
686        // ... but now multiply it by 128 without reducing coeffs
687        let abig = FieldElement51([
688            a.0[0] << 4,
689            a.0[1] << 4,
690            a.0[2] << 4,
691            a.0[3] << 4,
692            a.0[4] << 4,
693        ]);
694
695        let abigx4: F51x4Reduced = F51x4Unreduced::new(&abig, &abig, &abig, &abig).into();
696
697        let splits = F51x4Unreduced::from(abigx4).split();
698        let c = &a * &FieldElement51([(1 << 4), 0, 0, 0, 0]);
699
700        for i in 0..4 {
701            assert_eq!(c, splits[i]);
702        }
703    }
704
705    #[test]
706    fn mul_matches_serial() {
707        // Invert a small field element to get a big one
708        let a = FieldElement51([2438, 24, 243, 0, 0]).invert();
709        let b = FieldElement51([98098, 87987897, 0, 1, 0]).invert();
710        let c = &a * &b;
711
712        let ax4: F51x4Reduced = F51x4Unreduced::new(&a, &a, &a, &a).into();
713        let bx4: F51x4Reduced = F51x4Unreduced::new(&b, &b, &b, &b).into();
714        let cx4 = &ax4 * &bx4;
715
716        let splits = cx4.split();
717
718        for i in 0..4 {
719            assert_eq!(c, splits[i]);
720        }
721    }
722
723    #[test]
724    fn iterated_mul_matches_serial() {
725        // Invert a small field element to get a big one
726        let a = FieldElement51([2438, 24, 243, 0, 0]).invert();
727        let b = FieldElement51([98098, 87987897, 0, 1, 0]).invert();
728        let mut c = &a * &b;
729        for _i in 0..1024 {
730            c = &a * &c;
731            c = &b * &c;
732        }
733
734        let ax4: F51x4Reduced = F51x4Unreduced::new(&a, &a, &a, &a).into();
735        let bx4: F51x4Reduced = F51x4Unreduced::new(&b, &b, &b, &b).into();
736        let mut cx4 = &ax4 * &bx4;
737        for _i in 0..1024 {
738            cx4 = &ax4 * &F51x4Reduced::from(cx4);
739            cx4 = &bx4 * &F51x4Reduced::from(cx4);
740        }
741
742        let splits = cx4.split();
743
744        for i in 0..4 {
745            assert_eq!(c, splits[i]);
746        }
747    }
748
749    #[test]
750    fn square_matches_mul() {
751        // Invert a small field element to get a big one
752        let a = FieldElement51([2438, 24, 243, 0, 0]).invert();
753
754        let ax4: F51x4Reduced = F51x4Unreduced::new(&a, &a, &a, &a).into();
755        let cx4 = &ax4 * &ax4;
756        let cx4_sq = ax4.square();
757
758        let splits = cx4.split();
759        let splits_sq = cx4_sq.split();
760
761        for i in 0..4 {
762            assert_eq!(splits_sq[i], splits[i]);
763        }
764    }
765
766    #[test]
767    fn iterated_square_matches_serial() {
768        // Invert a small field element to get a big one
769        let mut a = FieldElement51([2438, 24, 243, 0, 0]).invert();
770        let mut ax4 = F51x4Unreduced::new(&a, &a, &a, &a);
771        for _j in 0..1024 {
772            a = a.square();
773            ax4 = F51x4Reduced::from(ax4).square();
774
775            let splits = ax4.split();
776            for i in 0..4 {
777                assert_eq!(a, splits[i]);
778            }
779        }
780    }
781
782    #[test]
783    fn iterated_u32_mul_matches_serial() {
784        // Invert a small field element to get a big one
785        let a = FieldElement51([2438, 24, 243, 0, 0]).invert();
786        let b = FieldElement51([121665, 0, 0, 0, 0]);
787        let mut c = &a * &b;
788        for _i in 0..1024 {
789            c = &b * &c;
790        }
791
792        let ax4 = F51x4Unreduced::new(&a, &a, &a, &a);
793        let bx4 = (121665u32, 121665u32, 121665u32, 121665u32);
794        let mut cx4 = &F51x4Reduced::from(ax4) * bx4;
795        for _i in 0..1024 {
796            cx4 = &F51x4Reduced::from(cx4) * bx4;
797        }
798
799        let splits = cx4.split();
800
801        for i in 0..4 {
802            assert_eq!(c, splits[i]);
803        }
804    }
805
806    #[test]
807    fn shuffle_AAAA() {
808        let x0 = FieldElement51::from_bytes(&[0x10; 32]);
809        let x1 = FieldElement51::from_bytes(&[0x11; 32]);
810        let x2 = FieldElement51::from_bytes(&[0x12; 32]);
811        let x3 = FieldElement51::from_bytes(&[0x13; 32]);
812
813        let x = F51x4Unreduced::new(&x0, &x1, &x2, &x3);
814
815        let y = x.shuffle(Shuffle::AAAA);
816        let splits = y.split();
817
818        assert_eq!(splits[0], x0);
819        assert_eq!(splits[1], x0);
820        assert_eq!(splits[2], x0);
821        assert_eq!(splits[3], x0);
822    }
823
824    #[test]
825    fn blend_AB() {
826        let x0 = FieldElement51::from_bytes(&[0x10; 32]);
827        let x1 = FieldElement51::from_bytes(&[0x11; 32]);
828        let x2 = FieldElement51::from_bytes(&[0x12; 32]);
829        let x3 = FieldElement51::from_bytes(&[0x13; 32]);
830
831        let x = F51x4Unreduced::new(&x0, &x1, &x2, &x3);
832        let z = F51x4Unreduced::new(&x3, &x2, &x1, &x0);
833
834        let y = x.blend(&z, Lanes::AB);
835        let splits = y.split();
836
837        assert_eq!(splits[0], x3);
838        assert_eq!(splits[1], x2);
839        assert_eq!(splits[2], x2);
840        assert_eq!(splits[3], x3);
841    }
842}