curve25519_dalek/backend/vector/scalar_mul/
straus.rs

1// -*- mode: rust; -*-
2//
3// This file is part of curve25519-dalek.
4// Copyright (c) 2016-2021 isis lovecruft
5// Copyright (c) 2016-2019 Henry de Valence
6// See LICENSE for licensing information.
7//
8// Authors:
9// - isis agora lovecruft <isis@patternsinthevoid.net>
10// - Henry de Valence <hdevalence@hdevalence.ca>
11
12#![allow(non_snake_case)]
13
14#[curve25519_dalek_derive::unsafe_target_feature_specialize(
15    "avx2",
16    conditional("avx512ifma,avx512vl", nightly)
17)]
18pub mod spec {
19
20    use alloc::vec::Vec;
21
22    use core::borrow::Borrow;
23    use core::cmp::Ordering;
24
25    #[cfg(feature = "zeroize")]
26    use zeroize::Zeroizing;
27
28    #[for_target_feature("avx2")]
29    use crate::backend::vector::avx2::{CachedPoint, ExtendedPoint};
30
31    #[for_target_feature("avx512ifma")]
32    use crate::backend::vector::ifma::{CachedPoint, ExtendedPoint};
33
34    use crate::edwards::EdwardsPoint;
35    use crate::scalar::Scalar;
36    use crate::traits::{Identity, MultiscalarMul, VartimeMultiscalarMul};
37    use crate::window::{LookupTable, NafLookupTable5};
38
39    /// Multiscalar multiplication using interleaved window / Straus'
40    /// method.  See the `Straus` struct in the serial backend for more
41    /// details.
42    ///
43    /// This exists as a seperate implementation from that one because the
44    /// AVX2 code uses different curve models (it does not pass between
45    /// multiple models during scalar mul), and it has to convert the
46    /// point representation on the fly.
47    pub struct Straus {}
48
49    impl MultiscalarMul for Straus {
50        type Point = EdwardsPoint;
51
52        fn multiscalar_mul<I, J>(scalars: I, points: J) -> EdwardsPoint
53        where
54            I: IntoIterator,
55            I::Item: Borrow<Scalar>,
56            J: IntoIterator,
57            J::Item: Borrow<EdwardsPoint>,
58        {
59            // Construct a lookup table of [P,2P,3P,4P,5P,6P,7P,8P]
60            // for each input point P
61            let lookup_tables: Vec<_> = points
62                .into_iter()
63                .map(|point| LookupTable::<CachedPoint>::from(point.borrow()))
64                .collect();
65
66            let scalar_digits_vec: Vec<_> = scalars
67                .into_iter()
68                .map(|s| s.borrow().as_radix_16())
69                .collect();
70            // Pass ownership to a `Zeroizing` wrapper
71            #[cfg(feature = "zeroize")]
72            let scalar_digits_vec = Zeroizing::new(scalar_digits_vec);
73
74            let mut Q = ExtendedPoint::identity();
75            for j in (0..64).rev() {
76                Q = Q.mul_by_pow_2(4);
77                let it = scalar_digits_vec.iter().zip(lookup_tables.iter());
78                for (s_i, lookup_table_i) in it {
79                    // Q = Q + s_{i,j} * P_i
80                    Q = &Q + &lookup_table_i.select(s_i[j]);
81                }
82            }
83            Q.into()
84        }
85    }
86
87    impl VartimeMultiscalarMul for Straus {
88        type Point = EdwardsPoint;
89
90        fn optional_multiscalar_mul<I, J>(scalars: I, points: J) -> Option<EdwardsPoint>
91        where
92            I: IntoIterator,
93            I::Item: Borrow<Scalar>,
94            J: IntoIterator<Item = Option<EdwardsPoint>>,
95        {
96            let nafs: Vec<_> = scalars
97                .into_iter()
98                .map(|c| c.borrow().non_adjacent_form(5))
99                .collect();
100            let lookup_tables: Vec<_> = points
101                .into_iter()
102                .map(|P_opt| P_opt.map(|P| NafLookupTable5::<CachedPoint>::from(&P)))
103                .collect::<Option<Vec<_>>>()?;
104
105            let mut Q = ExtendedPoint::identity();
106
107            for i in (0..256).rev() {
108                Q = Q.double();
109
110                for (naf, lookup_table) in nafs.iter().zip(lookup_tables.iter()) {
111                    match naf[i].cmp(&0) {
112                        Ordering::Greater => {
113                            Q = &Q + &lookup_table.select(naf[i] as usize);
114                        }
115                        Ordering::Less => {
116                            Q = &Q - &lookup_table.select(-naf[i] as usize);
117                        }
118                        Ordering::Equal => {}
119                    }
120                }
121            }
122
123            Some(Q.into())
124        }
125    }
126}