solana_zk_token_sdk/encryption/
grouped_elgamal.rs1use {
15 crate::{
16 encryption::{
17 discrete_log::DiscreteLog,
18 elgamal::{DecryptHandle, ElGamalCiphertext, ElGamalPubkey, ElGamalSecretKey},
19 pedersen::{Pedersen, PedersenCommitment, PedersenOpening},
20 },
21 RISTRETTO_POINT_LEN,
22 },
23 curve25519_dalek::scalar::Scalar,
24 thiserror::Error,
25};
26
27#[derive(Error, Clone, Debug, Eq, PartialEq)]
28pub enum GroupedElGamalError {
29 #[error("index out of bounds")]
30 IndexOutOfBounds,
31}
32
33pub struct GroupedElGamal<const N: usize>;
35impl<const N: usize> GroupedElGamal<N> {
36 pub fn encrypt<T: Into<Scalar>>(
40 pubkeys: [&ElGamalPubkey; N],
41 amount: T,
42 ) -> GroupedElGamalCiphertext<N> {
43 let (commitment, opening) = Pedersen::new(amount);
44 let handles: [DecryptHandle; N] = pubkeys
45 .iter()
46 .map(|handle| handle.decrypt_handle(&opening))
47 .collect::<Vec<DecryptHandle>>()
48 .try_into()
49 .unwrap();
50
51 GroupedElGamalCiphertext {
52 commitment,
53 handles,
54 }
55 }
56
57 pub fn encrypt_with<T: Into<Scalar>>(
60 pubkeys: [&ElGamalPubkey; N],
61 amount: T,
62 opening: &PedersenOpening,
63 ) -> GroupedElGamalCiphertext<N> {
64 let commitment = Pedersen::with(amount, opening);
65 let handles: [DecryptHandle; N] = pubkeys
66 .iter()
67 .map(|handle| handle.decrypt_handle(opening))
68 .collect::<Vec<DecryptHandle>>()
69 .try_into()
70 .unwrap();
71
72 GroupedElGamalCiphertext {
73 commitment,
74 handles,
75 }
76 }
77
78 fn to_elgamal_ciphertext(
81 grouped_ciphertext: &GroupedElGamalCiphertext<N>,
82 index: usize,
83 ) -> Result<ElGamalCiphertext, GroupedElGamalError> {
84 let handle = grouped_ciphertext
85 .handles
86 .get(index)
87 .ok_or(GroupedElGamalError::IndexOutOfBounds)?;
88
89 Ok(ElGamalCiphertext {
90 commitment: grouped_ciphertext.commitment,
91 handle: *handle,
92 })
93 }
94
95 fn decrypt(
101 grouped_ciphertext: &GroupedElGamalCiphertext<N>,
102 secret: &ElGamalSecretKey,
103 index: usize,
104 ) -> Result<DiscreteLog, GroupedElGamalError> {
105 Self::to_elgamal_ciphertext(grouped_ciphertext, index)
106 .map(|ciphertext| ciphertext.decrypt(secret))
107 }
108
109 fn decrypt_u32(
115 grouped_ciphertext: &GroupedElGamalCiphertext<N>,
116 secret: &ElGamalSecretKey,
117 index: usize,
118 ) -> Result<Option<u64>, GroupedElGamalError> {
119 Self::to_elgamal_ciphertext(grouped_ciphertext, index)
120 .map(|ciphertext| ciphertext.decrypt_u32(secret))
121 }
122}
123
124#[derive(Clone, Copy, Debug, Eq, PartialEq)]
129pub struct GroupedElGamalCiphertext<const N: usize> {
130 pub commitment: PedersenCommitment,
131 pub handles: [DecryptHandle; N],
132}
133
134impl<const N: usize> GroupedElGamalCiphertext<N> {
135 pub fn decrypt(
141 &self,
142 secret: &ElGamalSecretKey,
143 index: usize,
144 ) -> Result<DiscreteLog, GroupedElGamalError> {
145 GroupedElGamal::decrypt(self, secret, index)
146 }
147
148 pub fn decrypt_u32(
154 &self,
155 secret: &ElGamalSecretKey,
156 index: usize,
157 ) -> Result<Option<u64>, GroupedElGamalError> {
158 GroupedElGamal::decrypt_u32(self, secret, index)
159 }
160
161 fn expected_byte_length() -> usize {
168 N.checked_add(1)
169 .and_then(|length| length.checked_mul(RISTRETTO_POINT_LEN))
170 .unwrap()
171 }
172
173 pub fn to_bytes(&self) -> Vec<u8> {
174 let mut buf = Vec::with_capacity(Self::expected_byte_length());
175 buf.extend_from_slice(&self.commitment.to_bytes());
176 self.handles
177 .iter()
178 .for_each(|handle| buf.extend_from_slice(&handle.to_bytes()));
179 buf
180 }
181
182 pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
183 if bytes.len() != Self::expected_byte_length() {
184 return None;
185 }
186
187 let mut iter = bytes.chunks(RISTRETTO_POINT_LEN);
188 let commitment = PedersenCommitment::from_bytes(iter.next()?)?;
189
190 let mut handles = Vec::with_capacity(N);
191 for handle_bytes in iter {
192 handles.push(DecryptHandle::from_bytes(handle_bytes)?);
193 }
194
195 Some(Self {
196 commitment,
197 handles: handles.try_into().unwrap(),
198 })
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use {super::*, crate::encryption::elgamal::ElGamalKeypair};
205
206 #[test]
207 fn test_grouped_elgamal_encrypt_decrypt_correctness() {
208 let elgamal_keypair_0 = ElGamalKeypair::new_rand();
209 let elgamal_keypair_1 = ElGamalKeypair::new_rand();
210 let elgamal_keypair_2 = ElGamalKeypair::new_rand();
211
212 let amount: u64 = 10;
213 let grouped_ciphertext = GroupedElGamal::encrypt(
214 [
215 elgamal_keypair_0.pubkey(),
216 elgamal_keypair_1.pubkey(),
217 elgamal_keypair_2.pubkey(),
218 ],
219 amount,
220 );
221
222 assert_eq!(
223 Some(amount),
224 grouped_ciphertext
225 .decrypt_u32(elgamal_keypair_0.secret(), 0)
226 .unwrap()
227 );
228
229 assert_eq!(
230 Some(amount),
231 grouped_ciphertext
232 .decrypt_u32(elgamal_keypair_1.secret(), 1)
233 .unwrap()
234 );
235
236 assert_eq!(
237 Some(amount),
238 grouped_ciphertext
239 .decrypt_u32(elgamal_keypair_2.secret(), 2)
240 .unwrap()
241 );
242
243 assert_eq!(
244 GroupedElGamalError::IndexOutOfBounds,
245 grouped_ciphertext
246 .decrypt_u32(elgamal_keypair_0.secret(), 3)
247 .unwrap_err()
248 );
249 }
250
251 #[test]
252 fn test_grouped_ciphertext_bytes() {
253 let elgamal_keypair_0 = ElGamalKeypair::new_rand();
254 let elgamal_keypair_1 = ElGamalKeypair::new_rand();
255 let elgamal_keypair_2 = ElGamalKeypair::new_rand();
256
257 let amount: u64 = 10;
258 let grouped_ciphertext = GroupedElGamal::encrypt(
259 [
260 elgamal_keypair_0.pubkey(),
261 elgamal_keypair_1.pubkey(),
262 elgamal_keypair_2.pubkey(),
263 ],
264 amount,
265 );
266
267 let produced_bytes = grouped_ciphertext.to_bytes();
268 assert_eq!(produced_bytes.len(), 128);
269
270 let decoded_grouped_ciphertext =
271 GroupedElGamalCiphertext::<3>::from_bytes(&produced_bytes).unwrap();
272 assert_eq!(
273 Some(amount),
274 decoded_grouped_ciphertext
275 .decrypt_u32(elgamal_keypair_0.secret(), 0)
276 .unwrap()
277 );
278
279 assert_eq!(
280 Some(amount),
281 decoded_grouped_ciphertext
282 .decrypt_u32(elgamal_keypair_1.secret(), 1)
283 .unwrap()
284 );
285
286 assert_eq!(
287 Some(amount),
288 decoded_grouped_ciphertext
289 .decrypt_u32(elgamal_keypair_2.secret(), 2)
290 .unwrap()
291 );
292 }
293}