1use super::aead_ctx::{self, AeadCtx};
5use super::{Aad, Algorithm, AlgorithmID, Nonce, Tag, UnboundKey};
6use crate::error::Unspecified;
7use core::fmt::Debug;
8use core::ops::RangeFrom;
9
10#[allow(clippy::module_name_repetitions)]
12#[derive(Debug, PartialEq, Eq, Clone, Copy)]
13#[non_exhaustive]
14pub enum TlsProtocolId {
15 TLS12,
17
18 TLS13,
20}
21
22#[allow(clippy::module_name_repetitions)]
33pub struct TlsRecordSealingKey {
34 key: UnboundKey,
39 protocol: TlsProtocolId,
40}
41
42impl TlsRecordSealingKey {
43 pub fn new(
49 algorithm: &'static Algorithm,
50 protocol: TlsProtocolId,
51 key_bytes: &[u8],
52 ) -> Result<Self, Unspecified> {
53 let ctx = match (algorithm.id, protocol) {
54 (AlgorithmID::AES_128_GCM, TlsProtocolId::TLS12) => AeadCtx::aes_128_gcm_tls12(
55 key_bytes,
56 algorithm.tag_len(),
57 aead_ctx::AeadDirection::Seal,
58 ),
59 (AlgorithmID::AES_128_GCM, TlsProtocolId::TLS13) => AeadCtx::aes_128_gcm_tls13(
60 key_bytes,
61 algorithm.tag_len(),
62 aead_ctx::AeadDirection::Seal,
63 ),
64 (AlgorithmID::AES_256_GCM, TlsProtocolId::TLS12) => AeadCtx::aes_256_gcm_tls12(
65 key_bytes,
66 algorithm.tag_len(),
67 aead_ctx::AeadDirection::Seal,
68 ),
69 (AlgorithmID::AES_256_GCM, TlsProtocolId::TLS13) => AeadCtx::aes_256_gcm_tls13(
70 key_bytes,
71 algorithm.tag_len(),
72 aead_ctx::AeadDirection::Seal,
73 ),
74 (
75 AlgorithmID::AES_128_GCM_SIV
76 | AlgorithmID::AES_192_GCM
77 | AlgorithmID::AES_256_GCM_SIV
78 | AlgorithmID::CHACHA20_POLY1305,
79 _,
80 ) => Err(Unspecified),
81 }?;
82 Ok(Self {
83 key: UnboundKey::from(ctx),
84 protocol,
85 })
86 }
87
88 #[inline]
97 #[allow(clippy::needless_pass_by_value)]
98 pub fn seal_in_place_append_tag<A, InOut>(
99 &mut self,
100 nonce: Nonce,
101 aad: Aad<A>,
102 in_out: &mut InOut,
103 ) -> Result<(), Unspecified>
104 where
105 A: AsRef<[u8]>,
106 InOut: AsMut<[u8]> + for<'in_out> Extend<&'in_out u8>,
107 {
108 self.key
109 .seal_in_place_append_tag(Some(nonce), aad.as_ref(), in_out)
110 .map(|_| ())
111 }
112
113 #[inline]
130 #[allow(clippy::needless_pass_by_value)]
131 pub fn seal_in_place_separate_tag<A>(
132 &mut self,
133 nonce: Nonce,
134 aad: Aad<A>,
135 in_out: &mut [u8],
136 ) -> Result<Tag, Unspecified>
137 where
138 A: AsRef<[u8]>,
139 {
140 self.key
141 .seal_in_place_separate_tag(Some(nonce), aad.as_ref(), in_out)
142 .map(|(_, tag)| tag)
143 }
144
145 #[inline]
147 #[must_use]
148 pub fn algorithm(&self) -> &'static Algorithm {
149 self.key.algorithm()
150 }
151
152 #[must_use]
154 pub fn tls_protocol_id(&self) -> TlsProtocolId {
155 self.protocol
156 }
157}
158
159#[allow(clippy::missing_fields_in_debug)]
160impl Debug for TlsRecordSealingKey {
161 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
162 f.debug_struct("TlsRecordSealingKey")
163 .field("key", &self.key)
164 .field("protocol", &self.protocol)
165 .finish()
166 }
167}
168
169#[allow(clippy::module_name_repetitions)]
179pub struct TlsRecordOpeningKey {
180 key: UnboundKey,
185 protocol: TlsProtocolId,
186}
187
188impl TlsRecordOpeningKey {
189 pub fn new(
195 algorithm: &'static Algorithm,
196 protocol: TlsProtocolId,
197 key_bytes: &[u8],
198 ) -> Result<Self, Unspecified> {
199 let ctx = match (algorithm.id, protocol) {
200 (AlgorithmID::AES_128_GCM, TlsProtocolId::TLS12) => AeadCtx::aes_128_gcm_tls12(
201 key_bytes,
202 algorithm.tag_len(),
203 aead_ctx::AeadDirection::Open,
204 ),
205 (AlgorithmID::AES_128_GCM, TlsProtocolId::TLS13) => AeadCtx::aes_128_gcm_tls13(
206 key_bytes,
207 algorithm.tag_len(),
208 aead_ctx::AeadDirection::Open,
209 ),
210 (AlgorithmID::AES_256_GCM, TlsProtocolId::TLS12) => AeadCtx::aes_256_gcm_tls12(
211 key_bytes,
212 algorithm.tag_len(),
213 aead_ctx::AeadDirection::Open,
214 ),
215 (AlgorithmID::AES_256_GCM, TlsProtocolId::TLS13) => AeadCtx::aes_256_gcm_tls13(
216 key_bytes,
217 algorithm.tag_len(),
218 aead_ctx::AeadDirection::Open,
219 ),
220 (
221 AlgorithmID::AES_128_GCM_SIV
222 | AlgorithmID::AES_192_GCM
223 | AlgorithmID::AES_256_GCM_SIV
224 | AlgorithmID::CHACHA20_POLY1305,
225 _,
226 ) => Err(Unspecified),
227 }?;
228 Ok(Self {
229 key: UnboundKey::from(ctx),
230 protocol,
231 })
232 }
233
234 #[inline]
239 #[allow(clippy::needless_pass_by_value)]
240 pub fn open_in_place<'in_out, A>(
241 &self,
242 nonce: Nonce,
243 aad: Aad<A>,
244 in_out: &'in_out mut [u8],
245 ) -> Result<&'in_out mut [u8], Unspecified>
246 where
247 A: AsRef<[u8]>,
248 {
249 self.key.open_within(nonce, aad.as_ref(), in_out, 0..)
250 }
251
252 #[inline]
257 #[allow(clippy::needless_pass_by_value)]
258 pub fn open_within<'in_out, A>(
259 &self,
260 nonce: Nonce,
261 aad: Aad<A>,
262 in_out: &'in_out mut [u8],
263 ciphertext_and_tag: RangeFrom<usize>,
264 ) -> Result<&'in_out mut [u8], Unspecified>
265 where
266 A: AsRef<[u8]>,
267 {
268 self.key
269 .open_within(nonce, aad.as_ref(), in_out, ciphertext_and_tag)
270 }
271
272 #[inline]
274 #[must_use]
275 pub fn algorithm(&self) -> &'static Algorithm {
276 self.key.algorithm()
277 }
278
279 #[must_use]
281 pub fn tls_protocol_id(&self) -> TlsProtocolId {
282 self.protocol
283 }
284}
285
286#[allow(clippy::missing_fields_in_debug)]
287impl Debug for TlsRecordOpeningKey {
288 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
289 f.debug_struct("TlsRecordOpeningKey")
290 .field("key", &self.key)
291 .field("protocol", &self.protocol)
292 .finish()
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::{TlsProtocolId, TlsRecordOpeningKey, TlsRecordSealingKey};
299 use crate::aead::{Aad, Nonce, AES_128_GCM, AES_256_GCM, CHACHA20_POLY1305};
300 use crate::test::from_hex;
301 use paste::paste;
302
303 const TEST_128_BIT_KEY: &[u8] = &[
304 0xb0, 0x37, 0x9f, 0xf8, 0xfb, 0x8e, 0xa6, 0x31, 0xf4, 0x1c, 0xe6, 0x3e, 0xb5, 0xc5, 0x20,
305 0x7c,
306 ];
307
308 const TEST_256_BIT_KEY: &[u8] = &[
309 0x56, 0xd8, 0x96, 0x68, 0xbd, 0x96, 0xeb, 0xff, 0x5e, 0xa2, 0x0b, 0x34, 0xf2, 0x79, 0x84,
310 0x6e, 0x2b, 0x13, 0x01, 0x3d, 0xab, 0x1d, 0xa4, 0x07, 0x5a, 0x16, 0xd5, 0x0b, 0x53, 0xb0,
311 0xcc, 0x88,
312 ];
313
314 struct TlsNonceTestCase {
315 nonce: &'static str,
316 expect_err: bool,
317 }
318
319 const TLS_NONCE_TEST_CASES: &[TlsNonceTestCase] = &[
320 TlsNonceTestCase {
321 nonce: "9fab40177c900aad9fc28cc3",
322 expect_err: false,
323 },
324 TlsNonceTestCase {
325 nonce: "9fab40177c900aad9fc28cc4",
326 expect_err: false,
327 },
328 TlsNonceTestCase {
329 nonce: "9fab40177c900aad9fc28cc2",
330 expect_err: true,
331 },
332 ];
333
334 macro_rules! test_tls_aead {
335 ($name:ident, $alg:expr, $proto:expr, $key:expr) => {
336 paste! {
337 #[test]
338 fn [<test_ $name _tls_aead_unsupported>]() {
339 assert!(TlsRecordSealingKey::new($alg, $proto, $key).is_err());
340 assert!(TlsRecordOpeningKey::new($alg, $proto, $key).is_err());
341 }
342 }
343 };
344 ($name:ident, $alg:expr, $proto:expr, $key:expr, $expect_tag_len:expr, $expect_nonce_len:expr) => {
345 paste! {
346 #[test]
347 fn [<test_ $name>]() {
348 let mut sealing_key =
349 TlsRecordSealingKey::new($alg, $proto, $key).unwrap();
350
351 let opening_key =
352 TlsRecordOpeningKey::new($alg, $proto, $key).unwrap();
353
354 for case in TLS_NONCE_TEST_CASES {
355 let plaintext = from_hex("00112233445566778899aabbccddeeff").unwrap();
356
357 assert_eq!($alg, sealing_key.algorithm());
358 assert_eq!(*$expect_tag_len, $alg.tag_len());
359 assert_eq!(*$expect_nonce_len, $alg.nonce_len());
360
361 let mut in_out = Vec::from(plaintext.as_slice());
362
363 let nonce = from_hex(case.nonce).unwrap();
364
365 let nonce_bytes = nonce.as_slice();
366
367 let result = sealing_key.seal_in_place_append_tag(
368 Nonce::try_assume_unique_for_key(nonce_bytes).unwrap(),
369 Aad::empty(),
370 &mut in_out,
371 );
372
373 match (result, case.expect_err) {
374 (Ok(()), true) => panic!("expected error for seal_in_place_append_tag"),
375 (Ok(()), false) => {}
376 (Err(_), true) => return,
377 (Err(e), false) => panic!("{e}"),
378 }
379
380 assert_ne!(plaintext, in_out[..plaintext.len()]);
381
382 let mut offset_cipher_text = vec![ 1, 2, 3, 4 ];
384 offset_cipher_text.extend_from_slice(&in_out);
385
386 opening_key
387 .open_in_place(
388 Nonce::try_assume_unique_for_key(nonce_bytes).unwrap(),
389 Aad::empty(),
390 &mut in_out,
391 )
392 .unwrap();
393
394 assert_eq!(plaintext, in_out[..plaintext.len()]);
395
396 opening_key
397 .open_within(
398 Nonce::try_assume_unique_for_key(nonce_bytes).unwrap(),
399 Aad::empty(),
400 &mut offset_cipher_text,
401 4..)
402 .unwrap();
403 assert_eq!(plaintext, offset_cipher_text[..plaintext.len()]);
404 }
405 }
406 }
407 };
408 }
409
410 test_tls_aead!(
411 aes_128_gcm_tls12,
412 &AES_128_GCM,
413 TlsProtocolId::TLS12,
414 TEST_128_BIT_KEY,
415 &16,
416 &12
417 );
418 test_tls_aead!(
419 aes_128_gcm_tls13,
420 &AES_128_GCM,
421 TlsProtocolId::TLS13,
422 TEST_128_BIT_KEY,
423 &16,
424 &12
425 );
426 test_tls_aead!(
427 aes_256_gcm_tls12,
428 &AES_256_GCM,
429 TlsProtocolId::TLS12,
430 TEST_256_BIT_KEY,
431 &16,
432 &12
433 );
434 test_tls_aead!(
435 aes_256_gcm_tls13,
436 &AES_256_GCM,
437 TlsProtocolId::TLS13,
438 TEST_256_BIT_KEY,
439 &16,
440 &12
441 );
442 test_tls_aead!(
443 chacha20_poly1305_tls12,
444 &CHACHA20_POLY1305,
445 TlsProtocolId::TLS12,
446 TEST_256_BIT_KEY
447 );
448 test_tls_aead!(
449 chacha20_poly1305_tls13,
450 &CHACHA20_POLY1305,
451 TlsProtocolId::TLS13,
452 TEST_256_BIT_KEY
453 );
454}