snarkvm_algorithms_cuda/
lib.rs1#[allow(unused_imports)]
17use blst::*;
18
19use core::ffi::c_void;
20sppark::cuda_error!();
21
22#[repr(C)]
23pub enum NTTInputOutputOrder {
24 NN = 0,
25 NR = 1,
26 RN = 2,
27 RR = 3,
28}
29
30#[repr(C)]
31pub enum NTTDirection {
32 Forward = 0,
33 Inverse = 1,
34}
35
36#[repr(C)]
37pub enum NTTType {
38 Standard = 0,
39 Coset = 1,
40}
41
42extern "C" {
43 fn snarkvm_ntt(
44 inout: *mut core::ffi::c_void,
45 lg_domain_size: u32,
46 ntt_order: NTTInputOutputOrder,
47 ntt_direction: NTTDirection,
48 ntt_type: NTTType,
49 ) -> cuda::Error;
50
51 fn snarkvm_polymul(
52 out: *mut core::ffi::c_void,
53 pcount: usize,
54 polynomials: *const core::ffi::c_void,
55 plens: *const core::ffi::c_void,
56 ecount: usize,
57 evaluations: *const core::ffi::c_void,
58 elens: *const core::ffi::c_void,
59 lg_domain_size: u32,
60 ) -> cuda::Error;
61
62 fn snarkvm_msm(
63 out: *mut c_void,
64 points_with_infinity: *const c_void,
65 npoints: usize,
66 scalars: *const c_void,
67 ffi_affine_sz: usize,
68 ) -> cuda::Error;
69}
70
71#[allow(non_snake_case)]
77pub fn NTT<T>(
78 domain_size: usize,
79 inout: &mut [T],
80 ntt_order: NTTInputOutputOrder,
81 ntt_direction: NTTDirection,
82 ntt_type: NTTType,
83) -> Result<(), cuda::Error> {
84 if (domain_size & (domain_size - 1)) != 0 {
85 panic!("domain_size is not power of 2");
86 }
87 let lg_domain_size = domain_size.trailing_zeros();
88
89 let err = unsafe {
90 snarkvm_ntt(inout.as_mut_ptr() as *mut core::ffi::c_void, lg_domain_size, ntt_order, ntt_direction, ntt_type)
91 };
92
93 if err.code != 0 {
94 return Err(err);
95 }
96 Ok(())
97}
98
99pub fn polymul<T: std::clone::Clone>(
101 domain: usize,
102 polynomials: &Vec<Vec<T>>,
103 evaluations: &Vec<Vec<T>>,
104 zero: &T,
105) -> Result<Vec<T>, cuda::Error> {
106 let initial_domain_size = domain;
107 if (initial_domain_size & (initial_domain_size - 1)) != 0 {
108 panic!("domain_size is not power of 2");
109 }
110
111 let lg_domain_size = initial_domain_size.trailing_zeros();
112
113 let mut pptrs = Vec::new();
114 let mut plens = Vec::new();
115 for polynomial in polynomials {
116 pptrs.push(polynomial.as_ptr() as *const core::ffi::c_void);
117 plens.push(polynomial.len());
118 }
119 let mut eptrs = Vec::new();
120 let mut elens = Vec::new();
121 for evaluation in evaluations {
122 eptrs.push(evaluation.as_ptr() as *const core::ffi::c_void);
123 elens.push(evaluation.len());
124 }
125
126 let mut out = Vec::new();
127 out.resize(initial_domain_size, zero.clone());
128 let err = unsafe {
129 snarkvm_polymul(
130 out.as_mut_ptr() as *mut core::ffi::c_void,
131 pptrs.len(),
132 pptrs.as_ptr() as *const core::ffi::c_void,
133 plens.as_ptr() as *const core::ffi::c_void,
134 eptrs.len(),
135 eptrs.as_ptr() as *const core::ffi::c_void,
136 elens.as_ptr() as *const core::ffi::c_void,
137 lg_domain_size,
138 )
139 };
140
141 if err.code != 0 {
142 return Err(err);
143 }
144 Ok(out)
145}
146
147pub fn msm<Affine, Projective, Scalar>(points: &[Affine], scalars: &[Scalar]) -> Result<Projective, cuda::Error> {
149 let npoints = scalars.len();
150 if npoints > points.len() {
151 panic!("length mismatch {} points < {} scalars", npoints, scalars.len())
152 }
153 #[allow(clippy::uninit_assumed_init)]
154 let mut ret: Projective = unsafe { std::mem::MaybeUninit::uninit().assume_init() };
155 let err = unsafe {
156 snarkvm_msm(
157 &mut ret as *mut _ as *mut c_void,
158 points as *const _ as *const c_void,
159 npoints,
160 scalars as *const _ as *const c_void,
161 std::mem::size_of::<Affine>(),
162 )
163 };
164 if err.code != 0 {
165 return Err(err);
166 }
167 Ok(ret)
168}