solana_zk_sdk/encryption/
auth_encryption.rs1use {
6 crate::{
7 encryption::{AE_CIPHERTEXT_LEN, AE_KEY_LEN},
8 errors::AuthenticatedEncryptionError,
9 },
10 aes_gcm_siv::{
11 aead::{Aead, KeyInit},
12 Aes128GcmSiv,
13 },
14 base64::{prelude::BASE64_STANDARD, Engine},
15 rand::{rngs::OsRng, Rng},
16 std::{convert::TryInto, fmt},
17 zeroize::Zeroize,
18};
19#[cfg(not(target_arch = "wasm32"))]
24use {
25 sha3::Digest,
26 sha3::Sha3_512,
27 solana_derivation_path::DerivationPath,
28 solana_sdk::{
29 signature::Signature,
30 signer::{
31 keypair::generate_seed_from_seed_phrase_and_passphrase, EncodableKey, SeedDerivable,
32 Signer, SignerError,
33 },
34 },
35 std::{
36 error,
37 io::{Read, Write},
38 },
39 subtle::ConstantTimeEq,
40};
41
42const NONCE_LEN: usize = 12;
44
45const CIPHERTEXT_LEN: usize = 24;
47
48struct AuthenticatedEncryption;
49impl AuthenticatedEncryption {
50 fn keygen() -> AeKey {
54 AeKey(OsRng.gen::<[u8; AE_KEY_LEN]>())
55 }
56
57 fn encrypt(key: &AeKey, balance: u64) -> AeCiphertext {
60 let mut plaintext = balance.to_le_bytes();
61 let nonce: Nonce = OsRng.gen::<[u8; NONCE_LEN]>();
62
63 let ciphertext = Aes128GcmSiv::new(&key.0.into())
65 .encrypt(&nonce.into(), plaintext.as_ref())
66 .expect("authenticated encryption");
67
68 plaintext.zeroize();
69
70 AeCiphertext {
71 nonce,
72 ciphertext: ciphertext.try_into().unwrap(),
73 }
74 }
75
76 fn decrypt(key: &AeKey, ciphertext: &AeCiphertext) -> Option<u64> {
79 let plaintext = Aes128GcmSiv::new(&key.0.into())
80 .decrypt(&ciphertext.nonce.into(), ciphertext.ciphertext.as_ref());
81
82 if let Ok(plaintext) = plaintext {
83 let amount_bytes: [u8; 8] = plaintext.try_into().unwrap();
84 Some(u64::from_le_bytes(amount_bytes))
85 } else {
86 None
87 }
88 }
89}
90
91#[derive(Clone, Debug, Zeroize, Eq, PartialEq)]
92pub struct AeKey([u8; AE_KEY_LEN]);
93impl AeKey {
94 pub fn new_rand() -> Self {
98 AuthenticatedEncryption::keygen()
99 }
100
101 pub fn encrypt(&self, amount: u64) -> AeCiphertext {
103 AuthenticatedEncryption::encrypt(self, amount)
104 }
105
106 pub fn decrypt(&self, ciphertext: &AeCiphertext) -> Option<u64> {
107 AuthenticatedEncryption::decrypt(self, ciphertext)
108 }
109}
110
111#[cfg(not(target_arch = "wasm32"))]
112impl AeKey {
113 pub fn new_from_signer(
120 signer: &dyn Signer,
121 public_seed: &[u8],
122 ) -> Result<Self, Box<dyn error::Error>> {
123 let seed = Self::seed_from_signer(signer, public_seed)?;
124 Self::from_seed(&seed)
125 }
126
127 pub fn seed_from_signer(
131 signer: &dyn Signer,
132 public_seed: &[u8],
133 ) -> Result<Vec<u8>, SignerError> {
134 let message = [b"AeKey", public_seed].concat();
135 let signature = signer.try_sign_message(&message)?;
136
137 if bool::from(signature.as_ref().ct_eq(Signature::default().as_ref())) {
140 return Err(SignerError::Custom("Rejecting default signature".into()));
141 }
142
143 Ok(Self::seed_from_signature(&signature))
144 }
145
146 pub fn new_from_signature(signature: &Signature) -> Result<Self, Box<dyn error::Error>> {
148 let seed = Self::seed_from_signature(signature);
149 Self::from_seed(&seed)
150 }
151
152 pub fn seed_from_signature(signature: &Signature) -> Vec<u8> {
154 let mut hasher = Sha3_512::new();
155 hasher.update(signature);
156 let result = hasher.finalize();
157
158 result.to_vec()
159 }
160}
161
162#[cfg(not(target_arch = "wasm32"))]
163impl EncodableKey for AeKey {
164 fn read<R: Read>(reader: &mut R) -> Result<Self, Box<dyn error::Error>> {
165 let bytes: [u8; AE_KEY_LEN] = serde_json::from_reader(reader)?;
166 Ok(Self(bytes))
167 }
168
169 fn write<W: Write>(&self, writer: &mut W) -> Result<String, Box<dyn error::Error>> {
170 let bytes = self.0;
171 let json = serde_json::to_string(&bytes.to_vec())?;
172 writer.write_all(&json.clone().into_bytes())?;
173 Ok(json)
174 }
175}
176
177#[cfg(not(target_arch = "wasm32"))]
178impl SeedDerivable for AeKey {
179 fn from_seed(seed: &[u8]) -> Result<Self, Box<dyn error::Error>> {
180 const MINIMUM_SEED_LEN: usize = AE_KEY_LEN;
181 const MAXIMUM_SEED_LEN: usize = 65535;
182
183 if seed.len() < MINIMUM_SEED_LEN {
184 return Err(AuthenticatedEncryptionError::SeedLengthTooShort.into());
185 }
186 if seed.len() > MAXIMUM_SEED_LEN {
187 return Err(AuthenticatedEncryptionError::SeedLengthTooLong.into());
188 }
189
190 let mut hasher = Sha3_512::new();
191 hasher.update(seed);
192 let result = hasher.finalize();
193
194 Ok(Self(result[..AE_KEY_LEN].try_into()?))
195 }
196
197 fn from_seed_and_derivation_path(
198 _seed: &[u8],
199 _derivation_path: Option<DerivationPath>,
200 ) -> Result<Self, Box<dyn error::Error>> {
201 Err(AuthenticatedEncryptionError::DerivationMethodNotSupported.into())
202 }
203
204 fn from_seed_phrase_and_passphrase(
205 seed_phrase: &str,
206 passphrase: &str,
207 ) -> Result<Self, Box<dyn error::Error>> {
208 Self::from_seed(&generate_seed_from_seed_phrase_and_passphrase(
209 seed_phrase,
210 passphrase,
211 ))
212 }
213}
214
215impl From<[u8; AE_KEY_LEN]> for AeKey {
216 fn from(bytes: [u8; AE_KEY_LEN]) -> Self {
217 Self(bytes)
218 }
219}
220
221impl From<AeKey> for [u8; AE_KEY_LEN] {
222 fn from(key: AeKey) -> Self {
223 key.0
224 }
225}
226
227impl TryFrom<&[u8]> for AeKey {
228 type Error = AuthenticatedEncryptionError;
229 fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
230 if bytes.len() != AE_KEY_LEN {
231 return Err(AuthenticatedEncryptionError::Deserialization);
232 }
233 bytes
234 .try_into()
235 .map(Self)
236 .map_err(|_| AuthenticatedEncryptionError::Deserialization)
237 }
238}
239
240type Nonce = [u8; NONCE_LEN];
243type Ciphertext = [u8; CIPHERTEXT_LEN];
244
245#[derive(Clone, Copy, Debug, Default)]
247pub struct AeCiphertext {
248 nonce: Nonce,
249 ciphertext: Ciphertext,
250}
251impl AeCiphertext {
252 pub fn decrypt(&self, key: &AeKey) -> Option<u64> {
253 AuthenticatedEncryption::decrypt(key, self)
254 }
255
256 pub fn to_bytes(&self) -> [u8; AE_CIPHERTEXT_LEN] {
257 let mut buf = [0_u8; AE_CIPHERTEXT_LEN];
258 buf[..NONCE_LEN].copy_from_slice(&self.nonce);
259 buf[NONCE_LEN..].copy_from_slice(&self.ciphertext);
260 buf
261 }
262
263 pub fn from_bytes(bytes: &[u8]) -> Option<AeCiphertext> {
264 if bytes.len() != AE_CIPHERTEXT_LEN {
265 return None;
266 }
267
268 let nonce = bytes[..NONCE_LEN].try_into().ok()?;
269 let ciphertext = bytes[NONCE_LEN..].try_into().ok()?;
270
271 Some(AeCiphertext { nonce, ciphertext })
272 }
273}
274
275impl fmt::Display for AeCiphertext {
276 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
277 write!(f, "{}", BASE64_STANDARD.encode(self.to_bytes()))
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use {
284 super::*,
285 solana_sdk::{pubkey::Pubkey, signature::Keypair, signer::null_signer::NullSigner},
286 };
287
288 #[test]
289 fn test_aes_encrypt_decrypt_correctness() {
290 let key = AeKey::new_rand();
291 let amount = 55;
292
293 let ciphertext = key.encrypt(amount);
294 let decrypted_amount = ciphertext.decrypt(&key).unwrap();
295
296 assert_eq!(amount, decrypted_amount);
297 }
298
299 #[test]
300 fn test_aes_new() {
301 let keypair1 = Keypair::new();
302 let keypair2 = Keypair::new();
303
304 assert_ne!(
305 AeKey::new_from_signer(&keypair1, Pubkey::default().as_ref())
306 .unwrap()
307 .0,
308 AeKey::new_from_signer(&keypair2, Pubkey::default().as_ref())
309 .unwrap()
310 .0,
311 );
312
313 let null_signer = NullSigner::new(&Pubkey::default());
314 assert!(AeKey::new_from_signer(&null_signer, Pubkey::default().as_ref()).is_err());
315 }
316
317 #[test]
318 fn test_aes_key_from_seed() {
319 let good_seed = vec![0; 32];
320 assert!(AeKey::from_seed(&good_seed).is_ok());
321
322 let too_short_seed = vec![0; 15];
323 assert!(AeKey::from_seed(&too_short_seed).is_err());
324
325 let too_long_seed = vec![0; 65536];
326 assert!(AeKey::from_seed(&too_long_seed).is_err());
327 }
328
329 #[test]
330 fn test_aes_key_from() {
331 let key = AeKey::from_seed(&[0; 32]).unwrap();
332 let key_bytes: [u8; AE_KEY_LEN] = AeKey::from_seed(&[0; 32]).unwrap().into();
333
334 assert_eq!(key, AeKey::from(key_bytes));
335 }
336
337 #[test]
338 fn test_aes_key_try_from() {
339 let key = AeKey::from_seed(&[0; 32]).unwrap();
340 let key_bytes: [u8; AE_KEY_LEN] = AeKey::from_seed(&[0; 32]).unwrap().into();
341
342 assert_eq!(key, AeKey::try_from(key_bytes.as_slice()).unwrap());
343 }
344
345 #[test]
346 fn test_aes_key_try_from_error() {
347 let too_many_bytes = vec![0_u8; 32];
348 assert!(AeKey::try_from(too_many_bytes.as_slice()).is_err());
349 }
350}