solana_zk_token_sdk/encryption/
discrete_log.rs#![cfg(not(target_os = "solana"))]
#[cfg(not(target_arch = "wasm32"))]
use std::thread;
use {
crate::RISTRETTO_POINT_LEN,
curve25519_dalek::{
constants::RISTRETTO_BASEPOINT_POINT as G,
ristretto::RistrettoPoint,
scalar::Scalar,
traits::{Identity, IsIdentity},
},
itertools::Itertools,
serde::{Deserialize, Serialize},
std::{collections::HashMap, num::NonZeroUsize},
thiserror::Error,
};
const TWO16: u64 = 65536; const TWO17: u64 = 131072; #[cfg(not(target_arch = "wasm32"))]
const MAX_THREAD: usize = 65536;
#[derive(Error, Clone, Debug, Eq, PartialEq)]
pub enum DiscreteLogError {
#[error("discrete log number of threads not power-of-two")]
DiscreteLogThreads,
#[error("discrete log batch size too large")]
DiscreteLogBatchSize,
}
#[derive(Serialize, Deserialize, Copy, Clone, Debug, Eq, PartialEq)]
pub struct DiscreteLog {
pub generator: RistrettoPoint,
pub target: RistrettoPoint,
num_threads: Option<NonZeroUsize>,
range_bound: NonZeroUsize,
step_point: RistrettoPoint,
compression_batch_size: NonZeroUsize,
}
#[derive(Serialize, Deserialize, Default)]
pub struct DecodePrecomputation(HashMap<[u8; RISTRETTO_POINT_LEN], u16>);
#[allow(dead_code)]
fn decode_u32_precomputation(generator: RistrettoPoint) -> DecodePrecomputation {
let mut hashmap = HashMap::new();
let two17_scalar = Scalar::from(TWO17);
let identity = RistrettoPoint::identity(); let generator = two17_scalar * generator; let ristretto_iter = RistrettoIterator::new((identity, 0), (generator, 1));
for (point, x_hi) in ristretto_iter.take(TWO16 as usize) {
let key = point.compress().to_bytes();
hashmap.insert(key, x_hi as u16);
}
DecodePrecomputation(hashmap)
}
lazy_static::lazy_static! {
pub static ref DECODE_PRECOMPUTATION_FOR_G: DecodePrecomputation = {
static DECODE_PRECOMPUTATION_FOR_G_BINCODE: &[u8] =
include_bytes!("decode_u32_precomputation_for_G.bincode");
bincode::deserialize(DECODE_PRECOMPUTATION_FOR_G_BINCODE).unwrap_or_default()
};
}
impl DiscreteLog {
pub fn new(generator: RistrettoPoint, target: RistrettoPoint) -> Self {
Self {
generator,
target,
num_threads: None,
range_bound: (TWO16 as usize).try_into().unwrap(),
step_point: G,
compression_batch_size: 32.try_into().unwrap(),
}
}
#[cfg(not(target_arch = "wasm32"))]
pub fn num_threads(&mut self, num_threads: NonZeroUsize) -> Result<(), DiscreteLogError> {
if !num_threads.is_power_of_two() || num_threads.get() > MAX_THREAD {
return Err(DiscreteLogError::DiscreteLogThreads);
}
self.num_threads = Some(num_threads);
self.range_bound = (TWO16 as usize)
.checked_div(num_threads.get())
.and_then(|range_bound| range_bound.try_into().ok())
.unwrap(); self.step_point = Scalar::from(num_threads.get() as u64) * G;
Ok(())
}
pub fn set_compression_batch_size(
&mut self,
compression_batch_size: NonZeroUsize,
) -> Result<(), DiscreteLogError> {
if compression_batch_size.get() >= TWO16 as usize {
return Err(DiscreteLogError::DiscreteLogBatchSize);
}
self.compression_batch_size = compression_batch_size;
Ok(())
}
pub fn decode_u32(self) -> Option<u64> {
if let Some(num_threads) = self.num_threads {
#[cfg(not(target_arch = "wasm32"))]
{
let mut starting_point = self.target;
let handles = (0..num_threads.get())
.map(|i| {
let ristretto_iterator = RistrettoIterator::new(
(starting_point, i as u64),
(-(&self.step_point), num_threads.get() as u64),
);
let handle = thread::spawn(move || {
Self::decode_range(
ristretto_iterator,
self.range_bound,
self.compression_batch_size,
)
});
starting_point -= G;
handle
})
.collect::<Vec<_>>();
handles
.into_iter()
.map_while(|h| h.join().ok())
.find(|x| x.is_some())
.flatten()
}
#[cfg(target_arch = "wasm32")]
unreachable!() } else {
let ristretto_iterator =
RistrettoIterator::new((self.target, 0_u64), (-(&self.step_point), 1u64));
Self::decode_range(
ristretto_iterator,
self.range_bound,
self.compression_batch_size,
)
}
}
fn decode_range(
ristretto_iterator: RistrettoIterator,
range_bound: NonZeroUsize,
compression_batch_size: NonZeroUsize,
) -> Option<u64> {
let hashmap = &DECODE_PRECOMPUTATION_FOR_G;
let mut decoded = None;
for batch in &ristretto_iterator
.take(range_bound.get())
.chunks(compression_batch_size.get())
{
let (batch_points, batch_indices): (Vec<_>, Vec<_>) = batch
.filter(|(point, index)| {
if point.is_identity() {
decoded = Some(*index);
return false;
}
true
})
.unzip();
let batch_compressed = RistrettoPoint::double_and_compress_batch(&batch_points);
for (point, x_lo) in batch_compressed.iter().zip(batch_indices.iter()) {
let key = point.to_bytes();
if hashmap.0.contains_key(&key) {
let x_hi = hashmap.0[&key];
decoded = Some(x_lo + TWO16 * x_hi as u64);
}
}
}
decoded
}
}
struct RistrettoIterator {
pub current: (RistrettoPoint, u64),
pub step: (RistrettoPoint, u64),
}
impl RistrettoIterator {
fn new(current: (RistrettoPoint, u64), step: (RistrettoPoint, u64)) -> Self {
RistrettoIterator { current, step }
}
}
impl Iterator for RistrettoIterator {
type Item = (RistrettoPoint, u64);
fn next(&mut self) -> Option<Self::Item> {
let r = self.current;
self.current = (self.current.0 + self.step.0, self.current.1 + self.step.1);
Some(r)
}
}
#[cfg(test)]
mod tests {
use {super::*, std::time::Instant};
#[test]
#[allow(non_snake_case)]
fn test_serialize_decode_u32_precomputation_for_G() {
let decode_u32_precomputation_for_G = decode_u32_precomputation(G);
if decode_u32_precomputation_for_G.0 != DECODE_PRECOMPUTATION_FOR_G.0 {
use std::{fs::File, io::Write, path::PathBuf};
let mut f = File::create(PathBuf::from(
"src/encryption/decode_u32_precomputation_for_G.bincode",
))
.unwrap();
f.write_all(&bincode::serialize(&decode_u32_precomputation_for_G).unwrap())
.unwrap();
panic!("Rebuild and run this test again");
}
}
#[test]
fn test_decode_correctness() {
let amount: u64 = 4294967295;
let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
let start_computation = Instant::now();
let decoded = instance.decode_u32();
let computation_secs = start_computation.elapsed().as_secs_f64();
assert_eq!(amount, decoded.unwrap());
println!("single thread discrete log computation secs: {computation_secs:?} sec");
}
#[cfg(not(target_arch = "wasm32"))]
#[test]
fn test_decode_correctness_threaded() {
let amount: u64 = 55;
let mut instance = DiscreteLog::new(G, Scalar::from(amount) * G);
instance.num_threads(4.try_into().unwrap()).unwrap();
let start_computation = Instant::now();
let decoded = instance.decode_u32();
let computation_secs = start_computation.elapsed().as_secs_f64();
assert_eq!(amount, decoded.unwrap());
println!("4 thread discrete log computation: {computation_secs:?} sec");
let amount: u64 = 0;
let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
let decoded = instance.decode_u32();
assert_eq!(amount, decoded.unwrap());
let amount: u64 = 1;
let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
let decoded = instance.decode_u32();
assert_eq!(amount, decoded.unwrap());
let amount: u64 = 2;
let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
let decoded = instance.decode_u32();
assert_eq!(amount, decoded.unwrap());
let amount: u64 = 3;
let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
let decoded = instance.decode_u32();
assert_eq!(amount, decoded.unwrap());
let amount: u64 = (1_u64 << 32) - 1;
let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
let decoded = instance.decode_u32();
assert_eq!(amount, decoded.unwrap());
}
}