solana_zk_token_sdk/encryption/
discrete_log.rs

1//! The discrete log implementation for the twisted ElGamal decryption.
2//!
3//! The implementation uses the baby-step giant-step method, which consists of a precomputation
4//! step and an online step. The precomputation step involves computing a hash table of a number
5//! of Ristretto points that is independent of a discrete log instance. The online phase computes
6//! the final discrete log solution using the discrete log instance and the pre-computed hash
7//! table. More details on the baby-step giant-step algorithm and the implementation can be found
8//! in the [spl documentation](https://spl.solana.com).
9//!
10//! The implementation is NOT intended to run in constant-time. There are some measures to prevent
11//! straightforward timing attacks. For instance, it does not short-circuit the search when a
12//! solution is found. However, the use of hashtables, batching, and threads make the
13//! implementation inherently not constant-time. This may theoretically allow an adversary to gain
14//! information on a discrete log solution depending on the execution time of the implementation.
15//!
16
17#![cfg(not(target_os = "solana"))]
18
19#[cfg(not(target_arch = "wasm32"))]
20use std::thread;
21use {
22    crate::RISTRETTO_POINT_LEN,
23    curve25519_dalek::{
24        constants::RISTRETTO_BASEPOINT_POINT as G,
25        ristretto::RistrettoPoint,
26        scalar::Scalar,
27        traits::{Identity, IsIdentity},
28    },
29    itertools::Itertools,
30    serde::{Deserialize, Serialize},
31    std::{collections::HashMap, num::NonZeroUsize},
32    thiserror::Error,
33};
34
35const TWO16: u64 = 65536; // 2^16
36const TWO17: u64 = 131072; // 2^17
37
38/// Maximum number of threads permitted for discrete log computation
39#[cfg(not(target_arch = "wasm32"))]
40const MAX_THREAD: usize = 65536;
41
42#[derive(Error, Clone, Debug, Eq, PartialEq)]
43pub enum DiscreteLogError {
44    #[error("discrete log number of threads not power-of-two")]
45    DiscreteLogThreads,
46    #[error("discrete log batch size too large")]
47    DiscreteLogBatchSize,
48}
49
50/// Type that captures a discrete log challenge.
51///
52/// The goal of discrete log is to find x such that x * generator = target.
53#[derive(Serialize, Deserialize, Copy, Clone, Debug, Eq, PartialEq)]
54pub struct DiscreteLog {
55    /// Generator point for discrete log
56    pub generator: RistrettoPoint,
57    /// Target point for discrete log
58    pub target: RistrettoPoint,
59    /// Number of threads used for discrete log computation
60    num_threads: Option<NonZeroUsize>,
61    /// Range bound for discrete log search derived from the max value to search for and
62    /// `num_threads`
63    range_bound: NonZeroUsize,
64    /// Ristretto point representing each step of the discrete log search
65    step_point: RistrettoPoint,
66    /// Ristretto point compression batch size
67    compression_batch_size: NonZeroUsize,
68}
69
70#[derive(Serialize, Deserialize, Default)]
71pub struct DecodePrecomputation(HashMap<[u8; RISTRETTO_POINT_LEN], u16>);
72
73/// Builds a HashMap of 2^16 elements
74#[allow(dead_code)]
75fn decode_u32_precomputation(generator: RistrettoPoint) -> DecodePrecomputation {
76    let mut hashmap = HashMap::new();
77
78    let two17_scalar = Scalar::from(TWO17);
79    let identity = RistrettoPoint::identity(); // 0 * G
80    let generator = two17_scalar * generator; // 2^17 * G
81
82    // iterator for 2^17*0G , 2^17*1G, 2^17*2G, ...
83    let ristretto_iter = RistrettoIterator::new((identity, 0), (generator, 1));
84    for (point, x_hi) in ristretto_iter.take(TWO16 as usize) {
85        let key = point.compress().to_bytes();
86        hashmap.insert(key, x_hi as u16);
87    }
88
89    DecodePrecomputation(hashmap)
90}
91
92lazy_static::lazy_static! {
93    /// Pre-computed HashMap needed for decryption. The HashMap is independent of (works for) any key.
94    pub static ref DECODE_PRECOMPUTATION_FOR_G: DecodePrecomputation = {
95        static DECODE_PRECOMPUTATION_FOR_G_BINCODE: &[u8] =
96            include_bytes!("decode_u32_precomputation_for_G.bincode");
97        bincode::deserialize(DECODE_PRECOMPUTATION_FOR_G_BINCODE).unwrap_or_default()
98    };
99}
100
101/// Solves the discrete log instance using a 16/16 bit offline/online split
102impl DiscreteLog {
103    /// Discrete log instance constructor.
104    ///
105    /// Default number of threads set to 1.
106    pub fn new(generator: RistrettoPoint, target: RistrettoPoint) -> Self {
107        Self {
108            generator,
109            target,
110            num_threads: None,
111            range_bound: (TWO16 as usize).try_into().unwrap(),
112            step_point: G,
113            compression_batch_size: 32.try_into().unwrap(),
114        }
115    }
116
117    /// Adjusts number of threads in a discrete log instance.
118    #[cfg(not(target_arch = "wasm32"))]
119    pub fn num_threads(&mut self, num_threads: NonZeroUsize) -> Result<(), DiscreteLogError> {
120        // number of threads must be a positive power-of-two integer
121        if !num_threads.is_power_of_two() || num_threads.get() > MAX_THREAD {
122            return Err(DiscreteLogError::DiscreteLogThreads);
123        }
124
125        self.num_threads = Some(num_threads);
126        self.range_bound = (TWO16 as usize)
127            .checked_div(num_threads.get())
128            .and_then(|range_bound| range_bound.try_into().ok())
129            .unwrap(); // `num_threads` cannot exceed `TWO16`, so `range_bound` always non-zero
130        self.step_point = Scalar::from(num_threads.get() as u64) * G;
131
132        Ok(())
133    }
134
135    /// Adjusts inversion batch size in a discrete log instance.
136    pub fn set_compression_batch_size(
137        &mut self,
138        compression_batch_size: NonZeroUsize,
139    ) -> Result<(), DiscreteLogError> {
140        if compression_batch_size.get() >= TWO16 as usize {
141            return Err(DiscreteLogError::DiscreteLogBatchSize);
142        }
143        self.compression_batch_size = compression_batch_size;
144
145        Ok(())
146    }
147
148    /// Solves the discrete log problem under the assumption that the solution
149    /// is a positive 32-bit number.
150    pub fn decode_u32(self) -> Option<u64> {
151        if let Some(num_threads) = self.num_threads {
152            #[cfg(not(target_arch = "wasm32"))]
153            {
154                let mut starting_point = self.target;
155                let handles = (0..num_threads.get())
156                    .map(|i| {
157                        let ristretto_iterator = RistrettoIterator::new(
158                            (starting_point, i as u64),
159                            (-(&self.step_point), num_threads.get() as u64),
160                        );
161
162                        let handle = thread::spawn(move || {
163                            Self::decode_range(
164                                ristretto_iterator,
165                                self.range_bound,
166                                self.compression_batch_size,
167                            )
168                        });
169
170                        starting_point -= G;
171                        handle
172                    })
173                    .collect::<Vec<_>>();
174
175                handles
176                    .into_iter()
177                    .map_while(|h| h.join().ok())
178                    .find(|x| x.is_some())
179                    .flatten()
180            }
181            #[cfg(target_arch = "wasm32")]
182            unreachable!() // `self.num_threads` always `None` on wasm target
183        } else {
184            let ristretto_iterator =
185                RistrettoIterator::new((self.target, 0_u64), (-(&self.step_point), 1u64));
186
187            Self::decode_range(
188                ristretto_iterator,
189                self.range_bound,
190                self.compression_batch_size,
191            )
192        }
193    }
194
195    fn decode_range(
196        ristretto_iterator: RistrettoIterator,
197        range_bound: NonZeroUsize,
198        compression_batch_size: NonZeroUsize,
199    ) -> Option<u64> {
200        let hashmap = &DECODE_PRECOMPUTATION_FOR_G;
201        let mut decoded = None;
202
203        for batch in &ristretto_iterator
204            .take(range_bound.get())
205            .chunks(compression_batch_size.get())
206        {
207            // batch compression currently errors if any point in the batch is the identity point
208            let (batch_points, batch_indices): (Vec<_>, Vec<_>) = batch
209                .filter(|(point, index)| {
210                    if point.is_identity() {
211                        decoded = Some(*index);
212                        return false;
213                    }
214                    true
215                })
216                .unzip();
217
218            let batch_compressed = RistrettoPoint::double_and_compress_batch(&batch_points);
219
220            for (point, x_lo) in batch_compressed.iter().zip(batch_indices.iter()) {
221                let key = point.to_bytes();
222                if hashmap.0.contains_key(&key) {
223                    let x_hi = hashmap.0[&key];
224                    decoded = Some(x_lo + TWO16 * x_hi as u64);
225                }
226            }
227        }
228
229        decoded
230    }
231}
232
233/// Hashable Ristretto iterator.
234///
235/// Given an initial point X and a stepping point P, the iterator iterates through
236/// X + 0*P, X + 1*P, X + 2*P, X + 3*P, ...
237struct RistrettoIterator {
238    pub current: (RistrettoPoint, u64),
239    pub step: (RistrettoPoint, u64),
240}
241
242impl RistrettoIterator {
243    fn new(current: (RistrettoPoint, u64), step: (RistrettoPoint, u64)) -> Self {
244        RistrettoIterator { current, step }
245    }
246}
247
248impl Iterator for RistrettoIterator {
249    type Item = (RistrettoPoint, u64);
250
251    fn next(&mut self) -> Option<Self::Item> {
252        let r = self.current;
253        self.current = (self.current.0 + self.step.0, self.current.1 + self.step.1);
254        Some(r)
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use {super::*, std::time::Instant};
261
262    #[test]
263    #[allow(non_snake_case)]
264    fn test_serialize_decode_u32_precomputation_for_G() {
265        let decode_u32_precomputation_for_G = decode_u32_precomputation(G);
266        // let decode_u32_precomputation_for_G = decode_u32_precomputation(G);
267
268        if decode_u32_precomputation_for_G.0 != DECODE_PRECOMPUTATION_FOR_G.0 {
269            use std::{fs::File, io::Write, path::PathBuf};
270            let mut f = File::create(PathBuf::from(
271                "src/encryption/decode_u32_precomputation_for_G.bincode",
272            ))
273            .unwrap();
274            f.write_all(&bincode::serialize(&decode_u32_precomputation_for_G).unwrap())
275                .unwrap();
276            panic!("Rebuild and run this test again");
277        }
278    }
279
280    #[test]
281    fn test_decode_correctness() {
282        // general case
283        let amount: u64 = 4294967295;
284
285        let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
286
287        // Very informal measurements for now
288        let start_computation = Instant::now();
289        let decoded = instance.decode_u32();
290        let computation_secs = start_computation.elapsed().as_secs_f64();
291
292        assert_eq!(amount, decoded.unwrap());
293
294        println!("single thread discrete log computation secs: {computation_secs:?} sec");
295    }
296
297    #[cfg(not(target_arch = "wasm32"))]
298    #[test]
299    fn test_decode_correctness_threaded() {
300        // general case
301        let amount: u64 = 55;
302
303        let mut instance = DiscreteLog::new(G, Scalar::from(amount) * G);
304        instance.num_threads(4.try_into().unwrap()).unwrap();
305
306        // Very informal measurements for now
307        let start_computation = Instant::now();
308        let decoded = instance.decode_u32();
309        let computation_secs = start_computation.elapsed().as_secs_f64();
310
311        assert_eq!(amount, decoded.unwrap());
312
313        println!("4 thread discrete log computation: {computation_secs:?} sec");
314
315        // amount 0
316        let amount: u64 = 0;
317
318        let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
319
320        let decoded = instance.decode_u32();
321        assert_eq!(amount, decoded.unwrap());
322
323        // amount 1
324        let amount: u64 = 1;
325
326        let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
327
328        let decoded = instance.decode_u32();
329        assert_eq!(amount, decoded.unwrap());
330
331        // amount 2
332        let amount: u64 = 2;
333
334        let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
335
336        let decoded = instance.decode_u32();
337        assert_eq!(amount, decoded.unwrap());
338
339        // amount 3
340        let amount: u64 = 3;
341
342        let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
343
344        let decoded = instance.decode_u32();
345        assert_eq!(amount, decoded.unwrap());
346
347        // max amount
348        let amount: u64 = (1_u64 << 32) - 1;
349
350        let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
351
352        let decoded = instance.decode_u32();
353        assert_eq!(amount, decoded.unwrap());
354    }
355}