solana_curve25519/
ristretto.rs

1use bytemuck_derive::{Pod, Zeroable};
2pub use target_arch::*;
3
4#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Pod, Zeroable)]
5#[repr(transparent)]
6pub struct PodRistrettoPoint(pub [u8; 32]);
7
8#[cfg(not(target_os = "solana"))]
9mod target_arch {
10    use {
11        super::*,
12        crate::{
13            curve_syscall_traits::{GroupOperations, MultiScalarMultiplication, PointValidation},
14            errors::Curve25519Error,
15            scalar::PodScalar,
16        },
17        curve25519_dalek::{
18            ristretto::{CompressedRistretto, RistrettoPoint},
19            scalar::Scalar,
20            traits::VartimeMultiscalarMul,
21        },
22    };
23
24    pub fn validate_ristretto(point: &PodRistrettoPoint) -> bool {
25        point.validate_point()
26    }
27
28    pub fn add_ristretto(
29        left_point: &PodRistrettoPoint,
30        right_point: &PodRistrettoPoint,
31    ) -> Option<PodRistrettoPoint> {
32        PodRistrettoPoint::add(left_point, right_point)
33    }
34
35    pub fn subtract_ristretto(
36        left_point: &PodRistrettoPoint,
37        right_point: &PodRistrettoPoint,
38    ) -> Option<PodRistrettoPoint> {
39        PodRistrettoPoint::subtract(left_point, right_point)
40    }
41
42    pub fn multiply_ristretto(
43        scalar: &PodScalar,
44        point: &PodRistrettoPoint,
45    ) -> Option<PodRistrettoPoint> {
46        PodRistrettoPoint::multiply(scalar, point)
47    }
48
49    pub fn multiscalar_multiply_ristretto(
50        scalars: &[PodScalar],
51        points: &[PodRistrettoPoint],
52    ) -> Option<PodRistrettoPoint> {
53        PodRistrettoPoint::multiscalar_multiply(scalars, points)
54    }
55
56    impl From<&RistrettoPoint> for PodRistrettoPoint {
57        fn from(point: &RistrettoPoint) -> Self {
58            Self(point.compress().to_bytes())
59        }
60    }
61
62    impl TryFrom<&PodRistrettoPoint> for RistrettoPoint {
63        type Error = Curve25519Error;
64
65        fn try_from(pod: &PodRistrettoPoint) -> Result<Self, Self::Error> {
66            let Ok(compressed_ristretto) = CompressedRistretto::from_slice(&pod.0) else {
67                return Err(Curve25519Error::PodConversion);
68            };
69            compressed_ristretto
70                .decompress()
71                .ok_or(Curve25519Error::PodConversion)
72        }
73    }
74
75    impl PointValidation for PodRistrettoPoint {
76        type Point = Self;
77
78        fn validate_point(&self) -> bool {
79            let Ok(compressed_ristretto) = CompressedRistretto::from_slice(&self.0) else {
80                return false;
81            };
82            compressed_ristretto.decompress().is_some()
83        }
84    }
85
86    impl GroupOperations for PodRistrettoPoint {
87        type Scalar = PodScalar;
88        type Point = Self;
89
90        fn add(left_point: &Self, right_point: &Self) -> Option<Self> {
91            let left_point: RistrettoPoint = left_point.try_into().ok()?;
92            let right_point: RistrettoPoint = right_point.try_into().ok()?;
93
94            let result = &left_point + &right_point;
95            Some((&result).into())
96        }
97
98        fn subtract(left_point: &Self, right_point: &Self) -> Option<Self> {
99            let left_point: RistrettoPoint = left_point.try_into().ok()?;
100            let right_point: RistrettoPoint = right_point.try_into().ok()?;
101
102            let result = &left_point - &right_point;
103            Some((&result).into())
104        }
105
106        #[cfg(not(target_os = "solana"))]
107        fn multiply(scalar: &PodScalar, point: &Self) -> Option<Self> {
108            let scalar: Scalar = scalar.try_into().ok()?;
109            let point: RistrettoPoint = point.try_into().ok()?;
110
111            let result = &scalar * &point;
112            Some((&result).into())
113        }
114    }
115
116    impl MultiScalarMultiplication for PodRistrettoPoint {
117        type Scalar = PodScalar;
118        type Point = Self;
119
120        fn multiscalar_multiply(scalars: &[PodScalar], points: &[Self]) -> Option<Self> {
121            let scalars = scalars
122                .iter()
123                .map(|scalar| Scalar::try_from(scalar).ok())
124                .collect::<Option<Vec<_>>>()?;
125
126            RistrettoPoint::optional_multiscalar_mul(
127                scalars,
128                points
129                    .iter()
130                    .map(|point| RistrettoPoint::try_from(point).ok()),
131            )
132            .map(|result| PodRistrettoPoint::from(&result))
133        }
134    }
135}
136
137#[cfg(target_os = "solana")]
138#[allow(unused_variables)]
139mod target_arch {
140    use {
141        super::*,
142        crate::{
143            curve_syscall_traits::{ADD, CURVE25519_RISTRETTO, MUL, SUB},
144            scalar::PodScalar,
145        },
146        bytemuck::Zeroable,
147    };
148
149    pub fn validate_ristretto(point: &PodRistrettoPoint) -> bool {
150        let mut validate_result = 0u8;
151        let result = unsafe {
152            solana_define_syscall::definitions::sol_curve_validate_point(
153                CURVE25519_RISTRETTO,
154                &point.0 as *const u8,
155                &mut validate_result,
156            )
157        };
158
159        result == 0
160    }
161
162    pub fn add_ristretto(
163        left_point: &PodRistrettoPoint,
164        right_point: &PodRistrettoPoint,
165    ) -> Option<PodRistrettoPoint> {
166        let mut result_point = PodRistrettoPoint::zeroed();
167        let result = unsafe {
168            solana_define_syscall::definitions::sol_curve_group_op(
169                CURVE25519_RISTRETTO,
170                ADD,
171                &left_point.0 as *const u8,
172                &right_point.0 as *const u8,
173                &mut result_point.0 as *mut u8,
174            )
175        };
176
177        if result == 0 {
178            Some(result_point)
179        } else {
180            None
181        }
182    }
183
184    pub fn subtract_ristretto(
185        left_point: &PodRistrettoPoint,
186        right_point: &PodRistrettoPoint,
187    ) -> Option<PodRistrettoPoint> {
188        let mut result_point = PodRistrettoPoint::zeroed();
189        let result = unsafe {
190            solana_define_syscall::definitions::sol_curve_group_op(
191                CURVE25519_RISTRETTO,
192                SUB,
193                &left_point.0 as *const u8,
194                &right_point.0 as *const u8,
195                &mut result_point.0 as *mut u8,
196            )
197        };
198
199        if result == 0 {
200            Some(result_point)
201        } else {
202            None
203        }
204    }
205
206    pub fn multiply_ristretto(
207        scalar: &PodScalar,
208        point: &PodRistrettoPoint,
209    ) -> Option<PodRistrettoPoint> {
210        let mut result_point = PodRistrettoPoint::zeroed();
211        let result = unsafe {
212            solana_define_syscall::definitions::sol_curve_group_op(
213                CURVE25519_RISTRETTO,
214                MUL,
215                &scalar.0 as *const u8,
216                &point.0 as *const u8,
217                &mut result_point.0 as *mut u8,
218            )
219        };
220
221        if result == 0 {
222            Some(result_point)
223        } else {
224            None
225        }
226    }
227
228    pub fn multiscalar_multiply_ristretto(
229        scalars: &[PodScalar],
230        points: &[PodRistrettoPoint],
231    ) -> Option<PodRistrettoPoint> {
232        let mut result_point = PodRistrettoPoint::zeroed();
233        let result = unsafe {
234            solana_define_syscall::definitions::sol_curve_multiscalar_mul(
235                CURVE25519_RISTRETTO,
236                scalars.as_ptr() as *const u8,
237                points.as_ptr() as *const u8,
238                points.len() as u64,
239                &mut result_point.0 as *mut u8,
240            )
241        };
242
243        if result == 0 {
244            Some(result_point)
245        } else {
246            None
247        }
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use {
254        super::*,
255        crate::scalar::PodScalar,
256        curve25519_dalek::{
257            constants::RISTRETTO_BASEPOINT_POINT as G, ristretto::RistrettoPoint, traits::Identity,
258        },
259    };
260
261    #[test]
262    fn test_validate_ristretto() {
263        let pod = PodRistrettoPoint(G.compress().to_bytes());
264        assert!(validate_ristretto(&pod));
265
266        let invalid_bytes = [
267            120, 140, 152, 233, 41, 227, 203, 27, 87, 115, 25, 251, 219, 5, 84, 148, 117, 38, 84,
268            60, 87, 144, 161, 146, 42, 34, 91, 155, 158, 189, 121, 79,
269        ];
270
271        assert!(!validate_ristretto(&PodRistrettoPoint(invalid_bytes)));
272    }
273
274    #[test]
275    fn test_add_subtract_ristretto() {
276        // identity
277        let identity = PodRistrettoPoint(RistrettoPoint::identity().compress().to_bytes());
278        let point = PodRistrettoPoint([
279            210, 174, 124, 127, 67, 77, 11, 114, 71, 63, 168, 136, 113, 20, 141, 228, 195, 254,
280            232, 229, 220, 249, 213, 232, 61, 238, 152, 249, 83, 225, 206, 16,
281        ]);
282
283        assert_eq!(add_ristretto(&point, &identity).unwrap(), point);
284        assert_eq!(subtract_ristretto(&point, &identity).unwrap(), point);
285
286        // associativity
287        let point_a = PodRistrettoPoint([
288            208, 165, 125, 204, 2, 100, 218, 17, 170, 194, 23, 9, 102, 156, 134, 136, 217, 190, 98,
289            34, 183, 194, 228, 153, 92, 11, 108, 103, 28, 57, 88, 15,
290        ]);
291        let point_b = PodRistrettoPoint([
292            208, 241, 72, 163, 73, 53, 32, 174, 54, 194, 71, 8, 70, 181, 244, 199, 93, 147, 99,
293            231, 162, 127, 25, 40, 39, 19, 140, 132, 112, 212, 145, 108,
294        ]);
295        let point_c = PodRistrettoPoint([
296            250, 61, 200, 25, 195, 15, 144, 179, 24, 17, 252, 167, 247, 44, 47, 41, 104, 237, 49,
297            137, 231, 173, 86, 106, 121, 249, 245, 247, 70, 188, 31, 49,
298        ]);
299
300        assert_eq!(
301            add_ristretto(&add_ristretto(&point_a, &point_b).unwrap(), &point_c),
302            add_ristretto(&point_a, &add_ristretto(&point_b, &point_c).unwrap()),
303        );
304
305        assert_eq!(
306            subtract_ristretto(&subtract_ristretto(&point_a, &point_b).unwrap(), &point_c),
307            subtract_ristretto(&point_a, &add_ristretto(&point_b, &point_c).unwrap()),
308        );
309
310        // commutativity
311        assert_eq!(
312            add_ristretto(&point_a, &point_b).unwrap(),
313            add_ristretto(&point_b, &point_a).unwrap(),
314        );
315
316        // subtraction
317        let point = PodRistrettoPoint(G.compress().to_bytes());
318        let point_negated = PodRistrettoPoint((-G).compress().to_bytes());
319
320        assert_eq!(
321            point_negated,
322            subtract_ristretto(&identity, &point).unwrap(),
323        )
324    }
325
326    #[test]
327    fn test_multiply_ristretto() {
328        let scalar_x = PodScalar([
329            254, 198, 23, 138, 67, 243, 184, 110, 236, 115, 236, 205, 205, 215, 79, 114, 45, 250,
330            78, 137, 3, 107, 136, 237, 49, 126, 117, 223, 37, 191, 88, 6,
331        ]);
332        let point_a = PodRistrettoPoint([
333            68, 80, 232, 181, 241, 77, 60, 81, 154, 51, 173, 35, 98, 234, 149, 37, 1, 39, 191, 201,
334            193, 48, 88, 189, 97, 126, 63, 35, 144, 145, 203, 31,
335        ]);
336        let point_b = PodRistrettoPoint([
337            200, 236, 1, 12, 244, 130, 226, 214, 28, 125, 43, 163, 222, 234, 81, 213, 201, 156, 31,
338            4, 167, 132, 240, 76, 164, 18, 45, 20, 48, 85, 206, 121,
339        ]);
340
341        let ax = multiply_ristretto(&scalar_x, &point_a).unwrap();
342        let bx = multiply_ristretto(&scalar_x, &point_b).unwrap();
343
344        assert_eq!(
345            add_ristretto(&ax, &bx),
346            multiply_ristretto(&scalar_x, &add_ristretto(&point_a, &point_b).unwrap()),
347        );
348    }
349
350    #[test]
351    fn test_multiscalar_multiplication_ristretto() {
352        let scalar = PodScalar([
353            123, 108, 109, 66, 154, 185, 88, 122, 178, 43, 17, 154, 201, 223, 31, 238, 59, 215, 71,
354            154, 215, 143, 177, 158, 9, 136, 32, 223, 139, 13, 133, 5,
355        ]);
356        let point = PodRistrettoPoint([
357            158, 2, 130, 90, 148, 36, 172, 155, 86, 196, 74, 139, 30, 98, 44, 225, 155, 207, 135,
358            111, 238, 167, 235, 67, 234, 125, 0, 227, 146, 31, 24, 113,
359        ]);
360
361        let basic_product = multiply_ristretto(&scalar, &point).unwrap();
362        let msm_product = multiscalar_multiply_ristretto(&[scalar], &[point]).unwrap();
363
364        assert_eq!(basic_product, msm_product);
365
366        let scalar_a = PodScalar([
367            8, 161, 219, 155, 192, 137, 153, 26, 27, 40, 30, 17, 124, 194, 26, 41, 32, 7, 161, 45,
368            212, 198, 212, 81, 133, 185, 164, 85, 95, 232, 106, 10,
369        ]);
370        let scalar_b = PodScalar([
371            135, 207, 106, 208, 107, 127, 46, 82, 66, 22, 136, 125, 105, 62, 69, 34, 213, 210, 17,
372            196, 120, 114, 238, 237, 149, 170, 5, 243, 54, 77, 172, 12,
373        ]);
374        let point_x = PodRistrettoPoint([
375            130, 35, 97, 25, 18, 199, 33, 239, 85, 143, 119, 111, 49, 51, 224, 40, 167, 185, 240,
376            179, 25, 194, 213, 41, 14, 155, 104, 18, 181, 197, 15, 112,
377        ]);
378        let point_y = PodRistrettoPoint([
379            152, 156, 155, 197, 152, 232, 92, 206, 219, 159, 193, 134, 121, 128, 139, 36, 56, 191,
380            51, 143, 72, 204, 87, 76, 110, 124, 101, 96, 238, 158, 42, 108,
381        ]);
382
383        let ax = multiply_ristretto(&scalar_a, &point_x).unwrap();
384        let by = multiply_ristretto(&scalar_b, &point_y).unwrap();
385        let basic_product = add_ristretto(&ax, &by).unwrap();
386        let msm_product =
387            multiscalar_multiply_ristretto(&[scalar_a, scalar_b], &[point_x, point_y]).unwrap();
388
389        assert_eq!(basic_product, msm_product);
390    }
391}