solana_zk_token_sdk/encryption/
auth_encryption.rs1use {
6 crate::errors::AuthenticatedEncryptionError,
7 base64::{prelude::BASE64_STANDARD, Engine},
8 sha3::{Digest, Sha3_512},
9 solana_derivation_path::DerivationPath,
10 solana_sdk::{
11 signature::Signature,
12 signer::{
13 keypair::generate_seed_from_seed_phrase_and_passphrase, EncodableKey, SeedDerivable,
14 Signer, SignerError,
15 },
16 },
17 std::{
18 convert::TryInto,
19 error, fmt,
20 io::{Read, Write},
21 },
22 subtle::ConstantTimeEq,
23 zeroize::Zeroize,
24};
25#[cfg(not(target_os = "solana"))]
26use {
27 aes_gcm_siv::{
28 aead::{Aead, KeyInit},
29 Aes128GcmSiv,
30 },
31 rand::{rngs::OsRng, Rng},
32};
33
34pub const AE_KEY_LEN: usize = 16;
36
37const NONCE_LEN: usize = 12;
39
40const CIPHERTEXT_LEN: usize = 24;
42
43const AE_CIPHERTEXT_LEN: usize = 36;
46
47struct AuthenticatedEncryption;
48impl AuthenticatedEncryption {
49 #[cfg(not(target_os = "solana"))]
53 fn keygen() -> AeKey {
54 AeKey(OsRng.gen::<[u8; AE_KEY_LEN]>())
55 }
56
57 #[cfg(not(target_os = "solana"))]
60 fn encrypt(key: &AeKey, balance: u64) -> AeCiphertext {
61 let mut plaintext = balance.to_le_bytes();
62 let nonce: Nonce = OsRng.gen::<[u8; NONCE_LEN]>();
63
64 let ciphertext = Aes128GcmSiv::new(&key.0.into())
66 .encrypt(&nonce.into(), plaintext.as_ref())
67 .expect("authenticated encryption");
68
69 plaintext.zeroize();
70
71 AeCiphertext {
72 nonce,
73 ciphertext: ciphertext.try_into().unwrap(),
74 }
75 }
76
77 #[cfg(not(target_os = "solana"))]
80 fn decrypt(key: &AeKey, ciphertext: &AeCiphertext) -> Option<u64> {
81 let plaintext = Aes128GcmSiv::new(&key.0.into())
82 .decrypt(&ciphertext.nonce.into(), ciphertext.ciphertext.as_ref());
83
84 if let Ok(plaintext) = plaintext {
85 let amount_bytes: [u8; 8] = plaintext.try_into().unwrap();
86 Some(u64::from_le_bytes(amount_bytes))
87 } else {
88 None
89 }
90 }
91}
92
93#[derive(Debug, Zeroize, Eq, PartialEq)]
94pub struct AeKey([u8; AE_KEY_LEN]);
95impl AeKey {
96 pub fn new_from_signer(
103 signer: &dyn Signer,
104 public_seed: &[u8],
105 ) -> Result<Self, Box<dyn error::Error>> {
106 let seed = Self::seed_from_signer(signer, public_seed)?;
107 Self::from_seed(&seed)
108 }
109
110 pub fn seed_from_signer(
114 signer: &dyn Signer,
115 public_seed: &[u8],
116 ) -> Result<Vec<u8>, SignerError> {
117 let message = [b"AeKey", public_seed].concat();
118 let signature = signer.try_sign_message(&message)?;
119
120 if bool::from(signature.as_ref().ct_eq(Signature::default().as_ref())) {
123 return Err(SignerError::Custom("Rejecting default signature".into()));
124 }
125
126 let mut hasher = Sha3_512::new();
127 hasher.update(signature.as_ref());
128 let result = hasher.finalize();
129
130 Ok(result.to_vec())
131 }
132
133 pub fn new_rand() -> Self {
137 AuthenticatedEncryption::keygen()
138 }
139
140 pub fn encrypt(&self, amount: u64) -> AeCiphertext {
142 AuthenticatedEncryption::encrypt(self, amount)
143 }
144
145 pub fn decrypt(&self, ciphertext: &AeCiphertext) -> Option<u64> {
146 AuthenticatedEncryption::decrypt(self, ciphertext)
147 }
148}
149
150impl EncodableKey for AeKey {
151 fn read<R: Read>(reader: &mut R) -> Result<Self, Box<dyn error::Error>> {
152 let bytes: [u8; AE_KEY_LEN] = serde_json::from_reader(reader)?;
153 Ok(Self(bytes))
154 }
155
156 fn write<W: Write>(&self, writer: &mut W) -> Result<String, Box<dyn error::Error>> {
157 let bytes = self.0;
158 let json = serde_json::to_string(&bytes.to_vec())?;
159 writer.write_all(&json.clone().into_bytes())?;
160 Ok(json)
161 }
162}
163
164impl SeedDerivable for AeKey {
165 fn from_seed(seed: &[u8]) -> Result<Self, Box<dyn error::Error>> {
166 const MINIMUM_SEED_LEN: usize = AE_KEY_LEN;
167 const MAXIMUM_SEED_LEN: usize = 65535;
168
169 if seed.len() < MINIMUM_SEED_LEN {
170 return Err(AuthenticatedEncryptionError::SeedLengthTooShort.into());
171 }
172 if seed.len() > MAXIMUM_SEED_LEN {
173 return Err(AuthenticatedEncryptionError::SeedLengthTooLong.into());
174 }
175
176 let mut hasher = Sha3_512::new();
177 hasher.update(seed);
178 let result = hasher.finalize();
179
180 Ok(Self(result[..AE_KEY_LEN].try_into()?))
181 }
182
183 fn from_seed_and_derivation_path(
184 _seed: &[u8],
185 _derivation_path: Option<DerivationPath>,
186 ) -> Result<Self, Box<dyn error::Error>> {
187 Err(AuthenticatedEncryptionError::DerivationMethodNotSupported.into())
188 }
189
190 fn from_seed_phrase_and_passphrase(
191 seed_phrase: &str,
192 passphrase: &str,
193 ) -> Result<Self, Box<dyn error::Error>> {
194 Self::from_seed(&generate_seed_from_seed_phrase_and_passphrase(
195 seed_phrase,
196 passphrase,
197 ))
198 }
199}
200
201impl From<[u8; AE_KEY_LEN]> for AeKey {
202 fn from(bytes: [u8; AE_KEY_LEN]) -> Self {
203 Self(bytes)
204 }
205}
206
207impl From<AeKey> for [u8; AE_KEY_LEN] {
208 fn from(key: AeKey) -> Self {
209 key.0
210 }
211}
212
213impl TryFrom<&[u8]> for AeKey {
214 type Error = AuthenticatedEncryptionError;
215 fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
216 if bytes.len() != AE_KEY_LEN {
217 return Err(AuthenticatedEncryptionError::Deserialization);
218 }
219 bytes
220 .try_into()
221 .map(Self)
222 .map_err(|_| AuthenticatedEncryptionError::Deserialization)
223 }
224}
225
226type Nonce = [u8; NONCE_LEN];
229type Ciphertext = [u8; CIPHERTEXT_LEN];
230
231#[derive(Debug, Default, Clone)]
233pub struct AeCiphertext {
234 nonce: Nonce,
235 ciphertext: Ciphertext,
236}
237impl AeCiphertext {
238 pub fn decrypt(&self, key: &AeKey) -> Option<u64> {
239 AuthenticatedEncryption::decrypt(key, self)
240 }
241
242 pub fn to_bytes(&self) -> [u8; AE_CIPHERTEXT_LEN] {
243 let mut buf = [0_u8; AE_CIPHERTEXT_LEN];
244 buf[..NONCE_LEN].copy_from_slice(&self.nonce);
245 buf[NONCE_LEN..].copy_from_slice(&self.ciphertext);
246 buf
247 }
248
249 pub fn from_bytes(bytes: &[u8]) -> Option<AeCiphertext> {
250 if bytes.len() != AE_CIPHERTEXT_LEN {
251 return None;
252 }
253
254 let nonce = bytes[..NONCE_LEN].try_into().ok()?;
255 let ciphertext = bytes[NONCE_LEN..].try_into().ok()?;
256
257 Some(AeCiphertext { nonce, ciphertext })
258 }
259}
260
261impl fmt::Display for AeCiphertext {
262 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
263 write!(f, "{}", BASE64_STANDARD.encode(self.to_bytes()))
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use {
270 super::*,
271 solana_sdk::{pubkey::Pubkey, signature::Keypair, signer::null_signer::NullSigner},
272 };
273
274 #[test]
275 fn test_aes_encrypt_decrypt_correctness() {
276 let key = AeKey::new_rand();
277 let amount = 55;
278
279 let ciphertext = key.encrypt(amount);
280 let decrypted_amount = ciphertext.decrypt(&key).unwrap();
281
282 assert_eq!(amount, decrypted_amount);
283 }
284
285 #[test]
286 fn test_aes_new() {
287 let keypair1 = Keypair::new();
288 let keypair2 = Keypair::new();
289
290 assert_ne!(
291 AeKey::new_from_signer(&keypair1, Pubkey::default().as_ref())
292 .unwrap()
293 .0,
294 AeKey::new_from_signer(&keypair2, Pubkey::default().as_ref())
295 .unwrap()
296 .0,
297 );
298
299 let null_signer = NullSigner::new(&Pubkey::default());
300 assert!(AeKey::new_from_signer(&null_signer, Pubkey::default().as_ref()).is_err());
301 }
302
303 #[test]
304 fn test_aes_key_from_seed() {
305 let good_seed = vec![0; 32];
306 assert!(AeKey::from_seed(&good_seed).is_ok());
307
308 let too_short_seed = vec![0; 15];
309 assert!(AeKey::from_seed(&too_short_seed).is_err());
310
311 let too_long_seed = vec![0; 65536];
312 assert!(AeKey::from_seed(&too_long_seed).is_err());
313 }
314
315 #[test]
316 fn test_aes_key_from() {
317 let key = AeKey::from_seed(&[0; 32]).unwrap();
318 let key_bytes: [u8; AE_KEY_LEN] = AeKey::from_seed(&[0; 32]).unwrap().into();
319
320 assert_eq!(key, AeKey::from(key_bytes));
321 }
322
323 #[test]
324 fn test_aes_key_try_from() {
325 let key = AeKey::from_seed(&[0; 32]).unwrap();
326 let key_bytes: [u8; AE_KEY_LEN] = AeKey::from_seed(&[0; 32]).unwrap().into();
327
328 assert_eq!(key, AeKey::try_from(key_bytes.as_slice()).unwrap());
329 }
330
331 #[test]
332 fn test_aes_key_try_from_error() {
333 let too_many_bytes = vec![0_u8; 32];
334 assert!(AeKey::try_from(too_many_bytes.as_slice()).is_err());
335 }
336}