curve25519_dalek/backend/vector/scalar_mul/
pippenger.rs

1// -*- mode: rust; -*-
2//
3// This file is part of curve25519-dalek.
4// Copyright (c) 2019 Oleg Andreev
5// See LICENSE for licensing information.
6//
7// Authors:
8// - Oleg Andreev <oleganza@gmail.com>
9
10#![allow(non_snake_case)]
11
12#[curve25519_dalek_derive::unsafe_target_feature_specialize(
13    "avx2",
14    conditional("avx512ifma,avx512vl", nightly)
15)]
16pub mod spec {
17
18    use alloc::vec::Vec;
19
20    use core::borrow::Borrow;
21    use core::cmp::Ordering;
22
23    #[for_target_feature("avx2")]
24    use crate::backend::vector::avx2::{CachedPoint, ExtendedPoint};
25
26    #[for_target_feature("avx512ifma")]
27    use crate::backend::vector::ifma::{CachedPoint, ExtendedPoint};
28
29    use crate::edwards::EdwardsPoint;
30    use crate::scalar::Scalar;
31    use crate::traits::{Identity, VartimeMultiscalarMul};
32
33    /// Implements a version of Pippenger's algorithm.
34    ///
35    /// See the documentation in the serial `scalar_mul::pippenger` module for details.
36    pub struct Pippenger;
37
38    impl VartimeMultiscalarMul for Pippenger {
39        type Point = EdwardsPoint;
40
41        fn optional_multiscalar_mul<I, J>(scalars: I, points: J) -> Option<EdwardsPoint>
42        where
43            I: IntoIterator,
44            I::Item: Borrow<Scalar>,
45            J: IntoIterator<Item = Option<EdwardsPoint>>,
46        {
47            let mut scalars = scalars.into_iter();
48            let size = scalars.by_ref().size_hint().0;
49            let w = if size < 500 {
50                6
51            } else if size < 800 {
52                7
53            } else {
54                8
55            };
56
57            let max_digit: usize = 1 << w;
58            let digits_count: usize = Scalar::to_radix_2w_size_hint(w);
59            let buckets_count: usize = max_digit / 2; // digits are signed+centered hence 2^w/2, excluding 0-th bucket
60
61            // Collect optimized scalars and points in a buffer for repeated access
62            // (scanning the whole collection per each digit position).
63            let scalars = scalars.map(|s| s.borrow().as_radix_2w(w));
64
65            let points = points
66                .into_iter()
67                .map(|p| p.map(|P| CachedPoint::from(ExtendedPoint::from(P))));
68
69            let scalars_points = scalars
70                .zip(points)
71                .map(|(s, maybe_p)| maybe_p.map(|p| (s, p)))
72                .collect::<Option<Vec<_>>>()?;
73
74            // Prepare 2^w/2 buckets.
75            // buckets[i] corresponds to a multiplication factor (i+1).
76            let mut buckets: Vec<ExtendedPoint> = (0..buckets_count)
77                .map(|_| ExtendedPoint::identity())
78                .collect();
79
80            let mut columns = (0..digits_count).rev().map(|digit_index| {
81                // Clear the buckets when processing another digit.
82                for bucket in &mut buckets {
83                    *bucket = ExtendedPoint::identity();
84                }
85
86                // Iterate over pairs of (point, scalar)
87                // and add/sub the point to the corresponding bucket.
88                // Note: if we add support for precomputed lookup tables,
89                // we'll be adding/subtractiong point premultiplied by `digits[i]` to buckets[0].
90                for (digits, pt) in scalars_points.iter() {
91                    // Widen digit so that we don't run into edge cases when w=8.
92                    let digit = digits[digit_index] as i16;
93                    match digit.cmp(&0) {
94                        Ordering::Greater => {
95                            let b = (digit - 1) as usize;
96                            buckets[b] = &buckets[b] + pt;
97                        }
98                        Ordering::Less => {
99                            let b = (-digit - 1) as usize;
100                            buckets[b] = &buckets[b] - pt;
101                        }
102                        Ordering::Equal => {}
103                    }
104                }
105
106                // Add the buckets applying the multiplication factor to each bucket.
107                // The most efficient way to do that is to have a single sum with two running sums:
108                // an intermediate sum from last bucket to the first, and a sum of intermediate sums.
109                //
110                // For example, to add buckets 1*A, 2*B, 3*C we need to add these points:
111                //   C
112                //   C B
113                //   C B A   Sum = C + (C+B) + (C+B+A)
114                let mut buckets_intermediate_sum = buckets[buckets_count - 1];
115                let mut buckets_sum = buckets[buckets_count - 1];
116                for i in (0..(buckets_count - 1)).rev() {
117                    buckets_intermediate_sum =
118                        &buckets_intermediate_sum + &CachedPoint::from(buckets[i]);
119                    buckets_sum = &buckets_sum + &CachedPoint::from(buckets_intermediate_sum);
120                }
121
122                buckets_sum
123            });
124
125            // Take the high column as an initial value to avoid wasting time doubling the identity element in `fold()`.
126            let hi_column = columns.next().expect("should have more than zero digits");
127
128            Some(
129                columns
130                    .fold(hi_column, |total, p| {
131                        &total.mul_by_pow_2(w as u32) + &CachedPoint::from(p)
132                    })
133                    .into(),
134            )
135        }
136    }
137
138    #[cfg(test)]
139    mod test {
140        #[test]
141        fn test_vartime_pippenger() {
142            use super::*;
143            use crate::constants;
144            use crate::scalar::Scalar;
145
146            // Reuse points across different tests
147            let mut n = 512;
148            let x = Scalar::from(2128506u64).invert();
149            let y = Scalar::from(4443282u64).invert();
150            let points: Vec<_> = (0..n)
151                .map(|i| constants::ED25519_BASEPOINT_POINT * Scalar::from(1 + i as u64))
152                .collect();
153            let scalars: Vec<_> = (0..n)
154                .map(|i| x + (Scalar::from(i as u64) * y)) // fast way to make ~random but deterministic scalars
155                .collect();
156
157            let premultiplied: Vec<EdwardsPoint> = scalars
158                .iter()
159                .zip(points.iter())
160                .map(|(sc, pt)| sc * pt)
161                .collect();
162
163            while n > 0 {
164                let scalars = &scalars[0..n].to_vec();
165                let points = &points[0..n].to_vec();
166                let control: EdwardsPoint = premultiplied[0..n].iter().sum();
167
168                let subject = Pippenger::vartime_multiscalar_mul(scalars.clone(), points.clone());
169
170                assert_eq!(subject.compress(), control.compress());
171
172                n = n / 2;
173            }
174        }
175    }
176}