1use bytemuck_derive::{Pod, Zeroable};
2pub use target_arch::*;
3
4#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Pod, Zeroable)]
5#[repr(transparent)]
6pub struct PodRistrettoPoint(pub [u8; 32]);
7
8#[cfg(not(target_os = "solana"))]
9mod target_arch {
10 use {
11 super::*,
12 crate::{
13 curve_syscall_traits::{GroupOperations, MultiScalarMultiplication, PointValidation},
14 errors::Curve25519Error,
15 scalar::PodScalar,
16 },
17 curve25519_dalek::{
18 ristretto::{CompressedRistretto, RistrettoPoint},
19 scalar::Scalar,
20 traits::VartimeMultiscalarMul,
21 },
22 };
23
24 pub fn validate_ristretto(point: &PodRistrettoPoint) -> bool {
25 point.validate_point()
26 }
27
28 pub fn add_ristretto(
29 left_point: &PodRistrettoPoint,
30 right_point: &PodRistrettoPoint,
31 ) -> Option<PodRistrettoPoint> {
32 PodRistrettoPoint::add(left_point, right_point)
33 }
34
35 pub fn subtract_ristretto(
36 left_point: &PodRistrettoPoint,
37 right_point: &PodRistrettoPoint,
38 ) -> Option<PodRistrettoPoint> {
39 PodRistrettoPoint::subtract(left_point, right_point)
40 }
41
42 pub fn multiply_ristretto(
43 scalar: &PodScalar,
44 point: &PodRistrettoPoint,
45 ) -> Option<PodRistrettoPoint> {
46 PodRistrettoPoint::multiply(scalar, point)
47 }
48
49 pub fn multiscalar_multiply_ristretto(
50 scalars: &[PodScalar],
51 points: &[PodRistrettoPoint],
52 ) -> Option<PodRistrettoPoint> {
53 PodRistrettoPoint::multiscalar_multiply(scalars, points)
54 }
55
56 impl From<&RistrettoPoint> for PodRistrettoPoint {
57 fn from(point: &RistrettoPoint) -> Self {
58 Self(point.compress().to_bytes())
59 }
60 }
61
62 impl TryFrom<&PodRistrettoPoint> for RistrettoPoint {
63 type Error = Curve25519Error;
64
65 fn try_from(pod: &PodRistrettoPoint) -> Result<Self, Self::Error> {
66 let Ok(compressed_ristretto) = CompressedRistretto::from_slice(&pod.0) else {
67 return Err(Curve25519Error::PodConversion);
68 };
69 compressed_ristretto
70 .decompress()
71 .ok_or(Curve25519Error::PodConversion)
72 }
73 }
74
75 impl PointValidation for PodRistrettoPoint {
76 type Point = Self;
77
78 fn validate_point(&self) -> bool {
79 let Ok(compressed_ristretto) = CompressedRistretto::from_slice(&self.0) else {
80 return false;
81 };
82 compressed_ristretto.decompress().is_some()
83 }
84 }
85
86 impl GroupOperations for PodRistrettoPoint {
87 type Scalar = PodScalar;
88 type Point = Self;
89
90 fn add(left_point: &Self, right_point: &Self) -> Option<Self> {
91 let left_point: RistrettoPoint = left_point.try_into().ok()?;
92 let right_point: RistrettoPoint = right_point.try_into().ok()?;
93
94 let result = &left_point + &right_point;
95 Some((&result).into())
96 }
97
98 fn subtract(left_point: &Self, right_point: &Self) -> Option<Self> {
99 let left_point: RistrettoPoint = left_point.try_into().ok()?;
100 let right_point: RistrettoPoint = right_point.try_into().ok()?;
101
102 let result = &left_point - &right_point;
103 Some((&result).into())
104 }
105
106 #[cfg(not(target_os = "solana"))]
107 fn multiply(scalar: &PodScalar, point: &Self) -> Option<Self> {
108 let scalar: Scalar = scalar.try_into().ok()?;
109 let point: RistrettoPoint = point.try_into().ok()?;
110
111 let result = &scalar * &point;
112 Some((&result).into())
113 }
114 }
115
116 impl MultiScalarMultiplication for PodRistrettoPoint {
117 type Scalar = PodScalar;
118 type Point = Self;
119
120 fn multiscalar_multiply(scalars: &[PodScalar], points: &[Self]) -> Option<Self> {
121 let scalars = scalars
122 .iter()
123 .map(|scalar| Scalar::try_from(scalar).ok())
124 .collect::<Option<Vec<_>>>()?;
125
126 RistrettoPoint::optional_multiscalar_mul(
127 scalars,
128 points
129 .iter()
130 .map(|point| RistrettoPoint::try_from(point).ok()),
131 )
132 .map(|result| PodRistrettoPoint::from(&result))
133 }
134 }
135}
136
137#[cfg(target_os = "solana")]
138#[allow(unused_variables)]
139mod target_arch {
140 use {
141 super::*,
142 crate::{
143 curve_syscall_traits::{ADD, CURVE25519_RISTRETTO, MUL, SUB},
144 scalar::PodScalar,
145 },
146 bytemuck::Zeroable,
147 };
148
149 pub fn validate_ristretto(point: &PodRistrettoPoint) -> bool {
150 let mut validate_result = 0u8;
151 let result = unsafe {
152 solana_define_syscall::definitions::sol_curve_validate_point(
153 CURVE25519_RISTRETTO,
154 &point.0 as *const u8,
155 &mut validate_result,
156 )
157 };
158
159 result == 0
160 }
161
162 pub fn add_ristretto(
163 left_point: &PodRistrettoPoint,
164 right_point: &PodRistrettoPoint,
165 ) -> Option<PodRistrettoPoint> {
166 let mut result_point = PodRistrettoPoint::zeroed();
167 let result = unsafe {
168 solana_define_syscall::definitions::sol_curve_group_op(
169 CURVE25519_RISTRETTO,
170 ADD,
171 &left_point.0 as *const u8,
172 &right_point.0 as *const u8,
173 &mut result_point.0 as *mut u8,
174 )
175 };
176
177 if result == 0 {
178 Some(result_point)
179 } else {
180 None
181 }
182 }
183
184 pub fn subtract_ristretto(
185 left_point: &PodRistrettoPoint,
186 right_point: &PodRistrettoPoint,
187 ) -> Option<PodRistrettoPoint> {
188 let mut result_point = PodRistrettoPoint::zeroed();
189 let result = unsafe {
190 solana_define_syscall::definitions::sol_curve_group_op(
191 CURVE25519_RISTRETTO,
192 SUB,
193 &left_point.0 as *const u8,
194 &right_point.0 as *const u8,
195 &mut result_point.0 as *mut u8,
196 )
197 };
198
199 if result == 0 {
200 Some(result_point)
201 } else {
202 None
203 }
204 }
205
206 pub fn multiply_ristretto(
207 scalar: &PodScalar,
208 point: &PodRistrettoPoint,
209 ) -> Option<PodRistrettoPoint> {
210 let mut result_point = PodRistrettoPoint::zeroed();
211 let result = unsafe {
212 solana_define_syscall::definitions::sol_curve_group_op(
213 CURVE25519_RISTRETTO,
214 MUL,
215 &scalar.0 as *const u8,
216 &point.0 as *const u8,
217 &mut result_point.0 as *mut u8,
218 )
219 };
220
221 if result == 0 {
222 Some(result_point)
223 } else {
224 None
225 }
226 }
227
228 pub fn multiscalar_multiply_ristretto(
229 scalars: &[PodScalar],
230 points: &[PodRistrettoPoint],
231 ) -> Option<PodRistrettoPoint> {
232 let mut result_point = PodRistrettoPoint::zeroed();
233 let result = unsafe {
234 solana_define_syscall::definitions::sol_curve_multiscalar_mul(
235 CURVE25519_RISTRETTO,
236 scalars.as_ptr() as *const u8,
237 points.as_ptr() as *const u8,
238 points.len() as u64,
239 &mut result_point.0 as *mut u8,
240 )
241 };
242
243 if result == 0 {
244 Some(result_point)
245 } else {
246 None
247 }
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use {
254 super::*,
255 crate::scalar::PodScalar,
256 curve25519_dalek::{
257 constants::RISTRETTO_BASEPOINT_POINT as G, ristretto::RistrettoPoint, traits::Identity,
258 },
259 };
260
261 #[test]
262 fn test_validate_ristretto() {
263 let pod = PodRistrettoPoint(G.compress().to_bytes());
264 assert!(validate_ristretto(&pod));
265
266 let invalid_bytes = [
267 120, 140, 152, 233, 41, 227, 203, 27, 87, 115, 25, 251, 219, 5, 84, 148, 117, 38, 84,
268 60, 87, 144, 161, 146, 42, 34, 91, 155, 158, 189, 121, 79,
269 ];
270
271 assert!(!validate_ristretto(&PodRistrettoPoint(invalid_bytes)));
272 }
273
274 #[test]
275 fn test_add_subtract_ristretto() {
276 let identity = PodRistrettoPoint(RistrettoPoint::identity().compress().to_bytes());
278 let point = PodRistrettoPoint([
279 210, 174, 124, 127, 67, 77, 11, 114, 71, 63, 168, 136, 113, 20, 141, 228, 195, 254,
280 232, 229, 220, 249, 213, 232, 61, 238, 152, 249, 83, 225, 206, 16,
281 ]);
282
283 assert_eq!(add_ristretto(&point, &identity).unwrap(), point);
284 assert_eq!(subtract_ristretto(&point, &identity).unwrap(), point);
285
286 let point_a = PodRistrettoPoint([
288 208, 165, 125, 204, 2, 100, 218, 17, 170, 194, 23, 9, 102, 156, 134, 136, 217, 190, 98,
289 34, 183, 194, 228, 153, 92, 11, 108, 103, 28, 57, 88, 15,
290 ]);
291 let point_b = PodRistrettoPoint([
292 208, 241, 72, 163, 73, 53, 32, 174, 54, 194, 71, 8, 70, 181, 244, 199, 93, 147, 99,
293 231, 162, 127, 25, 40, 39, 19, 140, 132, 112, 212, 145, 108,
294 ]);
295 let point_c = PodRistrettoPoint([
296 250, 61, 200, 25, 195, 15, 144, 179, 24, 17, 252, 167, 247, 44, 47, 41, 104, 237, 49,
297 137, 231, 173, 86, 106, 121, 249, 245, 247, 70, 188, 31, 49,
298 ]);
299
300 assert_eq!(
301 add_ristretto(&add_ristretto(&point_a, &point_b).unwrap(), &point_c),
302 add_ristretto(&point_a, &add_ristretto(&point_b, &point_c).unwrap()),
303 );
304
305 assert_eq!(
306 subtract_ristretto(&subtract_ristretto(&point_a, &point_b).unwrap(), &point_c),
307 subtract_ristretto(&point_a, &add_ristretto(&point_b, &point_c).unwrap()),
308 );
309
310 assert_eq!(
312 add_ristretto(&point_a, &point_b).unwrap(),
313 add_ristretto(&point_b, &point_a).unwrap(),
314 );
315
316 let point = PodRistrettoPoint(G.compress().to_bytes());
318 let point_negated = PodRistrettoPoint((-G).compress().to_bytes());
319
320 assert_eq!(
321 point_negated,
322 subtract_ristretto(&identity, &point).unwrap(),
323 )
324 }
325
326 #[test]
327 fn test_multiply_ristretto() {
328 let scalar_x = PodScalar([
329 254, 198, 23, 138, 67, 243, 184, 110, 236, 115, 236, 205, 205, 215, 79, 114, 45, 250,
330 78, 137, 3, 107, 136, 237, 49, 126, 117, 223, 37, 191, 88, 6,
331 ]);
332 let point_a = PodRistrettoPoint([
333 68, 80, 232, 181, 241, 77, 60, 81, 154, 51, 173, 35, 98, 234, 149, 37, 1, 39, 191, 201,
334 193, 48, 88, 189, 97, 126, 63, 35, 144, 145, 203, 31,
335 ]);
336 let point_b = PodRistrettoPoint([
337 200, 236, 1, 12, 244, 130, 226, 214, 28, 125, 43, 163, 222, 234, 81, 213, 201, 156, 31,
338 4, 167, 132, 240, 76, 164, 18, 45, 20, 48, 85, 206, 121,
339 ]);
340
341 let ax = multiply_ristretto(&scalar_x, &point_a).unwrap();
342 let bx = multiply_ristretto(&scalar_x, &point_b).unwrap();
343
344 assert_eq!(
345 add_ristretto(&ax, &bx),
346 multiply_ristretto(&scalar_x, &add_ristretto(&point_a, &point_b).unwrap()),
347 );
348 }
349
350 #[test]
351 fn test_multiscalar_multiplication_ristretto() {
352 let scalar = PodScalar([
353 123, 108, 109, 66, 154, 185, 88, 122, 178, 43, 17, 154, 201, 223, 31, 238, 59, 215, 71,
354 154, 215, 143, 177, 158, 9, 136, 32, 223, 139, 13, 133, 5,
355 ]);
356 let point = PodRistrettoPoint([
357 158, 2, 130, 90, 148, 36, 172, 155, 86, 196, 74, 139, 30, 98, 44, 225, 155, 207, 135,
358 111, 238, 167, 235, 67, 234, 125, 0, 227, 146, 31, 24, 113,
359 ]);
360
361 let basic_product = multiply_ristretto(&scalar, &point).unwrap();
362 let msm_product = multiscalar_multiply_ristretto(&[scalar], &[point]).unwrap();
363
364 assert_eq!(basic_product, msm_product);
365
366 let scalar_a = PodScalar([
367 8, 161, 219, 155, 192, 137, 153, 26, 27, 40, 30, 17, 124, 194, 26, 41, 32, 7, 161, 45,
368 212, 198, 212, 81, 133, 185, 164, 85, 95, 232, 106, 10,
369 ]);
370 let scalar_b = PodScalar([
371 135, 207, 106, 208, 107, 127, 46, 82, 66, 22, 136, 125, 105, 62, 69, 34, 213, 210, 17,
372 196, 120, 114, 238, 237, 149, 170, 5, 243, 54, 77, 172, 12,
373 ]);
374 let point_x = PodRistrettoPoint([
375 130, 35, 97, 25, 18, 199, 33, 239, 85, 143, 119, 111, 49, 51, 224, 40, 167, 185, 240,
376 179, 25, 194, 213, 41, 14, 155, 104, 18, 181, 197, 15, 112,
377 ]);
378 let point_y = PodRistrettoPoint([
379 152, 156, 155, 197, 152, 232, 92, 206, 219, 159, 193, 134, 121, 128, 139, 36, 56, 191,
380 51, 143, 72, 204, 87, 76, 110, 124, 101, 96, 238, 158, 42, 108,
381 ]);
382
383 let ax = multiply_ristretto(&scalar_a, &point_x).unwrap();
384 let by = multiply_ristretto(&scalar_b, &point_y).unwrap();
385 let basic_product = add_ristretto(&ax, &by).unwrap();
386 let msm_product =
387 multiscalar_multiply_ristretto(&[scalar_a, scalar_b], &[point_x, point_y]).unwrap();
388
389 assert_eq!(basic_product, msm_product);
390 }
391}