solana_zk_token_sdk/encryption/
discrete_log.rs1#![cfg(not(target_os = "solana"))]
18
19#[cfg(not(target_arch = "wasm32"))]
20use std::thread;
21use {
22 crate::RISTRETTO_POINT_LEN,
23 curve25519_dalek::{
24 constants::RISTRETTO_BASEPOINT_POINT as G,
25 ristretto::RistrettoPoint,
26 scalar::Scalar,
27 traits::{Identity, IsIdentity},
28 },
29 itertools::Itertools,
30 serde::{Deserialize, Serialize},
31 std::{collections::HashMap, num::NonZeroUsize},
32 thiserror::Error,
33};
34
35const TWO16: u64 = 65536; const TWO17: u64 = 131072; #[cfg(not(target_arch = "wasm32"))]
40const MAX_THREAD: usize = 65536;
41
42#[derive(Error, Clone, Debug, Eq, PartialEq)]
43pub enum DiscreteLogError {
44 #[error("discrete log number of threads not power-of-two")]
45 DiscreteLogThreads,
46 #[error("discrete log batch size too large")]
47 DiscreteLogBatchSize,
48}
49
50#[derive(Serialize, Deserialize, Copy, Clone, Debug, Eq, PartialEq)]
54pub struct DiscreteLog {
55 pub generator: RistrettoPoint,
57 pub target: RistrettoPoint,
59 num_threads: Option<NonZeroUsize>,
61 range_bound: NonZeroUsize,
64 step_point: RistrettoPoint,
66 compression_batch_size: NonZeroUsize,
68}
69
70#[derive(Serialize, Deserialize, Default)]
71pub struct DecodePrecomputation(HashMap<[u8; RISTRETTO_POINT_LEN], u16>);
72
73#[allow(dead_code)]
75fn decode_u32_precomputation(generator: RistrettoPoint) -> DecodePrecomputation {
76 let mut hashmap = HashMap::new();
77
78 let two17_scalar = Scalar::from(TWO17);
79 let identity = RistrettoPoint::identity(); let generator = two17_scalar * generator; let ristretto_iter = RistrettoIterator::new((identity, 0), (generator, 1));
84 for (point, x_hi) in ristretto_iter.take(TWO16 as usize) {
85 let key = point.compress().to_bytes();
86 hashmap.insert(key, x_hi as u16);
87 }
88
89 DecodePrecomputation(hashmap)
90}
91
92lazy_static::lazy_static! {
93 pub static ref DECODE_PRECOMPUTATION_FOR_G: DecodePrecomputation = {
95 static DECODE_PRECOMPUTATION_FOR_G_BINCODE: &[u8] =
96 include_bytes!("decode_u32_precomputation_for_G.bincode");
97 bincode::deserialize(DECODE_PRECOMPUTATION_FOR_G_BINCODE).unwrap_or_default()
98 };
99}
100
101impl DiscreteLog {
103 pub fn new(generator: RistrettoPoint, target: RistrettoPoint) -> Self {
107 Self {
108 generator,
109 target,
110 num_threads: None,
111 range_bound: (TWO16 as usize).try_into().unwrap(),
112 step_point: G,
113 compression_batch_size: 32.try_into().unwrap(),
114 }
115 }
116
117 #[cfg(not(target_arch = "wasm32"))]
119 pub fn num_threads(&mut self, num_threads: NonZeroUsize) -> Result<(), DiscreteLogError> {
120 if !num_threads.is_power_of_two() || num_threads.get() > MAX_THREAD {
122 return Err(DiscreteLogError::DiscreteLogThreads);
123 }
124
125 self.num_threads = Some(num_threads);
126 self.range_bound = (TWO16 as usize)
127 .checked_div(num_threads.get())
128 .and_then(|range_bound| range_bound.try_into().ok())
129 .unwrap(); self.step_point = Scalar::from(num_threads.get() as u64) * G;
131
132 Ok(())
133 }
134
135 pub fn set_compression_batch_size(
137 &mut self,
138 compression_batch_size: NonZeroUsize,
139 ) -> Result<(), DiscreteLogError> {
140 if compression_batch_size.get() >= TWO16 as usize {
141 return Err(DiscreteLogError::DiscreteLogBatchSize);
142 }
143 self.compression_batch_size = compression_batch_size;
144
145 Ok(())
146 }
147
148 pub fn decode_u32(self) -> Option<u64> {
151 if let Some(num_threads) = self.num_threads {
152 #[cfg(not(target_arch = "wasm32"))]
153 {
154 let mut starting_point = self.target;
155 let handles = (0..num_threads.get())
156 .map(|i| {
157 let ristretto_iterator = RistrettoIterator::new(
158 (starting_point, i as u64),
159 (-(&self.step_point), num_threads.get() as u64),
160 );
161
162 let handle = thread::spawn(move || {
163 Self::decode_range(
164 ristretto_iterator,
165 self.range_bound,
166 self.compression_batch_size,
167 )
168 });
169
170 starting_point -= G;
171 handle
172 })
173 .collect::<Vec<_>>();
174
175 handles
176 .into_iter()
177 .map_while(|h| h.join().ok())
178 .find(|x| x.is_some())
179 .flatten()
180 }
181 #[cfg(target_arch = "wasm32")]
182 unreachable!() } else {
184 let ristretto_iterator =
185 RistrettoIterator::new((self.target, 0_u64), (-(&self.step_point), 1u64));
186
187 Self::decode_range(
188 ristretto_iterator,
189 self.range_bound,
190 self.compression_batch_size,
191 )
192 }
193 }
194
195 fn decode_range(
196 ristretto_iterator: RistrettoIterator,
197 range_bound: NonZeroUsize,
198 compression_batch_size: NonZeroUsize,
199 ) -> Option<u64> {
200 let hashmap = &DECODE_PRECOMPUTATION_FOR_G;
201 let mut decoded = None;
202
203 for batch in &ristretto_iterator
204 .take(range_bound.get())
205 .chunks(compression_batch_size.get())
206 {
207 let (batch_points, batch_indices): (Vec<_>, Vec<_>) = batch
209 .filter(|(point, index)| {
210 if point.is_identity() {
211 decoded = Some(*index);
212 return false;
213 }
214 true
215 })
216 .unzip();
217
218 let batch_compressed = RistrettoPoint::double_and_compress_batch(&batch_points);
219
220 for (point, x_lo) in batch_compressed.iter().zip(batch_indices.iter()) {
221 let key = point.to_bytes();
222 if hashmap.0.contains_key(&key) {
223 let x_hi = hashmap.0[&key];
224 decoded = Some(x_lo + TWO16 * x_hi as u64);
225 }
226 }
227 }
228
229 decoded
230 }
231}
232
233struct RistrettoIterator {
238 pub current: (RistrettoPoint, u64),
239 pub step: (RistrettoPoint, u64),
240}
241
242impl RistrettoIterator {
243 fn new(current: (RistrettoPoint, u64), step: (RistrettoPoint, u64)) -> Self {
244 RistrettoIterator { current, step }
245 }
246}
247
248impl Iterator for RistrettoIterator {
249 type Item = (RistrettoPoint, u64);
250
251 fn next(&mut self) -> Option<Self::Item> {
252 let r = self.current;
253 self.current = (self.current.0 + self.step.0, self.current.1 + self.step.1);
254 Some(r)
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use {super::*, std::time::Instant};
261
262 #[test]
263 #[allow(non_snake_case)]
264 fn test_serialize_decode_u32_precomputation_for_G() {
265 let decode_u32_precomputation_for_G = decode_u32_precomputation(G);
266 if decode_u32_precomputation_for_G.0 != DECODE_PRECOMPUTATION_FOR_G.0 {
269 use std::{fs::File, io::Write, path::PathBuf};
270 let mut f = File::create(PathBuf::from(
271 "src/encryption/decode_u32_precomputation_for_G.bincode",
272 ))
273 .unwrap();
274 f.write_all(&bincode::serialize(&decode_u32_precomputation_for_G).unwrap())
275 .unwrap();
276 panic!("Rebuild and run this test again");
277 }
278 }
279
280 #[test]
281 fn test_decode_correctness() {
282 let amount: u64 = 4294967295;
284
285 let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
286
287 let start_computation = Instant::now();
289 let decoded = instance.decode_u32();
290 let computation_secs = start_computation.elapsed().as_secs_f64();
291
292 assert_eq!(amount, decoded.unwrap());
293
294 println!("single thread discrete log computation secs: {computation_secs:?} sec");
295 }
296
297 #[cfg(not(target_arch = "wasm32"))]
298 #[test]
299 fn test_decode_correctness_threaded() {
300 let amount: u64 = 55;
302
303 let mut instance = DiscreteLog::new(G, Scalar::from(amount) * G);
304 instance.num_threads(4.try_into().unwrap()).unwrap();
305
306 let start_computation = Instant::now();
308 let decoded = instance.decode_u32();
309 let computation_secs = start_computation.elapsed().as_secs_f64();
310
311 assert_eq!(amount, decoded.unwrap());
312
313 println!("4 thread discrete log computation: {computation_secs:?} sec");
314
315 let amount: u64 = 0;
317
318 let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
319
320 let decoded = instance.decode_u32();
321 assert_eq!(amount, decoded.unwrap());
322
323 let amount: u64 = 1;
325
326 let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
327
328 let decoded = instance.decode_u32();
329 assert_eq!(amount, decoded.unwrap());
330
331 let amount: u64 = 2;
333
334 let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
335
336 let decoded = instance.decode_u32();
337 assert_eq!(amount, decoded.unwrap());
338
339 let amount: u64 = 3;
341
342 let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
343
344 let decoded = instance.decode_u32();
345 assert_eq!(amount, decoded.unwrap());
346
347 let amount: u64 = (1_u64 << 32) - 1;
349
350 let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
351
352 let decoded = instance.decode_u32();
353 assert_eq!(amount, decoded.unwrap());
354 }
355}