solana_curve25519/
edwards.rs

1use bytemuck_derive::{Pod, Zeroable};
2pub use target_arch::*;
3
4#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Pod, Zeroable)]
5#[repr(transparent)]
6pub struct PodEdwardsPoint(pub [u8; 32]);
7
8#[cfg(not(target_os = "solana"))]
9mod target_arch {
10    use {
11        super::*,
12        crate::{
13            curve_syscall_traits::{GroupOperations, MultiScalarMultiplication, PointValidation},
14            errors::Curve25519Error,
15            scalar::PodScalar,
16        },
17        curve25519_dalek::{
18            edwards::{CompressedEdwardsY, EdwardsPoint},
19            scalar::Scalar,
20            traits::VartimeMultiscalarMul,
21        },
22    };
23
24    pub fn validate_edwards(point: &PodEdwardsPoint) -> bool {
25        point.validate_point()
26    }
27
28    pub fn add_edwards(
29        left_point: &PodEdwardsPoint,
30        right_point: &PodEdwardsPoint,
31    ) -> Option<PodEdwardsPoint> {
32        PodEdwardsPoint::add(left_point, right_point)
33    }
34
35    pub fn subtract_edwards(
36        left_point: &PodEdwardsPoint,
37        right_point: &PodEdwardsPoint,
38    ) -> Option<PodEdwardsPoint> {
39        PodEdwardsPoint::subtract(left_point, right_point)
40    }
41
42    pub fn multiply_edwards(
43        scalar: &PodScalar,
44        point: &PodEdwardsPoint,
45    ) -> Option<PodEdwardsPoint> {
46        PodEdwardsPoint::multiply(scalar, point)
47    }
48
49    pub fn multiscalar_multiply_edwards(
50        scalars: &[PodScalar],
51        points: &[PodEdwardsPoint],
52    ) -> Option<PodEdwardsPoint> {
53        PodEdwardsPoint::multiscalar_multiply(scalars, points)
54    }
55
56    impl From<&EdwardsPoint> for PodEdwardsPoint {
57        fn from(point: &EdwardsPoint) -> Self {
58            Self(point.compress().to_bytes())
59        }
60    }
61
62    impl TryFrom<&PodEdwardsPoint> for EdwardsPoint {
63        type Error = Curve25519Error;
64
65        fn try_from(pod: &PodEdwardsPoint) -> Result<Self, Self::Error> {
66            let Ok(compressed_edwards_y) = CompressedEdwardsY::from_slice(&pod.0) else {
67                return Err(Curve25519Error::PodConversion);
68            };
69            compressed_edwards_y
70                .decompress()
71                .ok_or(Curve25519Error::PodConversion)
72        }
73    }
74
75    impl PointValidation for PodEdwardsPoint {
76        type Point = Self;
77
78        fn validate_point(&self) -> bool {
79            let Ok(compressed_edwards_y) = CompressedEdwardsY::from_slice(&self.0) else {
80                return false;
81            };
82            compressed_edwards_y.decompress().is_some()
83        }
84    }
85
86    impl GroupOperations for PodEdwardsPoint {
87        type Scalar = PodScalar;
88        type Point = Self;
89
90        fn add(left_point: &Self, right_point: &Self) -> Option<Self> {
91            let left_point: EdwardsPoint = left_point.try_into().ok()?;
92            let right_point: EdwardsPoint = right_point.try_into().ok()?;
93
94            let result = &left_point + &right_point;
95            Some((&result).into())
96        }
97
98        fn subtract(left_point: &Self, right_point: &Self) -> Option<Self> {
99            let left_point: EdwardsPoint = left_point.try_into().ok()?;
100            let right_point: EdwardsPoint = right_point.try_into().ok()?;
101
102            let result = &left_point - &right_point;
103            Some((&result).into())
104        }
105
106        #[cfg(not(target_os = "solana"))]
107        fn multiply(scalar: &PodScalar, point: &Self) -> Option<Self> {
108            let scalar: Scalar = scalar.try_into().ok()?;
109            let point: EdwardsPoint = point.try_into().ok()?;
110
111            let result = &scalar * &point;
112            Some((&result).into())
113        }
114    }
115
116    impl MultiScalarMultiplication for PodEdwardsPoint {
117        type Scalar = PodScalar;
118        type Point = Self;
119
120        fn multiscalar_multiply(scalars: &[PodScalar], points: &[Self]) -> Option<Self> {
121            let scalars = scalars
122                .iter()
123                .map(|scalar| Scalar::try_from(scalar).ok())
124                .collect::<Option<Vec<_>>>()?;
125
126            EdwardsPoint::optional_multiscalar_mul(
127                scalars,
128                points
129                    .iter()
130                    .map(|point| EdwardsPoint::try_from(point).ok()),
131            )
132            .map(|result| PodEdwardsPoint::from(&result))
133        }
134    }
135}
136
137#[cfg(target_os = "solana")]
138mod target_arch {
139    use {
140        super::*,
141        crate::{
142            curve_syscall_traits::{ADD, CURVE25519_EDWARDS, MUL, SUB},
143            scalar::PodScalar,
144        },
145        bytemuck::Zeroable,
146    };
147
148    pub fn validate_edwards(point: &PodEdwardsPoint) -> bool {
149        let mut validate_result = 0u8;
150        let result = unsafe {
151            solana_define_syscall::definitions::sol_curve_validate_point(
152                CURVE25519_EDWARDS,
153                &point.0 as *const u8,
154                &mut validate_result,
155            )
156        };
157        result == 0
158    }
159
160    pub fn add_edwards(
161        left_point: &PodEdwardsPoint,
162        right_point: &PodEdwardsPoint,
163    ) -> Option<PodEdwardsPoint> {
164        let mut result_point = PodEdwardsPoint::zeroed();
165        let result = unsafe {
166            solana_define_syscall::definitions::sol_curve_group_op(
167                CURVE25519_EDWARDS,
168                ADD,
169                &left_point.0 as *const u8,
170                &right_point.0 as *const u8,
171                &mut result_point.0 as *mut u8,
172            )
173        };
174
175        if result == 0 {
176            Some(result_point)
177        } else {
178            None
179        }
180    }
181
182    pub fn subtract_edwards(
183        left_point: &PodEdwardsPoint,
184        right_point: &PodEdwardsPoint,
185    ) -> Option<PodEdwardsPoint> {
186        let mut result_point = PodEdwardsPoint::zeroed();
187        let result = unsafe {
188            solana_define_syscall::definitions::sol_curve_group_op(
189                CURVE25519_EDWARDS,
190                SUB,
191                &left_point.0 as *const u8,
192                &right_point.0 as *const u8,
193                &mut result_point.0 as *mut u8,
194            )
195        };
196
197        if result == 0 {
198            Some(result_point)
199        } else {
200            None
201        }
202    }
203
204    pub fn multiply_edwards(
205        scalar: &PodScalar,
206        point: &PodEdwardsPoint,
207    ) -> Option<PodEdwardsPoint> {
208        let mut result_point = PodEdwardsPoint::zeroed();
209        let result = unsafe {
210            solana_define_syscall::definitions::sol_curve_group_op(
211                CURVE25519_EDWARDS,
212                MUL,
213                &scalar.0 as *const u8,
214                &point.0 as *const u8,
215                &mut result_point.0 as *mut u8,
216            )
217        };
218
219        if result == 0 {
220            Some(result_point)
221        } else {
222            None
223        }
224    }
225
226    pub fn multiscalar_multiply_edwards(
227        scalars: &[PodScalar],
228        points: &[PodEdwardsPoint],
229    ) -> Option<PodEdwardsPoint> {
230        let mut result_point = PodEdwardsPoint::zeroed();
231        let result = unsafe {
232            solana_define_syscall::definitions::sol_curve_multiscalar_mul(
233                CURVE25519_EDWARDS,
234                scalars.as_ptr() as *const u8,
235                points.as_ptr() as *const u8,
236                points.len() as u64,
237                &mut result_point.0 as *mut u8,
238            )
239        };
240
241        if result == 0 {
242            Some(result_point)
243        } else {
244            None
245        }
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use {
252        super::*,
253        crate::scalar::PodScalar,
254        curve25519_dalek::{
255            constants::ED25519_BASEPOINT_POINT as G, edwards::EdwardsPoint, traits::Identity,
256        },
257    };
258
259    #[test]
260    fn test_validate_edwards() {
261        let pod = PodEdwardsPoint(G.compress().to_bytes());
262        assert!(validate_edwards(&pod));
263
264        let invalid_bytes = [
265            120, 140, 152, 233, 41, 227, 203, 27, 87, 115, 25, 251, 219, 5, 84, 148, 117, 38, 84,
266            60, 87, 144, 161, 146, 42, 34, 91, 155, 158, 189, 121, 79,
267        ];
268
269        assert!(!validate_edwards(&PodEdwardsPoint(invalid_bytes)));
270    }
271
272    #[test]
273    fn test_edwards_add_subtract() {
274        // identity
275        let identity = PodEdwardsPoint(EdwardsPoint::identity().compress().to_bytes());
276        let point = PodEdwardsPoint([
277            201, 179, 241, 122, 180, 185, 239, 50, 183, 52, 221, 0, 153, 195, 43, 18, 22, 38, 187,
278            206, 179, 192, 210, 58, 53, 45, 150, 98, 89, 17, 158, 11,
279        ]);
280
281        assert_eq!(add_edwards(&point, &identity).unwrap(), point);
282        assert_eq!(subtract_edwards(&point, &identity).unwrap(), point);
283
284        // associativity
285        let point_a = PodEdwardsPoint([
286            33, 124, 71, 170, 117, 69, 151, 247, 59, 12, 95, 125, 133, 166, 64, 5, 2, 27, 90, 27,
287            200, 167, 59, 164, 52, 54, 52, 200, 29, 13, 34, 213,
288        ]);
289        let point_b = PodEdwardsPoint([
290            70, 222, 137, 221, 253, 204, 71, 51, 78, 8, 124, 1, 67, 200, 102, 225, 122, 228, 111,
291            183, 129, 14, 131, 210, 212, 95, 109, 246, 55, 10, 159, 91,
292        ]);
293        let point_c = PodEdwardsPoint([
294            72, 60, 66, 143, 59, 197, 111, 36, 181, 137, 25, 97, 157, 201, 247, 215, 123, 83, 220,
295            250, 154, 150, 180, 192, 196, 28, 215, 137, 34, 247, 39, 129,
296        ]);
297
298        assert_eq!(
299            add_edwards(&add_edwards(&point_a, &point_b).unwrap(), &point_c),
300            add_edwards(&point_a, &add_edwards(&point_b, &point_c).unwrap()),
301        );
302
303        assert_eq!(
304            subtract_edwards(&subtract_edwards(&point_a, &point_b).unwrap(), &point_c),
305            subtract_edwards(&point_a, &add_edwards(&point_b, &point_c).unwrap()),
306        );
307
308        // commutativity
309        assert_eq!(
310            add_edwards(&point_a, &point_b).unwrap(),
311            add_edwards(&point_b, &point_a).unwrap(),
312        );
313
314        // subtraction
315        let point = PodEdwardsPoint(G.compress().to_bytes());
316        let point_negated = PodEdwardsPoint((-G).compress().to_bytes());
317
318        assert_eq!(point_negated, subtract_edwards(&identity, &point).unwrap(),)
319    }
320
321    #[test]
322    fn test_edwards_mul() {
323        let scalar_a = PodScalar([
324            72, 191, 131, 55, 85, 86, 54, 60, 116, 10, 39, 130, 180, 3, 90, 227, 47, 228, 252, 99,
325            151, 71, 118, 29, 34, 102, 117, 114, 120, 50, 57, 8,
326        ]);
327        let point_x = PodEdwardsPoint([
328            176, 121, 6, 191, 108, 161, 206, 141, 73, 14, 235, 97, 49, 68, 48, 112, 98, 215, 145,
329            208, 44, 188, 70, 10, 180, 124, 230, 15, 98, 165, 104, 85,
330        ]);
331        let point_y = PodEdwardsPoint([
332            174, 86, 89, 208, 236, 123, 223, 128, 75, 54, 228, 232, 220, 100, 205, 108, 237, 97,
333            105, 79, 74, 192, 67, 224, 185, 23, 157, 116, 216, 151, 223, 81,
334        ]);
335
336        let ax = multiply_edwards(&scalar_a, &point_x).unwrap();
337        let bx = multiply_edwards(&scalar_a, &point_y).unwrap();
338
339        assert_eq!(
340            add_edwards(&ax, &bx),
341            multiply_edwards(&scalar_a, &add_edwards(&point_x, &point_y).unwrap()),
342        );
343    }
344
345    #[test]
346    fn test_multiscalar_multiplication_edwards() {
347        let scalar = PodScalar([
348            205, 73, 127, 173, 83, 80, 190, 66, 202, 3, 237, 77, 52, 223, 238, 70, 80, 242, 24, 87,
349            111, 84, 49, 63, 194, 76, 202, 108, 62, 240, 83, 15,
350        ]);
351        let point = PodEdwardsPoint([
352            222, 174, 184, 139, 143, 122, 253, 96, 0, 207, 120, 157, 112, 38, 54, 189, 91, 144, 78,
353            111, 111, 122, 140, 183, 65, 250, 191, 133, 6, 42, 212, 93,
354        ]);
355
356        let basic_product = multiply_edwards(&scalar, &point).unwrap();
357        let msm_product = multiscalar_multiply_edwards(&[scalar], &[point]).unwrap();
358
359        assert_eq!(basic_product, msm_product);
360
361        let scalar_a = PodScalar([
362            246, 154, 34, 110, 31, 185, 50, 1, 252, 194, 163, 56, 211, 18, 101, 192, 57, 225, 207,
363            69, 19, 84, 231, 118, 137, 175, 148, 218, 106, 212, 69, 9,
364        ]);
365        let scalar_b = PodScalar([
366            27, 58, 126, 136, 253, 178, 176, 245, 246, 55, 15, 202, 35, 183, 66, 199, 134, 187,
367            169, 154, 66, 120, 169, 193, 75, 4, 33, 241, 126, 227, 59, 3,
368        ]);
369        let point_x = PodEdwardsPoint([
370            252, 31, 230, 46, 173, 95, 144, 148, 158, 157, 63, 10, 8, 68, 58, 176, 142, 192, 168,
371            53, 61, 105, 194, 166, 43, 56, 246, 236, 28, 146, 114, 133,
372        ]);
373        let point_y = PodEdwardsPoint([
374            10, 111, 8, 236, 97, 189, 124, 69, 89, 176, 222, 39, 199, 253, 111, 11, 248, 186, 128,
375            90, 120, 128, 248, 210, 232, 183, 93, 104, 111, 150, 7, 241,
376        ]);
377
378        let ax = multiply_edwards(&scalar_a, &point_x).unwrap();
379        let by = multiply_edwards(&scalar_b, &point_y).unwrap();
380        let basic_product = add_edwards(&ax, &by).unwrap();
381        let msm_product =
382            multiscalar_multiply_edwards(&[scalar_a, scalar_b], &[point_x, point_y]).unwrap();
383
384        assert_eq!(basic_product, msm_product);
385    }
386}