aws_lc_rs/aead/
unbound_key.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0 OR ISC
3
4use super::aead_ctx::AeadCtx;
5use super::{
6    Algorithm, Nonce, Tag, AES_128_GCM, AES_128_GCM_SIV, AES_192_GCM, AES_256_GCM, AES_256_GCM_SIV,
7    CHACHA20_POLY1305, MAX_KEY_LEN, MAX_TAG_LEN, NONCE_LEN,
8};
9use crate::aws_lc::{
10    EVP_AEAD_CTX_open, EVP_AEAD_CTX_open_gather, EVP_AEAD_CTX_seal, EVP_AEAD_CTX_seal_scatter,
11};
12use crate::error::Unspecified;
13use crate::fips::indicator_check;
14use crate::hkdf;
15use crate::iv::FixedLength;
16use core::fmt::Debug;
17use core::mem::MaybeUninit;
18use core::ops::RangeFrom;
19use core::ptr::null;
20
21/// The maximum length of a nonce returned by our AEAD API.
22const MAX_NONCE_LEN: usize = NONCE_LEN;
23
24/// The maximum required tag buffer needed if using AWS-LC generated nonce construction
25const MAX_TAG_NONCE_BUFFER_LEN: usize = MAX_TAG_LEN + MAX_NONCE_LEN;
26
27/// An AEAD key without a designated role or nonce sequence.
28pub struct UnboundKey {
29    ctx: AeadCtx,
30    algorithm: &'static Algorithm,
31}
32
33#[allow(clippy::missing_fields_in_debug)]
34impl Debug for UnboundKey {
35    fn fmt(&self, f: &mut core::fmt::Formatter) -> Result<(), core::fmt::Error> {
36        f.debug_struct("UnboundKey")
37            .field("algorithm", &self.algorithm)
38            .finish()
39    }
40}
41
42impl UnboundKey {
43    /// Constructs an `UnboundKey`.
44    /// # Errors
45    /// `error::Unspecified` if `key_bytes.len() != algorithm.key_len()`.
46    pub fn new(algorithm: &'static Algorithm, key_bytes: &[u8]) -> Result<Self, Unspecified> {
47        Ok(Self {
48            ctx: (algorithm.init)(key_bytes, algorithm.tag_len())?,
49            algorithm,
50        })
51    }
52
53    #[inline]
54    pub(crate) fn open_within<'in_out>(
55        &self,
56        nonce: Nonce,
57        aad: &[u8],
58        in_out: &'in_out mut [u8],
59        ciphertext_and_tag: RangeFrom<usize>,
60    ) -> Result<&'in_out mut [u8], Unspecified> {
61        let in_prefix_len = ciphertext_and_tag.start;
62        let ciphertext_and_tag_len = in_out.len().checked_sub(in_prefix_len).ok_or(Unspecified)?;
63        let ciphertext_len = ciphertext_and_tag_len
64            .checked_sub(self.algorithm().tag_len())
65            .ok_or(Unspecified)?;
66        self.check_per_nonce_max_bytes(ciphertext_len)?;
67
68        match self.ctx {
69            AeadCtx::AES_128_GCM_RANDNONCE(_) | AeadCtx::AES_256_GCM_RANDNONCE(_) => {
70                self.open_combined_randnonce(nonce, aad, &mut in_out[in_prefix_len..])
71            }
72            _ => self.open_combined(nonce, aad.as_ref(), &mut in_out[in_prefix_len..]),
73        }?;
74
75        // shift the plaintext to the left
76        in_out.copy_within(in_prefix_len..in_prefix_len + ciphertext_len, 0);
77
78        // `ciphertext_len` is also the plaintext length.
79        Ok(&mut in_out[..ciphertext_len])
80    }
81
82    #[inline]
83    pub(crate) fn open_separate_gather(
84        &self,
85        nonce: &Nonce,
86        aad: &[u8],
87        in_ciphertext: &[u8],
88        in_tag: &[u8],
89        out_plaintext: &mut [u8],
90    ) -> Result<(), Unspecified> {
91        self.check_per_nonce_max_bytes(in_ciphertext.len())?;
92
93        // ensure that the lengths match
94        {
95            let actual = in_ciphertext.len();
96            let expected = out_plaintext.len();
97
98            if actual != expected {
99                return Err(Unspecified);
100            }
101        }
102
103        unsafe {
104            let aead_ctx = self.ctx.as_ref();
105            let nonce = nonce.as_ref();
106
107            if 1 != EVP_AEAD_CTX_open_gather(
108                *aead_ctx.as_const(),
109                out_plaintext.as_mut_ptr(),
110                nonce.as_ptr(),
111                nonce.len(),
112                in_ciphertext.as_ptr(),
113                in_ciphertext.len(),
114                in_tag.as_ptr(),
115                in_tag.len(),
116                aad.as_ptr(),
117                aad.len(),
118            ) {
119                return Err(Unspecified);
120            }
121            Ok(())
122        }
123    }
124
125    #[inline]
126    pub(crate) fn seal_in_place_append_tag<'a, InOut>(
127        &self,
128        nonce: Option<Nonce>,
129        aad: &[u8],
130        in_out: &'a mut InOut,
131    ) -> Result<Nonce, Unspecified>
132    where
133        InOut: AsMut<[u8]> + for<'in_out> Extend<&'in_out u8>,
134    {
135        self.check_per_nonce_max_bytes(in_out.as_mut().len())?;
136        match nonce {
137            Some(nonce) => self.seal_combined(nonce, aad, in_out),
138            None => self.seal_combined_randnonce(aad, in_out),
139        }
140    }
141
142    #[inline]
143    pub(crate) fn seal_in_place_separate_tag(
144        &self,
145        nonce: Option<Nonce>,
146        aad: &[u8],
147        in_out: &mut [u8],
148    ) -> Result<(Nonce, Tag), Unspecified> {
149        self.check_per_nonce_max_bytes(in_out.len())?;
150        match nonce {
151            Some(nonce) => self.seal_separate(nonce, aad, in_out),
152            None => self.seal_separate_randnonce(aad, in_out),
153        }
154    }
155
156    #[inline]
157    #[allow(clippy::needless_pass_by_value)]
158    pub(crate) fn seal_in_place_separate_scatter(
159        &self,
160        nonce: Nonce,
161        aad: &[u8],
162        in_out: &mut [u8],
163        extra_in: &[u8],
164        extra_out_and_tag: &mut [u8],
165    ) -> Result<(), Unspecified> {
166        self.check_per_nonce_max_bytes(in_out.len())?;
167        // ensure that the extra lengths match
168        {
169            let actual = extra_in.len() + self.algorithm().tag_len();
170            let expected = extra_out_and_tag.len();
171
172            if actual != expected {
173                return Err(Unspecified);
174            }
175        }
176
177        let nonce = nonce.as_ref();
178        let mut out_tag_len = extra_out_and_tag.len();
179
180        if 1 != unsafe {
181            EVP_AEAD_CTX_seal_scatter(
182                *self.ctx.as_ref().as_const(),
183                in_out.as_mut_ptr(),
184                extra_out_and_tag.as_mut_ptr(),
185                &mut out_tag_len,
186                extra_out_and_tag.len(),
187                nonce.as_ptr(),
188                nonce.len(),
189                in_out.as_ptr(),
190                in_out.len(),
191                extra_in.as_ptr(),
192                extra_in.len(),
193                aad.as_ptr(),
194                aad.len(),
195            )
196        } {
197            return Err(Unspecified);
198        }
199        Ok(())
200    }
201
202    /// The key's AEAD algorithm.
203    #[inline]
204    #[must_use]
205    pub fn algorithm(&self) -> &'static Algorithm {
206        self.algorithm
207    }
208
209    #[inline]
210    pub(crate) fn check_per_nonce_max_bytes(&self, in_out_len: usize) -> Result<(), Unspecified> {
211        if in_out_len as u64 > self.algorithm().max_input_len {
212            return Err(Unspecified);
213        }
214        Ok(())
215    }
216
217    #[inline]
218    #[allow(clippy::needless_pass_by_value)]
219    fn open_combined(
220        &self,
221        nonce: Nonce,
222        aad: &[u8],
223        in_out: &mut [u8],
224    ) -> Result<(), Unspecified> {
225        let nonce = nonce.as_ref();
226
227        debug_assert_eq!(nonce.len(), self.algorithm().nonce_len());
228
229        let plaintext_len = in_out.len() - self.algorithm().tag_len();
230
231        let mut out_len = MaybeUninit::<usize>::uninit();
232        if 1 != indicator_check!(unsafe {
233            EVP_AEAD_CTX_open(
234                *self.ctx.as_ref().as_const(),
235                in_out.as_mut_ptr(),
236                out_len.as_mut_ptr(),
237                plaintext_len,
238                nonce.as_ptr(),
239                nonce.len(),
240                in_out.as_ptr(),
241                plaintext_len + self.algorithm().tag_len(),
242                aad.as_ptr(),
243                aad.len(),
244            )
245        }) {
246            return Err(Unspecified);
247        }
248
249        Ok(())
250    }
251
252    #[inline]
253    #[allow(clippy::needless_pass_by_value)]
254    fn open_combined_randnonce(
255        &self,
256        nonce: Nonce,
257        aad: &[u8],
258        in_out: &mut [u8],
259    ) -> Result<(), Unspecified> {
260        let nonce = nonce.as_ref();
261
262        let alg_nonce_len = self.algorithm().nonce_len();
263        let alg_tag_len = self.algorithm().tag_len();
264
265        debug_assert_eq!(nonce.len(), alg_nonce_len);
266        debug_assert!(alg_tag_len + alg_nonce_len <= MAX_TAG_NONCE_BUFFER_LEN);
267
268        let plaintext_len = in_out.len() - alg_tag_len;
269
270        let mut tag_buffer = [0u8; MAX_TAG_NONCE_BUFFER_LEN];
271
272        tag_buffer[..alg_tag_len]
273            .copy_from_slice(&in_out[plaintext_len..plaintext_len + alg_tag_len]);
274        tag_buffer[alg_tag_len..alg_tag_len + alg_nonce_len].copy_from_slice(nonce);
275
276        let tag_slice = &tag_buffer[0..alg_tag_len + alg_nonce_len];
277
278        if 1 != indicator_check!(unsafe {
279            EVP_AEAD_CTX_open_gather(
280                *self.ctx.as_ref().as_const(),
281                in_out.as_mut_ptr(),
282                null(),
283                0,
284                in_out.as_ptr(),
285                plaintext_len,
286                tag_slice.as_ptr(),
287                tag_slice.len(),
288                aad.as_ptr(),
289                aad.len(),
290            )
291        }) {
292            return Err(Unspecified);
293        }
294
295        Ok(())
296    }
297
298    #[inline]
299    fn seal_combined<InOut>(
300        &self,
301        nonce: Nonce,
302        aad: &[u8],
303        in_out: &mut InOut,
304    ) -> Result<Nonce, Unspecified>
305    where
306        InOut: AsMut<[u8]> + for<'in_out> Extend<&'in_out u8>,
307    {
308        let plaintext_len = in_out.as_mut().len();
309
310        let alg_tag_len = self.algorithm().tag_len();
311
312        debug_assert!(alg_tag_len <= MAX_TAG_LEN);
313
314        let tag_buffer = [0u8; MAX_TAG_LEN];
315
316        in_out.extend(tag_buffer[..alg_tag_len].iter());
317
318        let mut out_len = MaybeUninit::<usize>::uninit();
319        let mut_in_out = in_out.as_mut();
320
321        {
322            let nonce = nonce.as_ref();
323
324            debug_assert_eq!(nonce.len(), self.algorithm().nonce_len());
325
326            if 1 != indicator_check!(unsafe {
327                EVP_AEAD_CTX_seal(
328                    *self.ctx.as_ref().as_const(),
329                    mut_in_out.as_mut_ptr(),
330                    out_len.as_mut_ptr(),
331                    plaintext_len + alg_tag_len,
332                    nonce.as_ptr(),
333                    nonce.len(),
334                    mut_in_out.as_ptr(),
335                    plaintext_len,
336                    aad.as_ptr(),
337                    aad.len(),
338                )
339            }) {
340                return Err(Unspecified);
341            }
342        }
343
344        Ok(nonce)
345    }
346
347    #[inline]
348    fn seal_combined_randnonce<InOut>(
349        &self,
350        aad: &[u8],
351        in_out: &mut InOut,
352    ) -> Result<Nonce, Unspecified>
353    where
354        InOut: AsMut<[u8]> + for<'in_out> Extend<&'in_out u8>,
355    {
356        let mut tag_buffer = [0u8; MAX_TAG_NONCE_BUFFER_LEN];
357
358        let mut out_tag_len = MaybeUninit::<usize>::uninit();
359
360        {
361            let plaintext_len = in_out.as_mut().len();
362            let in_out = in_out.as_mut();
363
364            if 1 != indicator_check!(unsafe {
365                EVP_AEAD_CTX_seal_scatter(
366                    *self.ctx.as_ref().as_const(),
367                    in_out.as_mut_ptr(),
368                    tag_buffer.as_mut_ptr(),
369                    out_tag_len.as_mut_ptr(),
370                    tag_buffer.len(),
371                    null(),
372                    0,
373                    in_out.as_ptr(),
374                    plaintext_len,
375                    null(),
376                    0,
377                    aad.as_ptr(),
378                    aad.len(),
379                )
380            }) {
381                return Err(Unspecified);
382            }
383        }
384
385        let tag_len = self.algorithm().tag_len();
386        let nonce_len = self.algorithm().nonce_len();
387
388        let nonce = Nonce(FixedLength::<NONCE_LEN>::try_from(
389            &tag_buffer[tag_len..tag_len + nonce_len],
390        )?);
391
392        in_out.extend(&tag_buffer[..tag_len]);
393
394        Ok(nonce)
395    }
396
397    #[inline]
398    fn seal_separate(
399        &self,
400        nonce: Nonce,
401        aad: &[u8],
402        in_out: &mut [u8],
403    ) -> Result<(Nonce, Tag), Unspecified> {
404        let mut tag = [0u8; MAX_TAG_LEN];
405        let mut out_tag_len = MaybeUninit::<usize>::uninit();
406        {
407            let nonce = nonce.as_ref();
408
409            debug_assert_eq!(nonce.len(), self.algorithm().nonce_len());
410
411            if 1 != indicator_check!(unsafe {
412                EVP_AEAD_CTX_seal_scatter(
413                    *self.ctx.as_ref().as_const(),
414                    in_out.as_mut_ptr(),
415                    tag.as_mut_ptr(),
416                    out_tag_len.as_mut_ptr(),
417                    tag.len(),
418                    nonce.as_ptr(),
419                    nonce.len(),
420                    in_out.as_ptr(),
421                    in_out.len(),
422                    null(),
423                    0usize,
424                    aad.as_ptr(),
425                    aad.len(),
426                )
427            }) {
428                return Err(Unspecified);
429            }
430        }
431        Ok((nonce, Tag(tag, unsafe { out_tag_len.assume_init() })))
432    }
433
434    #[inline]
435    fn seal_separate_randnonce(
436        &self,
437        aad: &[u8],
438        in_out: &mut [u8],
439    ) -> Result<(Nonce, Tag), Unspecified> {
440        let mut tag_buffer = [0u8; MAX_TAG_NONCE_BUFFER_LEN];
441
442        debug_assert!(
443            self.algorithm().tag_len() + self.algorithm().nonce_len() <= tag_buffer.len()
444        );
445
446        let mut out_tag_len = MaybeUninit::<usize>::uninit();
447
448        if 1 != indicator_check!(unsafe {
449            EVP_AEAD_CTX_seal_scatter(
450                *self.ctx.as_ref().as_const(),
451                in_out.as_mut_ptr(),
452                tag_buffer.as_mut_ptr(),
453                out_tag_len.as_mut_ptr(),
454                tag_buffer.len(),
455                null(),
456                0,
457                in_out.as_ptr(),
458                in_out.len(),
459                null(),
460                0usize,
461                aad.as_ptr(),
462                aad.len(),
463            )
464        }) {
465            return Err(Unspecified);
466        }
467
468        let tag_len = self.algorithm().tag_len();
469        let nonce_len = self.algorithm().nonce_len();
470
471        let nonce = Nonce(FixedLength::<NONCE_LEN>::try_from(
472            &tag_buffer[tag_len..tag_len + nonce_len],
473        )?);
474
475        let mut tag = [0u8; MAX_TAG_LEN];
476        tag.copy_from_slice(&tag_buffer[..tag_len]);
477
478        Ok((nonce, Tag(tag, tag_len)))
479    }
480}
481
482impl From<AeadCtx> for UnboundKey {
483    fn from(value: AeadCtx) -> Self {
484        let algorithm = match value {
485            AeadCtx::AES_128_GCM(_)
486            | AeadCtx::AES_128_GCM_TLS12(_)
487            | AeadCtx::AES_128_GCM_TLS13(_)
488            | AeadCtx::AES_128_GCM_RANDNONCE(_) => &AES_128_GCM,
489            AeadCtx::AES_192_GCM(_) => &AES_192_GCM,
490            AeadCtx::AES_128_GCM_SIV(_) => &AES_128_GCM_SIV,
491            AeadCtx::AES_256_GCM(_)
492            | AeadCtx::AES_256_GCM_RANDNONCE(_)
493            | AeadCtx::AES_256_GCM_TLS12(_)
494            | AeadCtx::AES_256_GCM_TLS13(_) => &AES_256_GCM,
495            AeadCtx::AES_256_GCM_SIV(_) => &AES_256_GCM_SIV,
496            AeadCtx::CHACHA20_POLY1305(_) => &CHACHA20_POLY1305,
497        };
498        Self {
499            ctx: value,
500            algorithm,
501        }
502    }
503}
504
505impl From<hkdf::Okm<'_, &'static Algorithm>> for UnboundKey {
506    fn from(okm: hkdf::Okm<&'static Algorithm>) -> Self {
507        let mut key_bytes = [0; MAX_KEY_LEN];
508        let key_bytes = &mut key_bytes[..okm.len().key_len];
509        let algorithm = *okm.len();
510        okm.fill(key_bytes).unwrap();
511        Self::new(algorithm, key_bytes).unwrap()
512    }
513}