solana_zk_sdk/encryption/
auth_encryption.rs

1//! Authenticated encryption implementation.
2//!
3//! This module is a simple wrapper of the `Aes128GcmSiv` implementation specialized for SPL
4//! token-2022 where the plaintext is always `u64`.
5#[cfg(target_arch = "wasm32")]
6use wasm_bindgen::prelude::*;
7use {
8    crate::{
9        encryption::{AE_CIPHERTEXT_LEN, AE_KEY_LEN},
10        errors::AuthenticatedEncryptionError,
11    },
12    aes_gcm_siv::{
13        aead::{Aead, KeyInit},
14        Aes128GcmSiv,
15    },
16    base64::{prelude::BASE64_STANDARD, Engine},
17    rand::{rngs::OsRng, Rng},
18    std::{convert::TryInto, fmt},
19    zeroize::Zeroize,
20};
21// Currently, `wasm_bindgen` exports types and functions included in the current crate, but all
22// types and functions exported for wasm targets in all of its dependencies
23// (https://github.com/rustwasm/wasm-bindgen/issues/3759). We specifically exclude some of the
24// dependencies that will cause unnecessary bloat to the wasm binary.
25#[cfg(not(target_arch = "wasm32"))]
26use {
27    sha3::Digest,
28    sha3::Sha3_512,
29    solana_derivation_path::DerivationPath,
30    solana_seed_derivable::SeedDerivable,
31    solana_seed_phrase::generate_seed_from_seed_phrase_and_passphrase,
32    solana_signature::Signature,
33    solana_signer::{EncodableKey, Signer, SignerError},
34    std::{
35        error,
36        io::{Read, Write},
37    },
38    subtle::ConstantTimeEq,
39};
40
41/// Byte length of an authenticated encryption nonce component
42const NONCE_LEN: usize = 12;
43
44/// Byte length of an authenticated encryption ciphertext component
45const CIPHERTEXT_LEN: usize = 24;
46
47struct AuthenticatedEncryption;
48impl AuthenticatedEncryption {
49    /// Generates an authenticated encryption key.
50    ///
51    /// This function is randomized. It internally samples a 128-bit key using `OsRng`.
52    fn keygen() -> AeKey {
53        AeKey(OsRng.gen::<[u8; AE_KEY_LEN]>())
54    }
55
56    /// On input of an authenticated encryption key and an amount, the function returns a
57    /// corresponding authenticated encryption ciphertext.
58    fn encrypt(key: &AeKey, balance: u64) -> AeCiphertext {
59        let mut plaintext = balance.to_le_bytes();
60        let nonce: Nonce = OsRng.gen::<[u8; NONCE_LEN]>();
61
62        // The balance and the nonce have fixed length and therefore, encryption should not fail.
63        let ciphertext = Aes128GcmSiv::new(&key.0.into())
64            .encrypt(&nonce.into(), plaintext.as_ref())
65            .expect("authenticated encryption");
66
67        plaintext.zeroize();
68
69        AeCiphertext {
70            nonce,
71            ciphertext: ciphertext.try_into().unwrap(),
72        }
73    }
74
75    /// On input of an authenticated encryption key and a ciphertext, the function returns the
76    /// originally encrypted amount.
77    fn decrypt(key: &AeKey, ciphertext: &AeCiphertext) -> Option<u64> {
78        let plaintext = Aes128GcmSiv::new(&key.0.into())
79            .decrypt(&ciphertext.nonce.into(), ciphertext.ciphertext.as_ref());
80
81        if let Ok(plaintext) = plaintext {
82            let amount_bytes: [u8; 8] = plaintext.try_into().unwrap();
83            Some(u64::from_le_bytes(amount_bytes))
84        } else {
85            None
86        }
87    }
88}
89
90#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
91#[derive(Clone, Debug, Zeroize, Eq, PartialEq)]
92pub struct AeKey([u8; AE_KEY_LEN]);
93#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
94impl AeKey {
95    /// Generates a random authenticated encryption key.
96    ///
97    /// This function is randomized. It internally samples a scalar element using `OsRng`.
98    #[cfg_attr(target_arch = "wasm32", wasm_bindgen(js_name = newRand))]
99    pub fn new_rand() -> Self {
100        AuthenticatedEncryption::keygen()
101    }
102
103    /// Encrypts an amount under the authenticated encryption key.
104    pub fn encrypt(&self, amount: u64) -> AeCiphertext {
105        AuthenticatedEncryption::encrypt(self, amount)
106    }
107
108    pub fn decrypt(&self, ciphertext: &AeCiphertext) -> Option<u64> {
109        AuthenticatedEncryption::decrypt(self, ciphertext)
110    }
111}
112
113#[cfg(not(target_arch = "wasm32"))]
114impl AeKey {
115    /// Deterministically derives an authenticated encryption key from a Solana signer and a public
116    /// seed.
117    ///
118    /// This function exists for applications where a user may not wish to maintain a Solana signer
119    /// and an authenticated encryption key separately. Instead, a user can derive the ElGamal
120    /// keypair on-the-fly whenever encrytion/decryption is needed.
121    pub fn new_from_signer(
122        signer: &dyn Signer,
123        public_seed: &[u8],
124    ) -> Result<Self, Box<dyn error::Error>> {
125        let seed = Self::seed_from_signer(signer, public_seed)?;
126        Self::from_seed(&seed)
127    }
128
129    /// Derive a seed from a Solana signer used to generate an authenticated encryption key.
130    ///
131    /// The seed is derived as the hash of the signature of a public seed.
132    pub fn seed_from_signer(
133        signer: &dyn Signer,
134        public_seed: &[u8],
135    ) -> Result<Vec<u8>, SignerError> {
136        let message = [b"AeKey", public_seed].concat();
137        let signature = signer.try_sign_message(&message)?;
138
139        // Some `Signer` implementations return the default signature, which is not suitable for
140        // use as key material
141        if bool::from(signature.as_ref().ct_eq(Signature::default().as_ref())) {
142            return Err(SignerError::Custom("Rejecting default signature".into()));
143        }
144
145        Ok(Self::seed_from_signature(&signature))
146    }
147
148    /// Derive an authenticated encryption key from a signature.
149    pub fn new_from_signature(signature: &Signature) -> Result<Self, Box<dyn error::Error>> {
150        let seed = Self::seed_from_signature(signature);
151        Self::from_seed(&seed)
152    }
153
154    /// Derive a seed from a signature used to generate an authenticated encryption key.
155    pub fn seed_from_signature(signature: &Signature) -> Vec<u8> {
156        let mut hasher = Sha3_512::new();
157        hasher.update(signature);
158        let result = hasher.finalize();
159
160        result.to_vec()
161    }
162}
163
164#[cfg(not(target_arch = "wasm32"))]
165impl EncodableKey for AeKey {
166    fn read<R: Read>(reader: &mut R) -> Result<Self, Box<dyn error::Error>> {
167        let bytes: [u8; AE_KEY_LEN] = serde_json::from_reader(reader)?;
168        Ok(Self(bytes))
169    }
170
171    fn write<W: Write>(&self, writer: &mut W) -> Result<String, Box<dyn error::Error>> {
172        let bytes = self.0;
173        let json = serde_json::to_string(&bytes.to_vec())?;
174        writer.write_all(&json.clone().into_bytes())?;
175        Ok(json)
176    }
177}
178
179#[cfg(not(target_arch = "wasm32"))]
180impl SeedDerivable for AeKey {
181    fn from_seed(seed: &[u8]) -> Result<Self, Box<dyn error::Error>> {
182        const MINIMUM_SEED_LEN: usize = AE_KEY_LEN;
183        const MAXIMUM_SEED_LEN: usize = 65535;
184
185        if seed.len() < MINIMUM_SEED_LEN {
186            return Err(AuthenticatedEncryptionError::SeedLengthTooShort.into());
187        }
188        if seed.len() > MAXIMUM_SEED_LEN {
189            return Err(AuthenticatedEncryptionError::SeedLengthTooLong.into());
190        }
191
192        let mut hasher = Sha3_512::new();
193        hasher.update(seed);
194        let result = hasher.finalize();
195
196        Ok(Self(result[..AE_KEY_LEN].try_into()?))
197    }
198
199    fn from_seed_and_derivation_path(
200        _seed: &[u8],
201        _derivation_path: Option<DerivationPath>,
202    ) -> Result<Self, Box<dyn error::Error>> {
203        Err(AuthenticatedEncryptionError::DerivationMethodNotSupported.into())
204    }
205
206    fn from_seed_phrase_and_passphrase(
207        seed_phrase: &str,
208        passphrase: &str,
209    ) -> Result<Self, Box<dyn error::Error>> {
210        Self::from_seed(&generate_seed_from_seed_phrase_and_passphrase(
211            seed_phrase,
212            passphrase,
213        ))
214    }
215}
216
217impl From<[u8; AE_KEY_LEN]> for AeKey {
218    fn from(bytes: [u8; AE_KEY_LEN]) -> Self {
219        Self(bytes)
220    }
221}
222
223impl From<AeKey> for [u8; AE_KEY_LEN] {
224    fn from(key: AeKey) -> Self {
225        key.0
226    }
227}
228
229impl TryFrom<&[u8]> for AeKey {
230    type Error = AuthenticatedEncryptionError;
231    fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
232        if bytes.len() != AE_KEY_LEN {
233            return Err(AuthenticatedEncryptionError::Deserialization);
234        }
235        bytes
236            .try_into()
237            .map(Self)
238            .map_err(|_| AuthenticatedEncryptionError::Deserialization)
239    }
240}
241
242/// For the purpose of encrypting balances for the spl token accounts, the nonce and ciphertext
243/// sizes should always be fixed.
244type Nonce = [u8; NONCE_LEN];
245type Ciphertext = [u8; CIPHERTEXT_LEN];
246
247/// Authenticated encryption nonce and ciphertext
248#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
249#[derive(Clone, Copy, Debug, Default)]
250pub struct AeCiphertext {
251    nonce: Nonce,
252    ciphertext: Ciphertext,
253}
254impl AeCiphertext {
255    pub fn decrypt(&self, key: &AeKey) -> Option<u64> {
256        AuthenticatedEncryption::decrypt(key, self)
257    }
258
259    pub fn to_bytes(&self) -> [u8; AE_CIPHERTEXT_LEN] {
260        let mut buf = [0_u8; AE_CIPHERTEXT_LEN];
261        buf[..NONCE_LEN].copy_from_slice(&self.nonce);
262        buf[NONCE_LEN..].copy_from_slice(&self.ciphertext);
263        buf
264    }
265
266    pub fn from_bytes(bytes: &[u8]) -> Option<AeCiphertext> {
267        if bytes.len() != AE_CIPHERTEXT_LEN {
268            return None;
269        }
270
271        let nonce = bytes[..NONCE_LEN].try_into().ok()?;
272        let ciphertext = bytes[NONCE_LEN..].try_into().ok()?;
273
274        Some(AeCiphertext { nonce, ciphertext })
275    }
276}
277
278impl fmt::Display for AeCiphertext {
279    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
280        write!(f, "{}", BASE64_STANDARD.encode(self.to_bytes()))
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use {
287        super::*, solana_keypair::Keypair, solana_pubkey::Pubkey,
288        solana_signer::null_signer::NullSigner,
289    };
290
291    #[test]
292    fn test_aes_encrypt_decrypt_correctness() {
293        let key = AeKey::new_rand();
294        let amount = 55;
295
296        let ciphertext = key.encrypt(amount);
297        let decrypted_amount = ciphertext.decrypt(&key).unwrap();
298
299        assert_eq!(amount, decrypted_amount);
300    }
301
302    #[test]
303    fn test_aes_new() {
304        let keypair1 = Keypair::new();
305        let keypair2 = Keypair::new();
306
307        assert_ne!(
308            AeKey::new_from_signer(&keypair1, Pubkey::default().as_ref())
309                .unwrap()
310                .0,
311            AeKey::new_from_signer(&keypair2, Pubkey::default().as_ref())
312                .unwrap()
313                .0,
314        );
315
316        let null_signer = NullSigner::new(&Pubkey::default());
317        assert!(AeKey::new_from_signer(&null_signer, Pubkey::default().as_ref()).is_err());
318    }
319
320    #[test]
321    fn test_aes_key_from_seed() {
322        let good_seed = vec![0; 32];
323        assert!(AeKey::from_seed(&good_seed).is_ok());
324
325        let too_short_seed = vec![0; 15];
326        assert!(AeKey::from_seed(&too_short_seed).is_err());
327
328        let too_long_seed = vec![0; 65536];
329        assert!(AeKey::from_seed(&too_long_seed).is_err());
330    }
331
332    #[test]
333    fn test_aes_key_from() {
334        let key = AeKey::from_seed(&[0; 32]).unwrap();
335        let key_bytes: [u8; AE_KEY_LEN] = AeKey::from_seed(&[0; 32]).unwrap().into();
336
337        assert_eq!(key, AeKey::from(key_bytes));
338    }
339
340    #[test]
341    fn test_aes_key_try_from() {
342        let key = AeKey::from_seed(&[0; 32]).unwrap();
343        let key_bytes: [u8; AE_KEY_LEN] = AeKey::from_seed(&[0; 32]).unwrap().into();
344
345        assert_eq!(key, AeKey::try_from(key_bytes.as_slice()).unwrap());
346    }
347
348    #[test]
349    fn test_aes_key_try_from_error() {
350        let too_many_bytes = vec![0_u8; 32];
351        assert!(AeKey::try_from(too_many_bytes.as_slice()).is_err());
352    }
353}