safe_zk_token_sdk/encryption/
discrete_log.rs1#![cfg(not(target_os = "solana"))]
18
19use {
20 crate::encryption::errors::DiscreteLogError,
21 curve25519_dalek::{
22 constants::RISTRETTO_BASEPOINT_POINT as G,
23 ristretto::RistrettoPoint,
24 scalar::Scalar,
25 traits::{Identity, IsIdentity},
26 },
27 itertools::Itertools,
28 serde::{Deserialize, Serialize},
29 std::{collections::HashMap, thread},
30};
31
32const TWO16: u64 = 65536; const TWO17: u64 = 131072; #[derive(Serialize, Deserialize, Copy, Clone, Debug, Eq, PartialEq)]
39pub struct DiscreteLog {
40 pub generator: RistrettoPoint,
42 pub target: RistrettoPoint,
44 num_threads: usize,
46 range_bound: usize,
49 step_point: RistrettoPoint,
51 compression_batch_size: usize,
53}
54
55#[derive(Serialize, Deserialize, Default)]
56pub struct DecodePrecomputation(HashMap<[u8; 32], u16>);
57
58#[allow(dead_code)]
60fn decode_u32_precomputation(generator: RistrettoPoint) -> DecodePrecomputation {
61 let mut hashmap = HashMap::new();
62
63 let two17_scalar = Scalar::from(TWO17);
64 let identity = RistrettoPoint::identity(); let generator = two17_scalar * generator; let ristretto_iter = RistrettoIterator::new((identity, 0), (generator, 1));
69 for (point, x_hi) in ristretto_iter.take(TWO16 as usize) {
70 let key = point.compress().to_bytes();
71 hashmap.insert(key, x_hi as u16);
72 }
73
74 DecodePrecomputation(hashmap)
75}
76
77lazy_static::lazy_static! {
78 pub static ref DECODE_PRECOMPUTATION_FOR_G: DecodePrecomputation = {
80 static DECODE_PRECOMPUTATION_FOR_G_BINCODE: &[u8] =
81 include_bytes!("decode_u32_precomputation_for_G.bincode");
82 bincode::deserialize(DECODE_PRECOMPUTATION_FOR_G_BINCODE).unwrap_or_default()
83 };
84}
85
86impl DiscreteLog {
88 pub fn new(generator: RistrettoPoint, target: RistrettoPoint) -> Self {
92 Self {
93 generator,
94 target,
95 num_threads: 1,
96 range_bound: TWO16 as usize,
97 step_point: G,
98 compression_batch_size: 32,
99 }
100 }
101
102 pub fn num_threads(&mut self, num_threads: usize) -> Result<(), DiscreteLogError> {
104 if num_threads == 0 || (num_threads & (num_threads - 1)) != 0 || num_threads > 65536 {
106 return Err(DiscreteLogError::DiscreteLogThreads);
107 }
108
109 self.num_threads = num_threads;
110 self.range_bound = (TWO16 as usize).checked_div(num_threads).unwrap();
111 self.step_point = Scalar::from(num_threads as u64) * G;
112
113 Ok(())
114 }
115
116 pub fn set_compression_batch_size(
118 &mut self,
119 compression_batch_size: usize,
120 ) -> Result<(), DiscreteLogError> {
121 if compression_batch_size >= TWO16 as usize {
122 return Err(DiscreteLogError::DiscreteLogBatchSize);
123 }
124 self.compression_batch_size = compression_batch_size;
125
126 Ok(())
127 }
128
129 pub fn decode_u32(self) -> Option<u64> {
132 let mut starting_point = self.target;
133 let handles = (0..self.num_threads)
134 .into_iter()
135 .map(|i| {
136 let ristretto_iterator = RistrettoIterator::new(
137 (starting_point, i as u64),
138 (-(&self.step_point), self.num_threads as u64),
139 );
140
141 let handle = thread::spawn(move || {
142 Self::decode_range(
143 ristretto_iterator,
144 self.range_bound,
145 self.compression_batch_size,
146 )
147 });
148
149 starting_point -= G;
150 handle
151 })
152 .collect::<Vec<_>>();
153
154 let mut solution = None;
155 for handle in handles {
156 let discrete_log = handle.join().unwrap();
157 if discrete_log.is_some() {
158 solution = discrete_log;
159 }
160 }
161 solution
162 }
163
164 fn decode_range(
165 ristretto_iterator: RistrettoIterator,
166 range_bound: usize,
167 compression_batch_size: usize,
168 ) -> Option<u64> {
169 let hashmap = &DECODE_PRECOMPUTATION_FOR_G;
170 let mut decoded = None;
171
172 for batch in &ristretto_iterator
173 .take(range_bound)
174 .chunks(compression_batch_size)
175 {
176 let (batch_points, batch_indices): (Vec<_>, Vec<_>) = batch
178 .filter(|(point, index)| {
179 if point.is_identity() {
180 decoded = Some(*index);
181 return false;
182 }
183 true
184 })
185 .unzip();
186
187 let batch_compressed = RistrettoPoint::double_and_compress_batch(&batch_points);
188
189 for (point, x_lo) in batch_compressed.iter().zip(batch_indices.iter()) {
190 let key = point.to_bytes();
191 if hashmap.0.contains_key(&key) {
192 let x_hi = hashmap.0[&key];
193 decoded = Some(x_lo + TWO16 * x_hi as u64);
194 }
195 }
196 }
197
198 decoded
199 }
200}
201
202struct RistrettoIterator {
207 pub current: (RistrettoPoint, u64),
208 pub step: (RistrettoPoint, u64),
209}
210
211impl RistrettoIterator {
212 fn new(current: (RistrettoPoint, u64), step: (RistrettoPoint, u64)) -> Self {
213 RistrettoIterator { current, step }
214 }
215}
216
217impl Iterator for RistrettoIterator {
218 type Item = (RistrettoPoint, u64);
219
220 fn next(&mut self) -> Option<Self::Item> {
221 let r = self.current;
222 self.current = (self.current.0 + self.step.0, self.current.1 + self.step.1);
223 Some(r)
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use {super::*, std::time::Instant};
230
231 #[test]
232 #[allow(non_snake_case)]
233 fn test_serialize_decode_u32_precomputation_for_G() {
234 let decode_u32_precomputation_for_G = decode_u32_precomputation(G);
235 if decode_u32_precomputation_for_G.0 != DECODE_PRECOMPUTATION_FOR_G.0 {
238 use std::{fs::File, io::Write, path::PathBuf};
239 let mut f = File::create(PathBuf::from(
240 "src/encryption/decode_u32_precomputation_for_G.bincode",
241 ))
242 .unwrap();
243 f.write_all(&bincode::serialize(&decode_u32_precomputation_for_G).unwrap())
244 .unwrap();
245 panic!("Rebuild and run this test again");
246 }
247 }
248
249 #[test]
250 fn test_decode_correctness() {
251 let amount: u64 = 4294967295;
253
254 let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
255
256 let start_computation = Instant::now();
258 let decoded = instance.decode_u32();
259 let computation_secs = start_computation.elapsed().as_secs_f64();
260
261 assert_eq!(amount, decoded.unwrap());
262
263 println!(
264 "single thread discrete log computation secs: {:?} sec",
265 computation_secs
266 );
267 }
268
269 #[test]
270 fn test_decode_correctness_threaded() {
271 let amount: u64 = 55;
273
274 let mut instance = DiscreteLog::new(G, Scalar::from(amount) * G);
275 instance.num_threads(4).unwrap();
276
277 let start_computation = Instant::now();
279 let decoded = instance.decode_u32();
280 let computation_secs = start_computation.elapsed().as_secs_f64();
281
282 assert_eq!(amount, decoded.unwrap());
283
284 println!(
285 "4 thread discrete log computation: {:?} sec",
286 computation_secs
287 );
288
289 let amount: u64 = 0;
291
292 let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
293
294 let decoded = instance.decode_u32();
295 assert_eq!(amount, decoded.unwrap());
296
297 let amount: u64 = 1;
299
300 let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
301
302 let decoded = instance.decode_u32();
303 assert_eq!(amount, decoded.unwrap());
304
305 let amount: u64 = 2;
307
308 let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
309
310 let decoded = instance.decode_u32();
311 assert_eq!(amount, decoded.unwrap());
312
313 let amount: u64 = 3;
315
316 let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
317
318 let decoded = instance.decode_u32();
319 assert_eq!(amount, decoded.unwrap());
320
321 let amount: u64 = ((1_u64 << 32) - 1) as u64;
323
324 let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
325
326 let decoded = instance.decode_u32();
327 assert_eq!(amount, decoded.unwrap());
328 }
329}