curve25519_dalek/backend/vector/scalar_mul/
pippenger.rs1#![allow(non_snake_case)]
11
12#[curve25519_dalek_derive::unsafe_target_feature_specialize(
13 "avx2",
14 conditional("avx512ifma,avx512vl", nightly)
15)]
16pub mod spec {
17
18 use alloc::vec::Vec;
19
20 use core::borrow::Borrow;
21 use core::cmp::Ordering;
22
23 #[for_target_feature("avx2")]
24 use crate::backend::vector::avx2::{CachedPoint, ExtendedPoint};
25
26 #[for_target_feature("avx512ifma")]
27 use crate::backend::vector::ifma::{CachedPoint, ExtendedPoint};
28
29 use crate::edwards::EdwardsPoint;
30 use crate::scalar::Scalar;
31 use crate::traits::{Identity, VartimeMultiscalarMul};
32
33 pub struct Pippenger;
37
38 impl VartimeMultiscalarMul for Pippenger {
39 type Point = EdwardsPoint;
40
41 fn optional_multiscalar_mul<I, J>(scalars: I, points: J) -> Option<EdwardsPoint>
42 where
43 I: IntoIterator,
44 I::Item: Borrow<Scalar>,
45 J: IntoIterator<Item = Option<EdwardsPoint>>,
46 {
47 let mut scalars = scalars.into_iter();
48 let size = scalars.by_ref().size_hint().0;
49 let w = if size < 500 {
50 6
51 } else if size < 800 {
52 7
53 } else {
54 8
55 };
56
57 let max_digit: usize = 1 << w;
58 let digits_count: usize = Scalar::to_radix_2w_size_hint(w);
59 let buckets_count: usize = max_digit / 2; let scalars = scalars.map(|s| s.borrow().as_radix_2w(w));
64
65 let points = points
66 .into_iter()
67 .map(|p| p.map(|P| CachedPoint::from(ExtendedPoint::from(P))));
68
69 let scalars_points = scalars
70 .zip(points)
71 .map(|(s, maybe_p)| maybe_p.map(|p| (s, p)))
72 .collect::<Option<Vec<_>>>()?;
73
74 let mut buckets: Vec<ExtendedPoint> = (0..buckets_count)
77 .map(|_| ExtendedPoint::identity())
78 .collect();
79
80 let mut columns = (0..digits_count).rev().map(|digit_index| {
81 for bucket in &mut buckets {
83 *bucket = ExtendedPoint::identity();
84 }
85
86 for (digits, pt) in scalars_points.iter() {
91 let digit = digits[digit_index] as i16;
93 match digit.cmp(&0) {
94 Ordering::Greater => {
95 let b = (digit - 1) as usize;
96 buckets[b] = &buckets[b] + pt;
97 }
98 Ordering::Less => {
99 let b = (-digit - 1) as usize;
100 buckets[b] = &buckets[b] - pt;
101 }
102 Ordering::Equal => {}
103 }
104 }
105
106 let mut buckets_intermediate_sum = buckets[buckets_count - 1];
115 let mut buckets_sum = buckets[buckets_count - 1];
116 for i in (0..(buckets_count - 1)).rev() {
117 buckets_intermediate_sum =
118 &buckets_intermediate_sum + &CachedPoint::from(buckets[i]);
119 buckets_sum = &buckets_sum + &CachedPoint::from(buckets_intermediate_sum);
120 }
121
122 buckets_sum
123 });
124
125 let hi_column = columns.next().expect("should have more than zero digits");
127
128 Some(
129 columns
130 .fold(hi_column, |total, p| {
131 &total.mul_by_pow_2(w as u32) + &CachedPoint::from(p)
132 })
133 .into(),
134 )
135 }
136 }
137
138 #[cfg(test)]
139 mod test {
140 #[test]
141 fn test_vartime_pippenger() {
142 use super::*;
143 use crate::constants;
144 use crate::scalar::Scalar;
145
146 let mut n = 512;
148 let x = Scalar::from(2128506u64).invert();
149 let y = Scalar::from(4443282u64).invert();
150 let points: Vec<_> = (0..n)
151 .map(|i| constants::ED25519_BASEPOINT_POINT * Scalar::from(1 + i as u64))
152 .collect();
153 let scalars: Vec<_> = (0..n)
154 .map(|i| x + (Scalar::from(i as u64) * y)) .collect();
156
157 let premultiplied: Vec<EdwardsPoint> = scalars
158 .iter()
159 .zip(points.iter())
160 .map(|(sc, pt)| sc * pt)
161 .collect();
162
163 while n > 0 {
164 let scalars = &scalars[0..n].to_vec();
165 let points = &points[0..n].to_vec();
166 let control: EdwardsPoint = premultiplied[0..n].iter().sum();
167
168 let subject = Pippenger::vartime_multiscalar_mul(scalars.clone(), points.clone());
169
170 assert_eq!(subject.compress(), control.compress());
171
172 n = n / 2;
173 }
174 }
175 }
176}