safe_zk_token_sdk/encryption/
discrete_log.rs

1//! The discrete log implementation for the twisted ElGamal decryption.
2//!
3//! The implementation uses the baby-step giant-step method, which consists of a precomputation
4//! step and an online step. The precomputation step involves computing a hash table of a number
5//! of Ristretto points that is independent of a discrete log instance. The online phase computes
6//! the final discrete log solution using the discrete log instance and the pre-computed hash
7//! table. More details on the baby-step giant-step algorithm and the implementation can be found
8//! in the [spl documentation](https://spl.solana.com).
9//!
10//! The implementation is NOT intended to run in constant-time. There are some measures to prevent
11//! straightforward timing attacks. For instance, it does not short-circuit the search when a
12//! solution is found. However, the use of hashtables, batching, and threads make the
13//! implementation inherently not constant-time. This may theoretically allow an adversary to gain
14//! information on a discrete log solution depending on the execution time of the implementation.
15//!
16
17#![cfg(not(target_os = "solana"))]
18
19use {
20    crate::encryption::errors::DiscreteLogError,
21    curve25519_dalek::{
22        constants::RISTRETTO_BASEPOINT_POINT as G,
23        ristretto::RistrettoPoint,
24        scalar::Scalar,
25        traits::{Identity, IsIdentity},
26    },
27    itertools::Itertools,
28    serde::{Deserialize, Serialize},
29    std::{collections::HashMap, thread},
30};
31
32const TWO16: u64 = 65536; // 2^16
33const TWO17: u64 = 131072; // 2^17
34
35/// Type that captures a discrete log challenge.
36///
37/// The goal of discrete log is to find x such that x * generator = target.
38#[derive(Serialize, Deserialize, Copy, Clone, Debug, Eq, PartialEq)]
39pub struct DiscreteLog {
40    /// Generator point for discrete log
41    pub generator: RistrettoPoint,
42    /// Target point for discrete log
43    pub target: RistrettoPoint,
44    /// Number of threads used for discrete log computation
45    num_threads: usize,
46    /// Range bound for discrete log search derived from the max value to search for and
47    /// `num_threads`
48    range_bound: usize,
49    /// Ristretto point representing each step of the discrete log search
50    step_point: RistrettoPoint,
51    /// Ristretto point compression batch size
52    compression_batch_size: usize,
53}
54
55#[derive(Serialize, Deserialize, Default)]
56pub struct DecodePrecomputation(HashMap<[u8; 32], u16>);
57
58/// Builds a HashMap of 2^16 elements
59#[allow(dead_code)]
60fn decode_u32_precomputation(generator: RistrettoPoint) -> DecodePrecomputation {
61    let mut hashmap = HashMap::new();
62
63    let two17_scalar = Scalar::from(TWO17);
64    let identity = RistrettoPoint::identity(); // 0 * G
65    let generator = two17_scalar * generator; // 2^17 * G
66
67    // iterator for 2^17*0G , 2^17*1G, 2^17*2G, ...
68    let ristretto_iter = RistrettoIterator::new((identity, 0), (generator, 1));
69    for (point, x_hi) in ristretto_iter.take(TWO16 as usize) {
70        let key = point.compress().to_bytes();
71        hashmap.insert(key, x_hi as u16);
72    }
73
74    DecodePrecomputation(hashmap)
75}
76
77lazy_static::lazy_static! {
78    /// Pre-computed HashMap needed for decryption. The HashMap is independent of (works for) any key.
79    pub static ref DECODE_PRECOMPUTATION_FOR_G: DecodePrecomputation = {
80        static DECODE_PRECOMPUTATION_FOR_G_BINCODE: &[u8] =
81            include_bytes!("decode_u32_precomputation_for_G.bincode");
82        bincode::deserialize(DECODE_PRECOMPUTATION_FOR_G_BINCODE).unwrap_or_default()
83    };
84}
85
86/// Safeves the discrete log instance using a 16/16 bit offline/online split
87impl DiscreteLog {
88    /// Discrete log instance constructor.
89    ///
90    /// Default number of threads set to 1.
91    pub fn new(generator: RistrettoPoint, target: RistrettoPoint) -> Self {
92        Self {
93            generator,
94            target,
95            num_threads: 1,
96            range_bound: TWO16 as usize,
97            step_point: G,
98            compression_batch_size: 32,
99        }
100    }
101
102    /// Adjusts number of threads in a discrete log instance.
103    pub fn num_threads(&mut self, num_threads: usize) -> Result<(), DiscreteLogError> {
104        // number of threads must be a positive power-of-two integer
105        if num_threads == 0 || (num_threads & (num_threads - 1)) != 0 || num_threads > 65536 {
106            return Err(DiscreteLogError::DiscreteLogThreads);
107        }
108
109        self.num_threads = num_threads;
110        self.range_bound = (TWO16 as usize).checked_div(num_threads).unwrap();
111        self.step_point = Scalar::from(num_threads as u64) * G;
112
113        Ok(())
114    }
115
116    /// Adjusts inversion batch size in a discrete log instance.
117    pub fn set_compression_batch_size(
118        &mut self,
119        compression_batch_size: usize,
120    ) -> Result<(), DiscreteLogError> {
121        if compression_batch_size >= TWO16 as usize {
122            return Err(DiscreteLogError::DiscreteLogBatchSize);
123        }
124        self.compression_batch_size = compression_batch_size;
125
126        Ok(())
127    }
128
129    /// Safeves the discrete log problem under the assumption that the solution
130    /// is a positive 32-bit number.
131    pub fn decode_u32(self) -> Option<u64> {
132        let mut starting_point = self.target;
133        let handles = (0..self.num_threads)
134            .into_iter()
135            .map(|i| {
136                let ristretto_iterator = RistrettoIterator::new(
137                    (starting_point, i as u64),
138                    (-(&self.step_point), self.num_threads as u64),
139                );
140
141                let handle = thread::spawn(move || {
142                    Self::decode_range(
143                        ristretto_iterator,
144                        self.range_bound,
145                        self.compression_batch_size,
146                    )
147                });
148
149                starting_point -= G;
150                handle
151            })
152            .collect::<Vec<_>>();
153
154        let mut solution = None;
155        for handle in handles {
156            let discrete_log = handle.join().unwrap();
157            if discrete_log.is_some() {
158                solution = discrete_log;
159            }
160        }
161        solution
162    }
163
164    fn decode_range(
165        ristretto_iterator: RistrettoIterator,
166        range_bound: usize,
167        compression_batch_size: usize,
168    ) -> Option<u64> {
169        let hashmap = &DECODE_PRECOMPUTATION_FOR_G;
170        let mut decoded = None;
171
172        for batch in &ristretto_iterator
173            .take(range_bound)
174            .chunks(compression_batch_size)
175        {
176            // batch compression currently errors if any point in the batch is the identity point
177            let (batch_points, batch_indices): (Vec<_>, Vec<_>) = batch
178                .filter(|(point, index)| {
179                    if point.is_identity() {
180                        decoded = Some(*index);
181                        return false;
182                    }
183                    true
184                })
185                .unzip();
186
187            let batch_compressed = RistrettoPoint::double_and_compress_batch(&batch_points);
188
189            for (point, x_lo) in batch_compressed.iter().zip(batch_indices.iter()) {
190                let key = point.to_bytes();
191                if hashmap.0.contains_key(&key) {
192                    let x_hi = hashmap.0[&key];
193                    decoded = Some(x_lo + TWO16 * x_hi as u64);
194                }
195            }
196        }
197
198        decoded
199    }
200}
201
202/// Hashable Ristretto iterator.
203///
204/// Given an initial point X and a stepping point P, the iterator iterates through
205/// X + 0*P, X + 1*P, X + 2*P, X + 3*P, ...
206struct RistrettoIterator {
207    pub current: (RistrettoPoint, u64),
208    pub step: (RistrettoPoint, u64),
209}
210
211impl RistrettoIterator {
212    fn new(current: (RistrettoPoint, u64), step: (RistrettoPoint, u64)) -> Self {
213        RistrettoIterator { current, step }
214    }
215}
216
217impl Iterator for RistrettoIterator {
218    type Item = (RistrettoPoint, u64);
219
220    fn next(&mut self) -> Option<Self::Item> {
221        let r = self.current;
222        self.current = (self.current.0 + self.step.0, self.current.1 + self.step.1);
223        Some(r)
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use {super::*, std::time::Instant};
230
231    #[test]
232    #[allow(non_snake_case)]
233    fn test_serialize_decode_u32_precomputation_for_G() {
234        let decode_u32_precomputation_for_G = decode_u32_precomputation(G);
235        // let decode_u32_precomputation_for_G = decode_u32_precomputation(G);
236
237        if decode_u32_precomputation_for_G.0 != DECODE_PRECOMPUTATION_FOR_G.0 {
238            use std::{fs::File, io::Write, path::PathBuf};
239            let mut f = File::create(PathBuf::from(
240                "src/encryption/decode_u32_precomputation_for_G.bincode",
241            ))
242            .unwrap();
243            f.write_all(&bincode::serialize(&decode_u32_precomputation_for_G).unwrap())
244                .unwrap();
245            panic!("Rebuild and run this test again");
246        }
247    }
248
249    #[test]
250    fn test_decode_correctness() {
251        // general case
252        let amount: u64 = 4294967295;
253
254        let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
255
256        // Very informal measurements for now
257        let start_computation = Instant::now();
258        let decoded = instance.decode_u32();
259        let computation_secs = start_computation.elapsed().as_secs_f64();
260
261        assert_eq!(amount, decoded.unwrap());
262
263        println!(
264            "single thread discrete log computation secs: {:?} sec",
265            computation_secs
266        );
267    }
268
269    #[test]
270    fn test_decode_correctness_threaded() {
271        // general case
272        let amount: u64 = 55;
273
274        let mut instance = DiscreteLog::new(G, Scalar::from(amount) * G);
275        instance.num_threads(4).unwrap();
276
277        // Very informal measurements for now
278        let start_computation = Instant::now();
279        let decoded = instance.decode_u32();
280        let computation_secs = start_computation.elapsed().as_secs_f64();
281
282        assert_eq!(amount, decoded.unwrap());
283
284        println!(
285            "4 thread discrete log computation: {:?} sec",
286            computation_secs
287        );
288
289        // amount 0
290        let amount: u64 = 0;
291
292        let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
293
294        let decoded = instance.decode_u32();
295        assert_eq!(amount, decoded.unwrap());
296
297        // amount 1
298        let amount: u64 = 1;
299
300        let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
301
302        let decoded = instance.decode_u32();
303        assert_eq!(amount, decoded.unwrap());
304
305        // amount 2
306        let amount: u64 = 2;
307
308        let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
309
310        let decoded = instance.decode_u32();
311        assert_eq!(amount, decoded.unwrap());
312
313        // amount 3
314        let amount: u64 = 3;
315
316        let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
317
318        let decoded = instance.decode_u32();
319        assert_eq!(amount, decoded.unwrap());
320
321        // max amount
322        let amount: u64 = ((1_u64 << 32) - 1) as u64;
323
324        let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
325
326        let decoded = instance.decode_u32();
327        assert_eq!(amount, decoded.unwrap());
328    }
329}