solana_zk_token_sdk/encryption/
discrete_log.rs#![cfg(not(target_os = "solana"))]
use {
crate::RISTRETTO_POINT_LEN,
curve25519_dalek::{
constants::RISTRETTO_BASEPOINT_POINT as G,
ristretto::RistrettoPoint,
scalar::Scalar,
traits::{Identity, IsIdentity},
},
itertools::Itertools,
serde::{Deserialize, Serialize},
std::{collections::HashMap, thread},
thiserror::Error,
};
const TWO16: u64 = 65536; const TWO17: u64 = 131072; const MAX_THREAD: usize = 65536;
#[derive(Error, Clone, Debug, Eq, PartialEq)]
pub enum DiscreteLogError {
#[error("discrete log number of threads not power-of-two")]
DiscreteLogThreads,
#[error("discrete log batch size too large")]
DiscreteLogBatchSize,
}
#[derive(Serialize, Deserialize, Copy, Clone, Debug, Eq, PartialEq)]
pub struct DiscreteLog {
pub generator: RistrettoPoint,
pub target: RistrettoPoint,
num_threads: usize,
range_bound: usize,
step_point: RistrettoPoint,
compression_batch_size: usize,
}
#[derive(Serialize, Deserialize, Default)]
pub struct DecodePrecomputation(HashMap<[u8; RISTRETTO_POINT_LEN], u16>);
#[allow(dead_code)]
fn decode_u32_precomputation(generator: RistrettoPoint) -> DecodePrecomputation {
let mut hashmap = HashMap::new();
let two17_scalar = Scalar::from(TWO17);
let identity = RistrettoPoint::identity(); let generator = two17_scalar * generator; let ristretto_iter = RistrettoIterator::new((identity, 0), (generator, 1));
for (point, x_hi) in ristretto_iter.take(TWO16 as usize) {
let key = point.compress().to_bytes();
hashmap.insert(key, x_hi as u16);
}
DecodePrecomputation(hashmap)
}
lazy_static::lazy_static! {
pub static ref DECODE_PRECOMPUTATION_FOR_G: DecodePrecomputation = {
static DECODE_PRECOMPUTATION_FOR_G_BINCODE: &[u8] =
include_bytes!("decode_u32_precomputation_for_G.bincode");
bincode::deserialize(DECODE_PRECOMPUTATION_FOR_G_BINCODE).unwrap_or_default()
};
}
impl DiscreteLog {
pub fn new(generator: RistrettoPoint, target: RistrettoPoint) -> Self {
Self {
generator,
target,
num_threads: 1,
range_bound: TWO16 as usize,
step_point: G,
compression_batch_size: 32,
}
}
pub fn num_threads(&mut self, num_threads: usize) -> Result<(), DiscreteLogError> {
if num_threads == 0 || (num_threads & (num_threads - 1)) != 0 || num_threads > MAX_THREAD {
return Err(DiscreteLogError::DiscreteLogThreads);
}
self.num_threads = num_threads;
self.range_bound = (TWO16 as usize).checked_div(num_threads).unwrap();
self.step_point = Scalar::from(num_threads as u64) * G;
Ok(())
}
pub fn set_compression_batch_size(
&mut self,
compression_batch_size: usize,
) -> Result<(), DiscreteLogError> {
if compression_batch_size >= TWO16 as usize || compression_batch_size == 0 {
return Err(DiscreteLogError::DiscreteLogBatchSize);
}
self.compression_batch_size = compression_batch_size;
Ok(())
}
pub fn decode_u32(self) -> Option<u64> {
let mut starting_point = self.target;
let handles = (0..self.num_threads)
.map(|i| {
let ristretto_iterator = RistrettoIterator::new(
(starting_point, i as u64),
(-(&self.step_point), self.num_threads as u64),
);
let handle = thread::spawn(move || {
Self::decode_range(
ristretto_iterator,
self.range_bound,
self.compression_batch_size,
)
});
starting_point -= G;
handle
})
.collect::<Vec<_>>();
let mut solution = None;
for handle in handles {
let discrete_log = handle.join().unwrap();
if discrete_log.is_some() {
solution = discrete_log;
}
}
solution
}
fn decode_range(
ristretto_iterator: RistrettoIterator,
range_bound: usize,
compression_batch_size: usize,
) -> Option<u64> {
let hashmap = &DECODE_PRECOMPUTATION_FOR_G;
let mut decoded = None;
for batch in &ristretto_iterator
.take(range_bound)
.chunks(compression_batch_size)
{
let (batch_points, batch_indices): (Vec<_>, Vec<_>) = batch
.filter(|(point, index)| {
if point.is_identity() {
decoded = Some(*index);
return false;
}
true
})
.unzip();
let batch_compressed = RistrettoPoint::double_and_compress_batch(&batch_points);
for (point, x_lo) in batch_compressed.iter().zip(batch_indices.iter()) {
let key = point.to_bytes();
if hashmap.0.contains_key(&key) {
let x_hi = hashmap.0[&key];
decoded = Some(x_lo + TWO16 * x_hi as u64);
}
}
}
decoded
}
}
struct RistrettoIterator {
pub current: (RistrettoPoint, u64),
pub step: (RistrettoPoint, u64),
}
impl RistrettoIterator {
fn new(current: (RistrettoPoint, u64), step: (RistrettoPoint, u64)) -> Self {
RistrettoIterator { current, step }
}
}
impl Iterator for RistrettoIterator {
type Item = (RistrettoPoint, u64);
fn next(&mut self) -> Option<Self::Item> {
let r = self.current;
self.current = (self.current.0 + self.step.0, self.current.1 + self.step.1);
Some(r)
}
}
#[cfg(test)]
mod tests {
use {super::*, std::time::Instant};
#[test]
#[allow(non_snake_case)]
fn test_serialize_decode_u32_precomputation_for_G() {
let decode_u32_precomputation_for_G = decode_u32_precomputation(G);
if decode_u32_precomputation_for_G.0 != DECODE_PRECOMPUTATION_FOR_G.0 {
use std::{fs::File, io::Write, path::PathBuf};
let mut f = File::create(PathBuf::from(
"src/encryption/decode_u32_precomputation_for_G.bincode",
))
.unwrap();
f.write_all(&bincode::serialize(&decode_u32_precomputation_for_G).unwrap())
.unwrap();
panic!("Rebuild and run this test again");
}
}
#[test]
fn test_decode_correctness() {
let amount: u64 = 4294967295;
let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
let start_computation = Instant::now();
let decoded = instance.decode_u32();
let computation_secs = start_computation.elapsed().as_secs_f64();
assert_eq!(amount, decoded.unwrap());
println!("single thread discrete log computation secs: {computation_secs:?} sec");
}
#[test]
fn test_decode_correctness_threaded() {
let amount: u64 = 55;
let mut instance = DiscreteLog::new(G, Scalar::from(amount) * G);
instance.num_threads(4).unwrap();
let start_computation = Instant::now();
let decoded = instance.decode_u32();
let computation_secs = start_computation.elapsed().as_secs_f64();
assert_eq!(amount, decoded.unwrap());
println!("4 thread discrete log computation: {computation_secs:?} sec");
let amount: u64 = 0;
let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
let decoded = instance.decode_u32();
assert_eq!(amount, decoded.unwrap());
let amount: u64 = 1;
let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
let decoded = instance.decode_u32();
assert_eq!(amount, decoded.unwrap());
let amount: u64 = 2;
let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
let decoded = instance.decode_u32();
assert_eq!(amount, decoded.unwrap());
let amount: u64 = 3;
let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
let decoded = instance.decode_u32();
assert_eq!(amount, decoded.unwrap());
let amount: u64 = (1_u64 << 32) - 1;
let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
let decoded = instance.decode_u32();
assert_eq!(amount, decoded.unwrap());
}
}