snarkvm_algorithms_cuda/
lib.rs#[allow(unused_imports)]
use blst::*;
use core::ffi::c_void;
sppark::cuda_error!();
#[repr(C)]
pub enum NTTInputOutputOrder {
NN = 0,
NR = 1,
RN = 2,
RR = 3,
}
#[repr(C)]
pub enum NTTDirection {
Forward = 0,
Inverse = 1,
}
#[repr(C)]
pub enum NTTType {
Standard = 0,
Coset = 1,
}
extern "C" {
fn snarkvm_ntt(
inout: *mut core::ffi::c_void,
lg_domain_size: u32,
ntt_order: NTTInputOutputOrder,
ntt_direction: NTTDirection,
ntt_type: NTTType,
) -> cuda::Error;
fn snarkvm_polymul(
out: *mut core::ffi::c_void,
pcount: usize,
polynomials: *const core::ffi::c_void,
plens: *const core::ffi::c_void,
ecount: usize,
evaluations: *const core::ffi::c_void,
elens: *const core::ffi::c_void,
lg_domain_size: u32,
) -> cuda::Error;
fn snarkvm_msm(
out: *mut c_void,
points_with_infinity: *const c_void,
npoints: usize,
scalars: *const c_void,
ffi_affine_sz: usize,
) -> cuda::Error;
}
#[allow(non_snake_case)]
pub fn NTT<T>(
domain_size: usize,
inout: &mut [T],
ntt_order: NTTInputOutputOrder,
ntt_direction: NTTDirection,
ntt_type: NTTType,
) -> Result<(), cuda::Error> {
if (domain_size & (domain_size - 1)) != 0 {
panic!("domain_size is not power of 2");
}
let lg_domain_size = domain_size.trailing_zeros();
let err = unsafe {
snarkvm_ntt(inout.as_mut_ptr() as *mut core::ffi::c_void, lg_domain_size, ntt_order, ntt_direction, ntt_type)
};
if err.code != 0 {
return Err(err);
}
Ok(())
}
pub fn polymul<T: std::clone::Clone>(
domain: usize,
polynomials: &Vec<Vec<T>>,
evaluations: &Vec<Vec<T>>,
zero: &T,
) -> Result<Vec<T>, cuda::Error> {
let initial_domain_size = domain;
if (initial_domain_size & (initial_domain_size - 1)) != 0 {
panic!("domain_size is not power of 2");
}
let lg_domain_size = initial_domain_size.trailing_zeros();
let mut pptrs = Vec::new();
let mut plens = Vec::new();
for polynomial in polynomials {
pptrs.push(polynomial.as_ptr() as *const core::ffi::c_void);
plens.push(polynomial.len());
}
let mut eptrs = Vec::new();
let mut elens = Vec::new();
for evaluation in evaluations {
eptrs.push(evaluation.as_ptr() as *const core::ffi::c_void);
elens.push(evaluation.len());
}
let mut out = Vec::new();
out.resize(initial_domain_size, zero.clone());
let err = unsafe {
snarkvm_polymul(
out.as_mut_ptr() as *mut core::ffi::c_void,
pptrs.len(),
pptrs.as_ptr() as *const core::ffi::c_void,
plens.as_ptr() as *const core::ffi::c_void,
eptrs.len(),
eptrs.as_ptr() as *const core::ffi::c_void,
elens.as_ptr() as *const core::ffi::c_void,
lg_domain_size,
)
};
if err.code != 0 {
return Err(err);
}
Ok(out)
}
pub fn msm<Affine, Projective, Scalar>(points: &[Affine], scalars: &[Scalar]) -> Result<Projective, cuda::Error> {
let npoints = scalars.len();
if npoints > points.len() {
panic!("length mismatch {} points < {} scalars", npoints, scalars.len())
}
#[allow(clippy::uninit_assumed_init)]
let mut ret: Projective = unsafe { std::mem::MaybeUninit::uninit().assume_init() };
let err = unsafe {
snarkvm_msm(
&mut ret as *mut _ as *mut c_void,
points as *const _ as *const c_void,
npoints,
scalars as *const _ as *const c_void,
std::mem::size_of::<Affine>(),
)
};
if err.code != 0 {
return Err(err);
}
Ok(ret)
}