solana_zk_sdk/encryption/
discrete_log.rs1#[cfg(not(target_arch = "wasm32"))]
18use std::thread;
19use {
20 crate::RISTRETTO_POINT_LEN,
21 curve25519_dalek::{
22 constants::RISTRETTO_BASEPOINT_POINT as G,
23 ristretto::RistrettoPoint,
24 scalar::Scalar,
25 traits::{Identity, IsIdentity},
26 },
27 itertools::Itertools,
28 serde::{Deserialize, Serialize},
29 std::{collections::HashMap, num::NonZeroUsize},
30 thiserror::Error,
31};
32
33const TWO16: u64 = 65536; const TWO17: u64 = 131072; #[cfg(not(target_arch = "wasm32"))]
38const MAX_THREAD: usize = 65536;
39
40#[derive(Error, Clone, Debug, Eq, PartialEq)]
41pub enum DiscreteLogError {
42 #[error("discrete log number of threads not power-of-two")]
43 DiscreteLogThreads,
44 #[error("discrete log batch size too large")]
45 DiscreteLogBatchSize,
46}
47
48#[derive(Serialize, Deserialize, Copy, Clone, Debug, Eq, PartialEq)]
52pub struct DiscreteLog {
53 pub generator: RistrettoPoint,
55 pub target: RistrettoPoint,
57 num_threads: Option<NonZeroUsize>,
59 range_bound: NonZeroUsize,
62 step_point: RistrettoPoint,
64 compression_batch_size: NonZeroUsize,
66}
67
68#[derive(Serialize, Deserialize, Default)]
69pub struct DecodePrecomputation(HashMap<[u8; RISTRETTO_POINT_LEN], u16>);
70
71#[allow(dead_code)]
73fn decode_u32_precomputation(generator: RistrettoPoint) -> DecodePrecomputation {
74 let mut hashmap = HashMap::new();
75
76 let two17_scalar = Scalar::from(TWO17);
77 let identity = RistrettoPoint::identity(); let generator = two17_scalar * generator; let ristretto_iter = RistrettoIterator::new((identity, 0), (generator, 1));
82 for (point, x_hi) in ristretto_iter.take(TWO16 as usize) {
83 let key = point.compress().to_bytes();
84 hashmap.insert(key, x_hi as u16);
85 }
86
87 DecodePrecomputation(hashmap)
88}
89
90lazy_static::lazy_static! {
91 pub static ref DECODE_PRECOMPUTATION_FOR_G: DecodePrecomputation = {
93 static DECODE_PRECOMPUTATION_FOR_G_BINCODE: &[u8] =
94 include_bytes!("decode_u32_precomputation_for_G.bincode");
95 bincode::deserialize(DECODE_PRECOMPUTATION_FOR_G_BINCODE).unwrap_or_default()
96 };
97}
98
99impl DiscreteLog {
101 pub fn new(generator: RistrettoPoint, target: RistrettoPoint) -> Self {
105 Self {
106 generator,
107 target,
108 num_threads: None,
109 range_bound: (TWO16 as usize).try_into().unwrap(),
110 step_point: G,
111 compression_batch_size: 32.try_into().unwrap(),
112 }
113 }
114
115 #[cfg(not(target_arch = "wasm32"))]
117 pub fn num_threads(&mut self, num_threads: NonZeroUsize) -> Result<(), DiscreteLogError> {
118 if !num_threads.is_power_of_two() || num_threads.get() > MAX_THREAD {
120 return Err(DiscreteLogError::DiscreteLogThreads);
121 }
122
123 self.num_threads = Some(num_threads);
124 self.range_bound = (TWO16 as usize)
125 .checked_div(num_threads.get())
126 .and_then(|range_bound| range_bound.try_into().ok())
127 .unwrap(); self.step_point = Scalar::from(num_threads.get() as u64) * G;
129
130 Ok(())
131 }
132
133 pub fn set_compression_batch_size(
135 &mut self,
136 compression_batch_size: NonZeroUsize,
137 ) -> Result<(), DiscreteLogError> {
138 if compression_batch_size.get() >= TWO16 as usize {
139 return Err(DiscreteLogError::DiscreteLogBatchSize);
140 }
141 self.compression_batch_size = compression_batch_size;
142
143 Ok(())
144 }
145
146 pub fn decode_u32(self) -> Option<u64> {
149 #[allow(unused_variables)]
150 if let Some(num_threads) = self.num_threads {
151 #[cfg(not(target_arch = "wasm32"))]
152 {
153 let mut starting_point = self.target;
154 let handles = (0..num_threads.get())
155 .map(|i| {
156 let ristretto_iterator = RistrettoIterator::new(
157 (starting_point, i as u64),
158 (-(&self.step_point), num_threads.get() as u64),
159 );
160
161 let handle = thread::spawn(move || {
162 Self::decode_range(
163 ristretto_iterator,
164 self.range_bound,
165 self.compression_batch_size,
166 )
167 });
168
169 starting_point -= G;
170 handle
171 })
172 .collect::<Vec<_>>();
173
174 handles
175 .into_iter()
176 .map_while(|h| h.join().ok())
177 .find(|x| x.is_some())
178 .flatten()
179 }
180 #[cfg(target_arch = "wasm32")]
181 unreachable!() } else {
183 let ristretto_iterator =
184 RistrettoIterator::new((self.target, 0_u64), (-(&self.step_point), 1u64));
185
186 Self::decode_range(
187 ristretto_iterator,
188 self.range_bound,
189 self.compression_batch_size,
190 )
191 }
192 }
193
194 fn decode_range(
195 ristretto_iterator: RistrettoIterator,
196 range_bound: NonZeroUsize,
197 compression_batch_size: NonZeroUsize,
198 ) -> Option<u64> {
199 let hashmap = &DECODE_PRECOMPUTATION_FOR_G;
200 let mut decoded = None;
201
202 for batch in &ristretto_iterator
203 .take(range_bound.get())
204 .chunks(compression_batch_size.get())
205 {
206 let (batch_points, batch_indices): (Vec<_>, Vec<_>) = batch
208 .filter(|(point, index)| {
209 if point.is_identity() {
210 decoded = Some(*index);
211 return false;
212 }
213 true
214 })
215 .unzip();
216
217 let batch_compressed = RistrettoPoint::double_and_compress_batch(&batch_points);
218
219 for (point, x_lo) in batch_compressed.iter().zip(batch_indices.iter()) {
220 let key = point.to_bytes();
221 if hashmap.0.contains_key(&key) {
222 let x_hi = hashmap.0[&key];
223 decoded = Some(x_lo + TWO16 * x_hi as u64);
224 }
225 }
226 }
227
228 decoded
229 }
230}
231
232struct RistrettoIterator {
237 pub current: (RistrettoPoint, u64),
238 pub step: (RistrettoPoint, u64),
239}
240
241impl RistrettoIterator {
242 fn new(current: (RistrettoPoint, u64), step: (RistrettoPoint, u64)) -> Self {
243 RistrettoIterator { current, step }
244 }
245}
246
247impl Iterator for RistrettoIterator {
248 type Item = (RistrettoPoint, u64);
249
250 fn next(&mut self) -> Option<Self::Item> {
251 let r = self.current;
252 self.current = (self.current.0 + self.step.0, self.current.1 + self.step.1);
253 Some(r)
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use {super::*, std::time::Instant};
260
261 #[test]
262 #[allow(non_snake_case)]
263 fn test_serialize_decode_u32_precomputation_for_G() {
264 let decode_u32_precomputation_for_G = decode_u32_precomputation(G);
265 if decode_u32_precomputation_for_G.0 != DECODE_PRECOMPUTATION_FOR_G.0 {
268 use std::{fs::File, io::Write, path::PathBuf};
269 let mut f = File::create(PathBuf::from(
270 "src/encryption/decode_u32_precomputation_for_G.bincode",
271 ))
272 .unwrap();
273 f.write_all(&bincode::serialize(&decode_u32_precomputation_for_G).unwrap())
274 .unwrap();
275 panic!("Rebuild and run this test again");
276 }
277 }
278
279 #[test]
280 fn test_decode_correctness() {
281 let amount: u64 = 4294967295;
283
284 let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
285
286 let start_computation = Instant::now();
288 let decoded = instance.decode_u32();
289 let computation_secs = start_computation.elapsed().as_secs_f64();
290
291 assert_eq!(amount, decoded.unwrap());
292
293 println!("single thread discrete log computation secs: {computation_secs:?} sec");
294 }
295
296 #[cfg(not(target_arch = "wasm32"))]
297 #[test]
298 fn test_decode_correctness_threaded() {
299 let amount: u64 = 55;
301
302 let mut instance = DiscreteLog::new(G, Scalar::from(amount) * G);
303 instance.num_threads(4.try_into().unwrap()).unwrap();
304
305 let start_computation = Instant::now();
307 let decoded = instance.decode_u32();
308 let computation_secs = start_computation.elapsed().as_secs_f64();
309
310 assert_eq!(amount, decoded.unwrap());
311
312 println!("4 thread discrete log computation: {computation_secs:?} sec");
313
314 let amount: u64 = 0;
316
317 let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
318
319 let decoded = instance.decode_u32();
320 assert_eq!(amount, decoded.unwrap());
321
322 let amount: u64 = 1;
324
325 let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
326
327 let decoded = instance.decode_u32();
328 assert_eq!(amount, decoded.unwrap());
329
330 let amount: u64 = 2;
332
333 let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
334
335 let decoded = instance.decode_u32();
336 assert_eq!(amount, decoded.unwrap());
337
338 let amount: u64 = 3;
340
341 let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
342
343 let decoded = instance.decode_u32();
344 assert_eq!(amount, decoded.unwrap());
345
346 let amount: u64 = (1_u64 << 32) - 1;
348
349 let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
350
351 let decoded = instance.decode_u32();
352 assert_eq!(amount, decoded.unwrap());
353 }
354}