solana_zk_sdk/encryption/
discrete_log.rs

1//! The discrete log implementation for the twisted ElGamal decryption.
2//!
3//! The implementation uses the baby-step giant-step method, which consists of a precomputation
4//! step and an online step. The precomputation step involves computing a hash table of a number
5//! of Ristretto points that is independent of a discrete log instance. The online phase computes
6//! the final discrete log solution using the discrete log instance and the pre-computed hash
7//! table. More details on the baby-step giant-step algorithm and the implementation can be found
8//! in the [spl documentation](https://spl.solana.com).
9//!
10//! The implementation is NOT intended to run in constant-time. There are some measures to prevent
11//! straightforward timing attacks. For instance, it does not short-circuit the search when a
12//! solution is found. However, the use of hashtables, batching, and threads make the
13//! implementation inherently not constant-time. This may theoretically allow an adversary to gain
14//! information on a discrete log solution depending on the execution time of the implementation.
15//!
16
17#[cfg(not(target_arch = "wasm32"))]
18use std::thread;
19use {
20    crate::RISTRETTO_POINT_LEN,
21    curve25519_dalek::{
22        constants::RISTRETTO_BASEPOINT_POINT as G,
23        ristretto::RistrettoPoint,
24        scalar::Scalar,
25        traits::{Identity, IsIdentity},
26    },
27    itertools::Itertools,
28    serde::{Deserialize, Serialize},
29    std::{collections::HashMap, num::NonZeroUsize},
30    thiserror::Error,
31};
32
33const TWO16: u64 = 65536; // 2^16
34const TWO17: u64 = 131072; // 2^17
35
36/// Maximum number of threads permitted for discrete log computation
37#[cfg(not(target_arch = "wasm32"))]
38const MAX_THREAD: usize = 65536;
39
40#[derive(Error, Clone, Debug, Eq, PartialEq)]
41pub enum DiscreteLogError {
42    #[error("discrete log number of threads not power-of-two")]
43    DiscreteLogThreads,
44    #[error("discrete log batch size too large")]
45    DiscreteLogBatchSize,
46}
47
48/// Type that captures a discrete log challenge.
49///
50/// The goal of discrete log is to find x such that x * generator = target.
51#[derive(Serialize, Deserialize, Copy, Clone, Debug, Eq, PartialEq)]
52pub struct DiscreteLog {
53    /// Generator point for discrete log
54    pub generator: RistrettoPoint,
55    /// Target point for discrete log
56    pub target: RistrettoPoint,
57    /// Number of threads used for discrete log computation
58    num_threads: Option<NonZeroUsize>,
59    /// Range bound for discrete log search derived from the max value to search for and
60    /// `num_threads`
61    range_bound: NonZeroUsize,
62    /// Ristretto point representing each step of the discrete log search
63    step_point: RistrettoPoint,
64    /// Ristretto point compression batch size
65    compression_batch_size: NonZeroUsize,
66}
67
68#[derive(Serialize, Deserialize, Default)]
69pub struct DecodePrecomputation(HashMap<[u8; RISTRETTO_POINT_LEN], u16>);
70
71/// Builds a HashMap of 2^16 elements
72#[allow(dead_code)]
73fn decode_u32_precomputation(generator: RistrettoPoint) -> DecodePrecomputation {
74    let mut hashmap = HashMap::new();
75
76    let two17_scalar = Scalar::from(TWO17);
77    let identity = RistrettoPoint::identity(); // 0 * G
78    let generator = two17_scalar * generator; // 2^17 * G
79
80    // iterator for 2^17*0G , 2^17*1G, 2^17*2G, ...
81    let ristretto_iter = RistrettoIterator::new((identity, 0), (generator, 1));
82    for (point, x_hi) in ristretto_iter.take(TWO16 as usize) {
83        let key = point.compress().to_bytes();
84        hashmap.insert(key, x_hi as u16);
85    }
86
87    DecodePrecomputation(hashmap)
88}
89
90lazy_static::lazy_static! {
91    /// Pre-computed HashMap needed for decryption. The HashMap is independent of (works for) any key.
92    pub static ref DECODE_PRECOMPUTATION_FOR_G: DecodePrecomputation = {
93        static DECODE_PRECOMPUTATION_FOR_G_BINCODE: &[u8] =
94            include_bytes!("decode_u32_precomputation_for_G.bincode");
95        bincode::deserialize(DECODE_PRECOMPUTATION_FOR_G_BINCODE).unwrap_or_default()
96    };
97}
98
99/// Solves the discrete log instance using a 16/16 bit offline/online split
100impl DiscreteLog {
101    /// Discrete log instance constructor.
102    ///
103    /// Default number of threads set to 1.
104    pub fn new(generator: RistrettoPoint, target: RistrettoPoint) -> Self {
105        Self {
106            generator,
107            target,
108            num_threads: None,
109            range_bound: (TWO16 as usize).try_into().unwrap(),
110            step_point: G,
111            compression_batch_size: 32.try_into().unwrap(),
112        }
113    }
114
115    /// Adjusts number of threads in a discrete log instance.
116    #[cfg(not(target_arch = "wasm32"))]
117    pub fn num_threads(&mut self, num_threads: NonZeroUsize) -> Result<(), DiscreteLogError> {
118        // number of threads must be a positive power-of-two integer
119        if !num_threads.is_power_of_two() || num_threads.get() > MAX_THREAD {
120            return Err(DiscreteLogError::DiscreteLogThreads);
121        }
122
123        self.num_threads = Some(num_threads);
124        self.range_bound = (TWO16 as usize)
125            .checked_div(num_threads.get())
126            .and_then(|range_bound| range_bound.try_into().ok())
127            .unwrap(); // `num_threads` cannot exceed `TWO16`, so `range_bound` always non-zero
128        self.step_point = Scalar::from(num_threads.get() as u64) * G;
129
130        Ok(())
131    }
132
133    /// Adjusts inversion batch size in a discrete log instance.
134    pub fn set_compression_batch_size(
135        &mut self,
136        compression_batch_size: NonZeroUsize,
137    ) -> Result<(), DiscreteLogError> {
138        if compression_batch_size.get() >= TWO16 as usize {
139            return Err(DiscreteLogError::DiscreteLogBatchSize);
140        }
141        self.compression_batch_size = compression_batch_size;
142
143        Ok(())
144    }
145
146    /// Solves the discrete log problem under the assumption that the solution
147    /// is a positive 32-bit number.
148    pub fn decode_u32(self) -> Option<u64> {
149        #[allow(unused_variables)]
150        if let Some(num_threads) = self.num_threads {
151            #[cfg(not(target_arch = "wasm32"))]
152            {
153                let mut starting_point = self.target;
154                let handles = (0..num_threads.get())
155                    .map(|i| {
156                        let ristretto_iterator = RistrettoIterator::new(
157                            (starting_point, i as u64),
158                            (-(&self.step_point), num_threads.get() as u64),
159                        );
160
161                        let handle = thread::spawn(move || {
162                            Self::decode_range(
163                                ristretto_iterator,
164                                self.range_bound,
165                                self.compression_batch_size,
166                            )
167                        });
168
169                        starting_point -= G;
170                        handle
171                    })
172                    .collect::<Vec<_>>();
173
174                handles
175                    .into_iter()
176                    .map_while(|h| h.join().ok())
177                    .find(|x| x.is_some())
178                    .flatten()
179            }
180            #[cfg(target_arch = "wasm32")]
181            unreachable!() // `self.num_threads` always `None` on wasm target
182        } else {
183            let ristretto_iterator =
184                RistrettoIterator::new((self.target, 0_u64), (-(&self.step_point), 1u64));
185
186            Self::decode_range(
187                ristretto_iterator,
188                self.range_bound,
189                self.compression_batch_size,
190            )
191        }
192    }
193
194    fn decode_range(
195        ristretto_iterator: RistrettoIterator,
196        range_bound: NonZeroUsize,
197        compression_batch_size: NonZeroUsize,
198    ) -> Option<u64> {
199        let hashmap = &DECODE_PRECOMPUTATION_FOR_G;
200        let mut decoded = None;
201
202        for batch in &ristretto_iterator
203            .take(range_bound.get())
204            .chunks(compression_batch_size.get())
205        {
206            // batch compression currently errors if any point in the batch is the identity point
207            let (batch_points, batch_indices): (Vec<_>, Vec<_>) = batch
208                .filter(|(point, index)| {
209                    if point.is_identity() {
210                        decoded = Some(*index);
211                        return false;
212                    }
213                    true
214                })
215                .unzip();
216
217            let batch_compressed = RistrettoPoint::double_and_compress_batch(&batch_points);
218
219            for (point, x_lo) in batch_compressed.iter().zip(batch_indices.iter()) {
220                let key = point.to_bytes();
221                if hashmap.0.contains_key(&key) {
222                    let x_hi = hashmap.0[&key];
223                    decoded = Some(x_lo + TWO16 * x_hi as u64);
224                }
225            }
226        }
227
228        decoded
229    }
230}
231
232/// Hashable Ristretto iterator.
233///
234/// Given an initial point X and a stepping point P, the iterator iterates through
235/// X + 0*P, X + 1*P, X + 2*P, X + 3*P, ...
236struct RistrettoIterator {
237    pub current: (RistrettoPoint, u64),
238    pub step: (RistrettoPoint, u64),
239}
240
241impl RistrettoIterator {
242    fn new(current: (RistrettoPoint, u64), step: (RistrettoPoint, u64)) -> Self {
243        RistrettoIterator { current, step }
244    }
245}
246
247impl Iterator for RistrettoIterator {
248    type Item = (RistrettoPoint, u64);
249
250    fn next(&mut self) -> Option<Self::Item> {
251        let r = self.current;
252        self.current = (self.current.0 + self.step.0, self.current.1 + self.step.1);
253        Some(r)
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use {super::*, std::time::Instant};
260
261    #[test]
262    #[allow(non_snake_case)]
263    fn test_serialize_decode_u32_precomputation_for_G() {
264        let decode_u32_precomputation_for_G = decode_u32_precomputation(G);
265        // let decode_u32_precomputation_for_G = decode_u32_precomputation(G);
266
267        if decode_u32_precomputation_for_G.0 != DECODE_PRECOMPUTATION_FOR_G.0 {
268            use std::{fs::File, io::Write, path::PathBuf};
269            let mut f = File::create(PathBuf::from(
270                "src/encryption/decode_u32_precomputation_for_G.bincode",
271            ))
272            .unwrap();
273            f.write_all(&bincode::serialize(&decode_u32_precomputation_for_G).unwrap())
274                .unwrap();
275            panic!("Rebuild and run this test again");
276        }
277    }
278
279    #[test]
280    fn test_decode_correctness() {
281        // general case
282        let amount: u64 = 4294967295;
283
284        let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
285
286        // Very informal measurements for now
287        let start_computation = Instant::now();
288        let decoded = instance.decode_u32();
289        let computation_secs = start_computation.elapsed().as_secs_f64();
290
291        assert_eq!(amount, decoded.unwrap());
292
293        println!("single thread discrete log computation secs: {computation_secs:?} sec");
294    }
295
296    #[cfg(not(target_arch = "wasm32"))]
297    #[test]
298    fn test_decode_correctness_threaded() {
299        // general case
300        let amount: u64 = 55;
301
302        let mut instance = DiscreteLog::new(G, Scalar::from(amount) * G);
303        instance.num_threads(4.try_into().unwrap()).unwrap();
304
305        // Very informal measurements for now
306        let start_computation = Instant::now();
307        let decoded = instance.decode_u32();
308        let computation_secs = start_computation.elapsed().as_secs_f64();
309
310        assert_eq!(amount, decoded.unwrap());
311
312        println!("4 thread discrete log computation: {computation_secs:?} sec");
313
314        // amount 0
315        let amount: u64 = 0;
316
317        let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
318
319        let decoded = instance.decode_u32();
320        assert_eq!(amount, decoded.unwrap());
321
322        // amount 1
323        let amount: u64 = 1;
324
325        let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
326
327        let decoded = instance.decode_u32();
328        assert_eq!(amount, decoded.unwrap());
329
330        // amount 2
331        let amount: u64 = 2;
332
333        let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
334
335        let decoded = instance.decode_u32();
336        assert_eq!(amount, decoded.unwrap());
337
338        // amount 3
339        let amount: u64 = 3;
340
341        let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
342
343        let decoded = instance.decode_u32();
344        assert_eq!(amount, decoded.unwrap());
345
346        // max amount
347        let amount: u64 = (1_u64 << 32) - 1;
348
349        let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
350
351        let decoded = instance.decode_u32();
352        assert_eq!(amount, decoded.unwrap());
353    }
354}