1use bytemuck_derive::{Pod, Zeroable};
2pub use target_arch::*;
3
4#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Pod, Zeroable)]
5#[repr(transparent)]
6pub struct PodEdwardsPoint(pub [u8; 32]);
7
8#[cfg(not(target_os = "solana"))]
9mod target_arch {
10 use {
11 super::*,
12 crate::{
13 curve_syscall_traits::{GroupOperations, MultiScalarMultiplication, PointValidation},
14 errors::Curve25519Error,
15 scalar::PodScalar,
16 },
17 curve25519_dalek::{
18 edwards::{CompressedEdwardsY, EdwardsPoint},
19 scalar::Scalar,
20 traits::VartimeMultiscalarMul,
21 },
22 };
23
24 pub fn validate_edwards(point: &PodEdwardsPoint) -> bool {
25 point.validate_point()
26 }
27
28 pub fn add_edwards(
29 left_point: &PodEdwardsPoint,
30 right_point: &PodEdwardsPoint,
31 ) -> Option<PodEdwardsPoint> {
32 PodEdwardsPoint::add(left_point, right_point)
33 }
34
35 pub fn subtract_edwards(
36 left_point: &PodEdwardsPoint,
37 right_point: &PodEdwardsPoint,
38 ) -> Option<PodEdwardsPoint> {
39 PodEdwardsPoint::subtract(left_point, right_point)
40 }
41
42 pub fn multiply_edwards(
43 scalar: &PodScalar,
44 point: &PodEdwardsPoint,
45 ) -> Option<PodEdwardsPoint> {
46 PodEdwardsPoint::multiply(scalar, point)
47 }
48
49 pub fn multiscalar_multiply_edwards(
50 scalars: &[PodScalar],
51 points: &[PodEdwardsPoint],
52 ) -> Option<PodEdwardsPoint> {
53 PodEdwardsPoint::multiscalar_multiply(scalars, points)
54 }
55
56 impl From<&EdwardsPoint> for PodEdwardsPoint {
57 fn from(point: &EdwardsPoint) -> Self {
58 Self(point.compress().to_bytes())
59 }
60 }
61
62 impl TryFrom<&PodEdwardsPoint> for EdwardsPoint {
63 type Error = Curve25519Error;
64
65 fn try_from(pod: &PodEdwardsPoint) -> Result<Self, Self::Error> {
66 let Ok(compressed_edwards_y) = CompressedEdwardsY::from_slice(&pod.0) else {
67 return Err(Curve25519Error::PodConversion);
68 };
69 compressed_edwards_y
70 .decompress()
71 .ok_or(Curve25519Error::PodConversion)
72 }
73 }
74
75 impl PointValidation for PodEdwardsPoint {
76 type Point = Self;
77
78 fn validate_point(&self) -> bool {
79 let Ok(compressed_edwards_y) = CompressedEdwardsY::from_slice(&self.0) else {
80 return false;
81 };
82 compressed_edwards_y.decompress().is_some()
83 }
84 }
85
86 impl GroupOperations for PodEdwardsPoint {
87 type Scalar = PodScalar;
88 type Point = Self;
89
90 fn add(left_point: &Self, right_point: &Self) -> Option<Self> {
91 let left_point: EdwardsPoint = left_point.try_into().ok()?;
92 let right_point: EdwardsPoint = right_point.try_into().ok()?;
93
94 let result = &left_point + &right_point;
95 Some((&result).into())
96 }
97
98 fn subtract(left_point: &Self, right_point: &Self) -> Option<Self> {
99 let left_point: EdwardsPoint = left_point.try_into().ok()?;
100 let right_point: EdwardsPoint = right_point.try_into().ok()?;
101
102 let result = &left_point - &right_point;
103 Some((&result).into())
104 }
105
106 #[cfg(not(target_os = "solana"))]
107 fn multiply(scalar: &PodScalar, point: &Self) -> Option<Self> {
108 let scalar: Scalar = scalar.try_into().ok()?;
109 let point: EdwardsPoint = point.try_into().ok()?;
110
111 let result = &scalar * &point;
112 Some((&result).into())
113 }
114 }
115
116 impl MultiScalarMultiplication for PodEdwardsPoint {
117 type Scalar = PodScalar;
118 type Point = Self;
119
120 fn multiscalar_multiply(scalars: &[PodScalar], points: &[Self]) -> Option<Self> {
121 let scalars = scalars
122 .iter()
123 .map(|scalar| Scalar::try_from(scalar).ok())
124 .collect::<Option<Vec<_>>>()?;
125
126 EdwardsPoint::optional_multiscalar_mul(
127 scalars,
128 points
129 .iter()
130 .map(|point| EdwardsPoint::try_from(point).ok()),
131 )
132 .map(|result| PodEdwardsPoint::from(&result))
133 }
134 }
135}
136
137#[cfg(target_os = "solana")]
138mod target_arch {
139 use {
140 super::*,
141 crate::{
142 curve_syscall_traits::{ADD, CURVE25519_EDWARDS, MUL, SUB},
143 scalar::PodScalar,
144 },
145 bytemuck::Zeroable,
146 };
147
148 pub fn validate_edwards(point: &PodEdwardsPoint) -> bool {
149 let mut validate_result = 0u8;
150 let result = unsafe {
151 solana_define_syscall::definitions::sol_curve_validate_point(
152 CURVE25519_EDWARDS,
153 &point.0 as *const u8,
154 &mut validate_result,
155 )
156 };
157 result == 0
158 }
159
160 pub fn add_edwards(
161 left_point: &PodEdwardsPoint,
162 right_point: &PodEdwardsPoint,
163 ) -> Option<PodEdwardsPoint> {
164 let mut result_point = PodEdwardsPoint::zeroed();
165 let result = unsafe {
166 solana_define_syscall::definitions::sol_curve_group_op(
167 CURVE25519_EDWARDS,
168 ADD,
169 &left_point.0 as *const u8,
170 &right_point.0 as *const u8,
171 &mut result_point.0 as *mut u8,
172 )
173 };
174
175 if result == 0 {
176 Some(result_point)
177 } else {
178 None
179 }
180 }
181
182 pub fn subtract_edwards(
183 left_point: &PodEdwardsPoint,
184 right_point: &PodEdwardsPoint,
185 ) -> Option<PodEdwardsPoint> {
186 let mut result_point = PodEdwardsPoint::zeroed();
187 let result = unsafe {
188 solana_define_syscall::definitions::sol_curve_group_op(
189 CURVE25519_EDWARDS,
190 SUB,
191 &left_point.0 as *const u8,
192 &right_point.0 as *const u8,
193 &mut result_point.0 as *mut u8,
194 )
195 };
196
197 if result == 0 {
198 Some(result_point)
199 } else {
200 None
201 }
202 }
203
204 pub fn multiply_edwards(
205 scalar: &PodScalar,
206 point: &PodEdwardsPoint,
207 ) -> Option<PodEdwardsPoint> {
208 let mut result_point = PodEdwardsPoint::zeroed();
209 let result = unsafe {
210 solana_define_syscall::definitions::sol_curve_group_op(
211 CURVE25519_EDWARDS,
212 MUL,
213 &scalar.0 as *const u8,
214 &point.0 as *const u8,
215 &mut result_point.0 as *mut u8,
216 )
217 };
218
219 if result == 0 {
220 Some(result_point)
221 } else {
222 None
223 }
224 }
225
226 pub fn multiscalar_multiply_edwards(
227 scalars: &[PodScalar],
228 points: &[PodEdwardsPoint],
229 ) -> Option<PodEdwardsPoint> {
230 let mut result_point = PodEdwardsPoint::zeroed();
231 let result = unsafe {
232 solana_define_syscall::definitions::sol_curve_multiscalar_mul(
233 CURVE25519_EDWARDS,
234 scalars.as_ptr() as *const u8,
235 points.as_ptr() as *const u8,
236 points.len() as u64,
237 &mut result_point.0 as *mut u8,
238 )
239 };
240
241 if result == 0 {
242 Some(result_point)
243 } else {
244 None
245 }
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use {
252 super::*,
253 crate::scalar::PodScalar,
254 curve25519_dalek::{
255 constants::ED25519_BASEPOINT_POINT as G, edwards::EdwardsPoint, traits::Identity,
256 },
257 };
258
259 #[test]
260 fn test_validate_edwards() {
261 let pod = PodEdwardsPoint(G.compress().to_bytes());
262 assert!(validate_edwards(&pod));
263
264 let invalid_bytes = [
265 120, 140, 152, 233, 41, 227, 203, 27, 87, 115, 25, 251, 219, 5, 84, 148, 117, 38, 84,
266 60, 87, 144, 161, 146, 42, 34, 91, 155, 158, 189, 121, 79,
267 ];
268
269 assert!(!validate_edwards(&PodEdwardsPoint(invalid_bytes)));
270 }
271
272 #[test]
273 fn test_edwards_add_subtract() {
274 let identity = PodEdwardsPoint(EdwardsPoint::identity().compress().to_bytes());
276 let point = PodEdwardsPoint([
277 201, 179, 241, 122, 180, 185, 239, 50, 183, 52, 221, 0, 153, 195, 43, 18, 22, 38, 187,
278 206, 179, 192, 210, 58, 53, 45, 150, 98, 89, 17, 158, 11,
279 ]);
280
281 assert_eq!(add_edwards(&point, &identity).unwrap(), point);
282 assert_eq!(subtract_edwards(&point, &identity).unwrap(), point);
283
284 let point_a = PodEdwardsPoint([
286 33, 124, 71, 170, 117, 69, 151, 247, 59, 12, 95, 125, 133, 166, 64, 5, 2, 27, 90, 27,
287 200, 167, 59, 164, 52, 54, 52, 200, 29, 13, 34, 213,
288 ]);
289 let point_b = PodEdwardsPoint([
290 70, 222, 137, 221, 253, 204, 71, 51, 78, 8, 124, 1, 67, 200, 102, 225, 122, 228, 111,
291 183, 129, 14, 131, 210, 212, 95, 109, 246, 55, 10, 159, 91,
292 ]);
293 let point_c = PodEdwardsPoint([
294 72, 60, 66, 143, 59, 197, 111, 36, 181, 137, 25, 97, 157, 201, 247, 215, 123, 83, 220,
295 250, 154, 150, 180, 192, 196, 28, 215, 137, 34, 247, 39, 129,
296 ]);
297
298 assert_eq!(
299 add_edwards(&add_edwards(&point_a, &point_b).unwrap(), &point_c),
300 add_edwards(&point_a, &add_edwards(&point_b, &point_c).unwrap()),
301 );
302
303 assert_eq!(
304 subtract_edwards(&subtract_edwards(&point_a, &point_b).unwrap(), &point_c),
305 subtract_edwards(&point_a, &add_edwards(&point_b, &point_c).unwrap()),
306 );
307
308 assert_eq!(
310 add_edwards(&point_a, &point_b).unwrap(),
311 add_edwards(&point_b, &point_a).unwrap(),
312 );
313
314 let point = PodEdwardsPoint(G.compress().to_bytes());
316 let point_negated = PodEdwardsPoint((-G).compress().to_bytes());
317
318 assert_eq!(point_negated, subtract_edwards(&identity, &point).unwrap(),)
319 }
320
321 #[test]
322 fn test_edwards_mul() {
323 let scalar_a = PodScalar([
324 72, 191, 131, 55, 85, 86, 54, 60, 116, 10, 39, 130, 180, 3, 90, 227, 47, 228, 252, 99,
325 151, 71, 118, 29, 34, 102, 117, 114, 120, 50, 57, 8,
326 ]);
327 let point_x = PodEdwardsPoint([
328 176, 121, 6, 191, 108, 161, 206, 141, 73, 14, 235, 97, 49, 68, 48, 112, 98, 215, 145,
329 208, 44, 188, 70, 10, 180, 124, 230, 15, 98, 165, 104, 85,
330 ]);
331 let point_y = PodEdwardsPoint([
332 174, 86, 89, 208, 236, 123, 223, 128, 75, 54, 228, 232, 220, 100, 205, 108, 237, 97,
333 105, 79, 74, 192, 67, 224, 185, 23, 157, 116, 216, 151, 223, 81,
334 ]);
335
336 let ax = multiply_edwards(&scalar_a, &point_x).unwrap();
337 let bx = multiply_edwards(&scalar_a, &point_y).unwrap();
338
339 assert_eq!(
340 add_edwards(&ax, &bx),
341 multiply_edwards(&scalar_a, &add_edwards(&point_x, &point_y).unwrap()),
342 );
343 }
344
345 #[test]
346 fn test_multiscalar_multiplication_edwards() {
347 let scalar = PodScalar([
348 205, 73, 127, 173, 83, 80, 190, 66, 202, 3, 237, 77, 52, 223, 238, 70, 80, 242, 24, 87,
349 111, 84, 49, 63, 194, 76, 202, 108, 62, 240, 83, 15,
350 ]);
351 let point = PodEdwardsPoint([
352 222, 174, 184, 139, 143, 122, 253, 96, 0, 207, 120, 157, 112, 38, 54, 189, 91, 144, 78,
353 111, 111, 122, 140, 183, 65, 250, 191, 133, 6, 42, 212, 93,
354 ]);
355
356 let basic_product = multiply_edwards(&scalar, &point).unwrap();
357 let msm_product = multiscalar_multiply_edwards(&[scalar], &[point]).unwrap();
358
359 assert_eq!(basic_product, msm_product);
360
361 let scalar_a = PodScalar([
362 246, 154, 34, 110, 31, 185, 50, 1, 252, 194, 163, 56, 211, 18, 101, 192, 57, 225, 207,
363 69, 19, 84, 231, 118, 137, 175, 148, 218, 106, 212, 69, 9,
364 ]);
365 let scalar_b = PodScalar([
366 27, 58, 126, 136, 253, 178, 176, 245, 246, 55, 15, 202, 35, 183, 66, 199, 134, 187,
367 169, 154, 66, 120, 169, 193, 75, 4, 33, 241, 126, 227, 59, 3,
368 ]);
369 let point_x = PodEdwardsPoint([
370 252, 31, 230, 46, 173, 95, 144, 148, 158, 157, 63, 10, 8, 68, 58, 176, 142, 192, 168,
371 53, 61, 105, 194, 166, 43, 56, 246, 236, 28, 146, 114, 133,
372 ]);
373 let point_y = PodEdwardsPoint([
374 10, 111, 8, 236, 97, 189, 124, 69, 89, 176, 222, 39, 199, 253, 111, 11, 248, 186, 128,
375 90, 120, 128, 248, 210, 232, 183, 93, 104, 111, 150, 7, 241,
376 ]);
377
378 let ax = multiply_edwards(&scalar_a, &point_x).unwrap();
379 let by = multiply_edwards(&scalar_b, &point_y).unwrap();
380 let basic_product = add_edwards(&ax, &by).unwrap();
381 let msm_product =
382 multiscalar_multiply_edwards(&[scalar_a, scalar_b], &[point_x, point_y]).unwrap();
383
384 assert_eq!(basic_product, msm_product);
385 }
386}