curve25519_dalek/backend/vector/scalar_mul/
straus.rs1#![allow(non_snake_case)]
13
14#[curve25519_dalek_derive::unsafe_target_feature_specialize(
15 "avx2",
16 conditional("avx512ifma,avx512vl", nightly)
17)]
18pub mod spec {
19
20 use alloc::vec::Vec;
21
22 use core::borrow::Borrow;
23 use core::cmp::Ordering;
24
25 #[cfg(feature = "zeroize")]
26 use zeroize::Zeroizing;
27
28 #[for_target_feature("avx2")]
29 use crate::backend::vector::avx2::{CachedPoint, ExtendedPoint};
30
31 #[for_target_feature("avx512ifma")]
32 use crate::backend::vector::ifma::{CachedPoint, ExtendedPoint};
33
34 use crate::edwards::EdwardsPoint;
35 use crate::scalar::Scalar;
36 use crate::traits::{Identity, MultiscalarMul, VartimeMultiscalarMul};
37 use crate::window::{LookupTable, NafLookupTable5};
38
39 pub struct Straus {}
48
49 impl MultiscalarMul for Straus {
50 type Point = EdwardsPoint;
51
52 fn multiscalar_mul<I, J>(scalars: I, points: J) -> EdwardsPoint
53 where
54 I: IntoIterator,
55 I::Item: Borrow<Scalar>,
56 J: IntoIterator,
57 J::Item: Borrow<EdwardsPoint>,
58 {
59 let lookup_tables: Vec<_> = points
62 .into_iter()
63 .map(|point| LookupTable::<CachedPoint>::from(point.borrow()))
64 .collect();
65
66 let scalar_digits_vec: Vec<_> = scalars
67 .into_iter()
68 .map(|s| s.borrow().as_radix_16())
69 .collect();
70 #[cfg(feature = "zeroize")]
72 let scalar_digits_vec = Zeroizing::new(scalar_digits_vec);
73
74 let mut Q = ExtendedPoint::identity();
75 for j in (0..64).rev() {
76 Q = Q.mul_by_pow_2(4);
77 let it = scalar_digits_vec.iter().zip(lookup_tables.iter());
78 for (s_i, lookup_table_i) in it {
79 Q = &Q + &lookup_table_i.select(s_i[j]);
81 }
82 }
83 Q.into()
84 }
85 }
86
87 impl VartimeMultiscalarMul for Straus {
88 type Point = EdwardsPoint;
89
90 fn optional_multiscalar_mul<I, J>(scalars: I, points: J) -> Option<EdwardsPoint>
91 where
92 I: IntoIterator,
93 I::Item: Borrow<Scalar>,
94 J: IntoIterator<Item = Option<EdwardsPoint>>,
95 {
96 let nafs: Vec<_> = scalars
97 .into_iter()
98 .map(|c| c.borrow().non_adjacent_form(5))
99 .collect();
100 let lookup_tables: Vec<_> = points
101 .into_iter()
102 .map(|P_opt| P_opt.map(|P| NafLookupTable5::<CachedPoint>::from(&P)))
103 .collect::<Option<Vec<_>>>()?;
104
105 let mut Q = ExtendedPoint::identity();
106
107 for i in (0..256).rev() {
108 Q = Q.double();
109
110 for (naf, lookup_table) in nafs.iter().zip(lookup_tables.iter()) {
111 match naf[i].cmp(&0) {
112 Ordering::Greater => {
113 Q = &Q + &lookup_table.select(naf[i] as usize);
114 }
115 Ordering::Less => {
116 Q = &Q - &lookup_table.select(-naf[i] as usize);
117 }
118 Ordering::Equal => {}
119 }
120 }
121 }
122
123 Some(Q.into())
124 }
125 }
126}