1use crate::aws_lc::{
49 EVP_PKEY_CTX_kem_set_params, EVP_PKEY_decapsulate, EVP_PKEY_encapsulate,
50 EVP_PKEY_kem_new_raw_public_key, EVP_PKEY, EVP_PKEY_KEM,
51};
52use crate::buffer::Buffer;
53use crate::encoding::generated_encodings;
54use crate::error::{KeyRejected, Unspecified};
55use crate::ptr::LcPtr;
56use alloc::borrow::Cow;
57use core::cmp::Ordering;
58use zeroize::Zeroize;
59
60const ML_KEM_512_SHARED_SECRET_LENGTH: usize = 32;
61const ML_KEM_512_PUBLIC_KEY_LENGTH: usize = 800;
62const ML_KEM_512_SECRET_KEY_LENGTH: usize = 1632;
63const ML_KEM_512_CIPHERTEXT_LENGTH: usize = 768;
64
65const ML_KEM_768_SHARED_SECRET_LENGTH: usize = 32;
66const ML_KEM_768_PUBLIC_KEY_LENGTH: usize = 1184;
67const ML_KEM_768_SECRET_KEY_LENGTH: usize = 2400;
68const ML_KEM_768_CIPHERTEXT_LENGTH: usize = 1088;
69
70const ML_KEM_1024_SHARED_SECRET_LENGTH: usize = 32;
71const ML_KEM_1024_PUBLIC_KEY_LENGTH: usize = 1568;
72const ML_KEM_1024_SECRET_KEY_LENGTH: usize = 3168;
73const ML_KEM_1024_CIPHERTEXT_LENGTH: usize = 1568;
74
75pub const ML_KEM_512: Algorithm<AlgorithmId> = Algorithm {
77 id: AlgorithmId::MlKem512,
78 decapsulate_key_size: ML_KEM_512_SECRET_KEY_LENGTH,
79 encapsulate_key_size: ML_KEM_512_PUBLIC_KEY_LENGTH,
80 ciphertext_size: ML_KEM_512_CIPHERTEXT_LENGTH,
81 shared_secret_size: ML_KEM_512_SHARED_SECRET_LENGTH,
82};
83
84pub const ML_KEM_768: Algorithm<AlgorithmId> = Algorithm {
86 id: AlgorithmId::MlKem768,
87 decapsulate_key_size: ML_KEM_768_SECRET_KEY_LENGTH,
88 encapsulate_key_size: ML_KEM_768_PUBLIC_KEY_LENGTH,
89 ciphertext_size: ML_KEM_768_CIPHERTEXT_LENGTH,
90 shared_secret_size: ML_KEM_768_SHARED_SECRET_LENGTH,
91};
92
93pub const ML_KEM_1024: Algorithm<AlgorithmId> = Algorithm {
95 id: AlgorithmId::MlKem1024,
96 decapsulate_key_size: ML_KEM_1024_SECRET_KEY_LENGTH,
97 encapsulate_key_size: ML_KEM_1024_PUBLIC_KEY_LENGTH,
98 ciphertext_size: ML_KEM_1024_CIPHERTEXT_LENGTH,
99 shared_secret_size: ML_KEM_1024_SHARED_SECRET_LENGTH,
100};
101
102use crate::aws_lc::{NID_MLKEM1024, NID_MLKEM512, NID_MLKEM768};
103
104pub trait AlgorithmIdentifier:
106 Copy + Clone + Debug + PartialEq + crate::sealed::Sealed + 'static
107{
108 fn nid(self) -> i32;
110}
111
112#[derive(PartialEq)]
114pub struct Algorithm<Id = AlgorithmId>
115where
116 Id: AlgorithmIdentifier,
117{
118 pub(crate) id: Id,
119 pub(crate) decapsulate_key_size: usize,
120 pub(crate) encapsulate_key_size: usize,
121 pub(crate) ciphertext_size: usize,
122 pub(crate) shared_secret_size: usize,
123}
124
125impl<Id> Algorithm<Id>
126where
127 Id: AlgorithmIdentifier,
128{
129 #[must_use]
131 pub fn id(&self) -> Id {
132 self.id
133 }
134
135 #[inline]
136 #[allow(dead_code)]
137 pub(crate) fn decapsulate_key_size(&self) -> usize {
138 self.decapsulate_key_size
139 }
140
141 #[inline]
142 pub(crate) fn encapsulate_key_size(&self) -> usize {
143 self.encapsulate_key_size
144 }
145
146 #[inline]
147 pub(crate) fn ciphertext_size(&self) -> usize {
148 self.ciphertext_size
149 }
150
151 #[inline]
152 pub(crate) fn shared_secret_size(&self) -> usize {
153 self.shared_secret_size
154 }
155}
156
157impl<Id> Debug for Algorithm<Id>
158where
159 Id: AlgorithmIdentifier,
160{
161 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
162 Debug::fmt(&self.id, f)
163 }
164}
165
166pub struct DecapsulationKey<Id = AlgorithmId>
168where
169 Id: AlgorithmIdentifier,
170{
171 algorithm: &'static Algorithm<Id>,
172 evp_pkey: LcPtr<EVP_PKEY>,
173}
174
175#[non_exhaustive]
177#[derive(Clone, Copy, Debug, PartialEq)]
178pub enum AlgorithmId {
179 MlKem512,
181
182 MlKem768,
184
185 MlKem1024,
187}
188
189impl AlgorithmIdentifier for AlgorithmId {
190 fn nid(self) -> i32 {
191 match self {
192 AlgorithmId::MlKem512 => NID_MLKEM512,
193 AlgorithmId::MlKem768 => NID_MLKEM768,
194 AlgorithmId::MlKem1024 => NID_MLKEM1024,
195 }
196 }
197}
198
199impl crate::sealed::Sealed for AlgorithmId {}
200
201impl<Id> DecapsulationKey<Id>
202where
203 Id: AlgorithmIdentifier,
204{
205 pub fn generate(alg: &'static Algorithm<Id>) -> Result<Self, Unspecified> {
210 let kyber_key = kem_key_generate(alg.id.nid())?;
211 Ok(DecapsulationKey {
212 algorithm: alg,
213 evp_pkey: kyber_key,
214 })
215 }
216
217 #[must_use]
219 pub fn algorithm(&self) -> &'static Algorithm<Id> {
220 self.algorithm
221 }
222
223 #[allow(clippy::missing_panics_doc)]
228 pub fn encapsulation_key(&self) -> Result<EncapsulationKey<Id>, Unspecified> {
229 let evp_pkey = self.evp_pkey.clone();
230
231 Ok(EncapsulationKey {
232 algorithm: self.algorithm,
233 evp_pkey,
234 })
235 }
236
237 #[allow(clippy::needless_pass_by_value)]
245 pub fn decapsulate(&self, ciphertext: Ciphertext<'_>) -> Result<SharedSecret, Unspecified> {
246 let mut shared_secret_len = self.algorithm.shared_secret_size();
247 let mut shared_secret: Vec<u8> = vec![0u8; shared_secret_len];
248
249 let mut ctx = self.evp_pkey.create_EVP_PKEY_CTX()?;
250
251 let ciphertext = ciphertext.as_ref();
252
253 if 1 != unsafe {
254 EVP_PKEY_decapsulate(
255 *ctx.as_mut(),
256 shared_secret.as_mut_ptr(),
257 &mut shared_secret_len,
258 ciphertext.as_ptr() as *mut u8,
260 ciphertext.len(),
261 )
262 } {
263 return Err(Unspecified);
264 }
265
266 debug_assert_eq!(shared_secret_len, shared_secret.len());
271 shared_secret.truncate(shared_secret_len);
272
273 Ok(SharedSecret(shared_secret.into_boxed_slice()))
274 }
275}
276
277unsafe impl<Id> Send for DecapsulationKey<Id> where Id: AlgorithmIdentifier {}
278
279unsafe impl<Id> Sync for DecapsulationKey<Id> where Id: AlgorithmIdentifier {}
280
281impl<Id> Debug for DecapsulationKey<Id>
282where
283 Id: AlgorithmIdentifier,
284{
285 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
286 f.debug_struct("DecapsulationKey")
287 .field("algorithm", &self.algorithm)
288 .finish_non_exhaustive()
289 }
290}
291
292use paste::paste;
293
294generated_encodings!(EncapsulationKeyBytes);
295
296pub struct EncapsulationKey<Id = AlgorithmId>
299where
300 Id: AlgorithmIdentifier,
301{
302 algorithm: &'static Algorithm<Id>,
303 evp_pkey: LcPtr<EVP_PKEY>,
304}
305
306impl<Id> EncapsulationKey<Id>
307where
308 Id: AlgorithmIdentifier,
309{
310 #[must_use]
312 pub fn algorithm(&self) -> &'static Algorithm<Id> {
313 self.algorithm
314 }
315
316 pub fn encapsulate(&self) -> Result<(Ciphertext<'static>, SharedSecret), Unspecified> {
322 let mut ciphertext_len = self.algorithm.ciphertext_size();
323 let mut shared_secret_len = self.algorithm.shared_secret_size();
324 let mut ciphertext: Vec<u8> = vec![0u8; ciphertext_len];
325 let mut shared_secret: Vec<u8> = vec![0u8; shared_secret_len];
326
327 let mut ctx = self.evp_pkey.create_EVP_PKEY_CTX()?;
328
329 if 1 != unsafe {
330 EVP_PKEY_encapsulate(
331 *ctx.as_mut(),
332 ciphertext.as_mut_ptr(),
333 &mut ciphertext_len,
334 shared_secret.as_mut_ptr(),
335 &mut shared_secret_len,
336 )
337 } {
338 return Err(Unspecified);
339 }
340
341 debug_assert_eq!(ciphertext_len, ciphertext.len());
347 ciphertext.truncate(ciphertext_len);
348 debug_assert_eq!(shared_secret_len, shared_secret.len());
349 shared_secret.truncate(shared_secret_len);
350
351 Ok((
352 Ciphertext::new(ciphertext),
353 SharedSecret::new(shared_secret.into_boxed_slice()),
354 ))
355 }
356
357 pub fn key_bytes(&self) -> Result<EncapsulationKeyBytes<'static>, Unspecified> {
362 let mut encapsulate_bytes = vec![0u8; self.algorithm.encapsulate_key_size()];
363 let encapsulate_key_size = self
364 .evp_pkey
365 .marshal_raw_public_to_buffer(&mut encapsulate_bytes)?;
366
367 debug_assert_eq!(encapsulate_key_size, encapsulate_bytes.len());
368 encapsulate_bytes.truncate(encapsulate_key_size);
369
370 Ok(EncapsulationKeyBytes::new(encapsulate_bytes))
371 }
372
373 pub fn new(alg: &'static Algorithm<Id>, bytes: &[u8]) -> Result<Self, KeyRejected> {
384 match bytes.len().cmp(&alg.encapsulate_key_size()) {
385 Ordering::Less => Err(KeyRejected::too_small()),
386 Ordering::Greater => Err(KeyRejected::too_large()),
387 Ordering::Equal => Ok(()),
388 }?;
389 let pubkey = LcPtr::new(unsafe {
390 EVP_PKEY_kem_new_raw_public_key(alg.id.nid(), bytes.as_ptr(), bytes.len())
391 })?;
392 Ok(EncapsulationKey {
393 algorithm: alg,
394 evp_pkey: pubkey,
395 })
396 }
397}
398
399unsafe impl<Id> Send for EncapsulationKey<Id> where Id: AlgorithmIdentifier {}
400
401unsafe impl<Id> Sync for EncapsulationKey<Id> where Id: AlgorithmIdentifier {}
402
403impl<Id> Debug for EncapsulationKey<Id>
404where
405 Id: AlgorithmIdentifier,
406{
407 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
408 f.debug_struct("EncapsulationKey")
409 .field("algorithm", &self.algorithm)
410 .finish_non_exhaustive()
411 }
412}
413
414pub struct Ciphertext<'a>(Cow<'a, [u8]>);
417
418impl<'a> Ciphertext<'a> {
419 fn new(value: Vec<u8>) -> Ciphertext<'a> {
420 Self(Cow::Owned(value))
421 }
422}
423
424impl Drop for Ciphertext<'_> {
425 fn drop(&mut self) {
426 if let Cow::Owned(ref mut v) = self.0 {
427 v.zeroize();
428 }
429 }
430}
431
432impl AsRef<[u8]> for Ciphertext<'_> {
433 fn as_ref(&self) -> &[u8] {
434 match self.0 {
435 Cow::Borrowed(v) => v,
436 Cow::Owned(ref v) => v.as_ref(),
437 }
438 }
439}
440
441impl<'a> From<&'a [u8]> for Ciphertext<'a> {
442 fn from(value: &'a [u8]) -> Self {
443 Self(Cow::Borrowed(value))
444 }
445}
446
447pub struct SharedSecret(Box<[u8]>);
449
450impl SharedSecret {
451 fn new(value: Box<[u8]>) -> Self {
452 Self(value)
453 }
454}
455
456impl Drop for SharedSecret {
457 fn drop(&mut self) {
458 self.0.zeroize();
459 }
460}
461
462impl AsRef<[u8]> for SharedSecret {
463 fn as_ref(&self) -> &[u8] {
464 self.0.as_ref()
465 }
466}
467
468#[inline]
470fn kem_key_generate(nid: i32) -> Result<LcPtr<EVP_PKEY>, Unspecified> {
471 let params_fn = |ctx| {
472 if 1 == unsafe { EVP_PKEY_CTX_kem_set_params(ctx, nid) } {
473 Ok(())
474 } else {
475 Err(())
476 }
477 };
478
479 LcPtr::<EVP_PKEY>::generate(EVP_PKEY_KEM, Some(params_fn))
480}
481
482#[cfg(test)]
483mod tests {
484 use super::{Ciphertext, DecapsulationKey, EncapsulationKey, SharedSecret};
485 use crate::error::KeyRejected;
486
487 use crate::kem::{ML_KEM_1024, ML_KEM_512, ML_KEM_768};
488
489 #[test]
490 fn ciphertext() {
491 let ciphertext_bytes = vec![42u8; 4];
492 let ciphertext = Ciphertext::from(ciphertext_bytes.as_ref());
493 assert_eq!(ciphertext.as_ref(), &[42, 42, 42, 42]);
494 drop(ciphertext);
495
496 let ciphertext_bytes = vec![42u8; 4];
497 let ciphertext = Ciphertext::<'static>::new(ciphertext_bytes);
498 assert_eq!(ciphertext.as_ref(), &[42, 42, 42, 42]);
499 }
500
501 #[test]
502 fn shared_secret() {
503 let secret_bytes = vec![42u8; 4];
504 let shared_secret = SharedSecret::new(secret_bytes.into_boxed_slice());
505 assert_eq!(shared_secret.as_ref(), &[42, 42, 42, 42]);
506 }
507
508 #[test]
509 fn test_kem_serialize() {
510 for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
511 let priv_key = DecapsulationKey::generate(algorithm).unwrap();
512 assert_eq!(priv_key.algorithm(), algorithm);
513
514 let pub_key = priv_key.encapsulation_key().unwrap();
515 let pubkey_raw_bytes = pub_key.key_bytes().unwrap();
516 let pub_key_from_bytes =
517 EncapsulationKey::new(algorithm, pubkey_raw_bytes.as_ref()).unwrap();
518
519 assert_eq!(
520 pub_key.key_bytes().unwrap().as_ref(),
521 pub_key_from_bytes.key_bytes().unwrap().as_ref()
522 );
523 assert_eq!(pub_key.algorithm(), pub_key_from_bytes.algorithm());
524 }
525 }
526
527 #[test]
528 fn test_kem_wrong_sizes() {
529 for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
530 let too_long_bytes = vec![0u8; algorithm.encapsulate_key_size() + 1];
531 let long_pub_key_from_bytes = EncapsulationKey::new(algorithm, &too_long_bytes);
532 assert_eq!(
533 long_pub_key_from_bytes.err(),
534 Some(KeyRejected::too_large())
535 );
536
537 let too_short_bytes = vec![0u8; algorithm.encapsulate_key_size() - 1];
538 let short_pub_key_from_bytes = EncapsulationKey::new(algorithm, &too_short_bytes);
539 assert_eq!(
540 short_pub_key_from_bytes.err(),
541 Some(KeyRejected::too_small())
542 );
543 }
544 }
545
546 #[test]
547 fn test_kem_e2e() {
548 for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
549 let priv_key = DecapsulationKey::generate(algorithm).unwrap();
550 assert_eq!(priv_key.algorithm(), algorithm);
551
552 let pub_key = priv_key.encapsulation_key().unwrap();
553
554 let (alice_ciphertext, alice_secret) =
555 pub_key.encapsulate().expect("encapsulate successful");
556
557 let bob_secret = priv_key
558 .decapsulate(alice_ciphertext)
559 .expect("decapsulate successful");
560
561 assert_eq!(alice_secret.as_ref(), bob_secret.as_ref());
562 }
563 }
564
565 #[test]
566 fn test_serialized_kem_e2e() {
567 for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
568 let priv_key = DecapsulationKey::generate(algorithm).unwrap();
569 assert_eq!(priv_key.algorithm(), algorithm);
570
571 let pub_key = priv_key.encapsulation_key().unwrap();
572
573 let pub_key_bytes = pub_key.key_bytes().unwrap();
575
576 drop(pub_key);
578
579 let retrieved_pub_key =
580 EncapsulationKey::new(algorithm, pub_key_bytes.as_ref()).unwrap();
581 let (ciphertext, bob_secret) = retrieved_pub_key
582 .encapsulate()
583 .expect("encapsulate successful");
584
585 let alice_secret = priv_key
586 .decapsulate(ciphertext)
587 .expect("decapsulate successful");
588
589 assert_eq!(alice_secret.as_ref(), bob_secret.as_ref());
590 }
591 }
592
593 #[test]
594 fn test_debug_fmt() {
595 let private = DecapsulationKey::generate(&ML_KEM_512).expect("successful generation");
596 assert_eq!(
597 format!("{private:?}"),
598 "DecapsulationKey { algorithm: MlKem512, .. }"
599 );
600 assert_eq!(
601 format!(
602 "{:?}",
603 private.encapsulation_key().expect("public key retrievable")
604 ),
605 "EncapsulationKey { algorithm: MlKem512, .. }"
606 );
607 }
608}