use crate::{
buffer::Buffer,
encoding::generated_encodings,
error::{KeyRejected, Unspecified},
ptr::LcPtr,
};
use alloc::borrow::Cow;
use aws_lc::{
EVP_PKEY_CTX_kem_set_params, EVP_PKEY_CTX_new_id, EVP_PKEY_decapsulate, EVP_PKEY_encapsulate,
EVP_PKEY_get_raw_private_key, EVP_PKEY_get_raw_public_key, EVP_PKEY_kem_new_raw_public_key,
EVP_PKEY_keygen, EVP_PKEY_keygen_init, EVP_PKEY, EVP_PKEY_KEM,
};
use core::{cmp::Ordering, ptr::null_mut};
use zeroize::Zeroize;
#[cfg(not(feature = "fips"))]
pub(crate) mod semistable {
#![allow(unused)]
use super::{Algorithm, AlgorithmId};
const ML_KEM_512_SHARED_SECRET_LENGTH: usize = 32;
const ML_KEM_512_PUBLIC_KEY_LENGTH: usize = 800;
const ML_KEM_512_SECRET_KEY_LENGTH: usize = 1632;
const ML_KEM_512_CIPHERTEXT_LENGTH: usize = 768;
const ML_KEM_768_SHARED_SECRET_LENGTH: usize = 32;
const ML_KEM_768_PUBLIC_KEY_LENGTH: usize = 1184;
const ML_KEM_768_SECRET_KEY_LENGTH: usize = 2400;
const ML_KEM_768_CIPHERTEXT_LENGTH: usize = 1088;
const ML_KEM_1024_SHARED_SECRET_LENGTH: usize = 32;
const ML_KEM_1024_PUBLIC_KEY_LENGTH: usize = 1568;
const ML_KEM_1024_SECRET_KEY_LENGTH: usize = 3168;
const ML_KEM_1024_CIPHERTEXT_LENGTH: usize = 1568;
pub const ML_KEM_512: Algorithm<AlgorithmId> = Algorithm {
id: AlgorithmId::MlKem512,
decapsulate_key_size: ML_KEM_512_SECRET_KEY_LENGTH,
encapsulate_key_size: ML_KEM_512_PUBLIC_KEY_LENGTH,
ciphertext_size: ML_KEM_512_CIPHERTEXT_LENGTH,
shared_secret_size: ML_KEM_512_SHARED_SECRET_LENGTH,
};
pub const ML_KEM_768: Algorithm<AlgorithmId> = Algorithm {
id: AlgorithmId::MlKem768,
decapsulate_key_size: ML_KEM_768_SECRET_KEY_LENGTH,
encapsulate_key_size: ML_KEM_768_PUBLIC_KEY_LENGTH,
ciphertext_size: ML_KEM_768_CIPHERTEXT_LENGTH,
shared_secret_size: ML_KEM_768_SHARED_SECRET_LENGTH,
};
pub const ML_KEM_1024: Algorithm<AlgorithmId> = Algorithm {
id: AlgorithmId::MlKem1024,
decapsulate_key_size: ML_KEM_1024_SECRET_KEY_LENGTH,
encapsulate_key_size: ML_KEM_1024_PUBLIC_KEY_LENGTH,
ciphertext_size: ML_KEM_1024_CIPHERTEXT_LENGTH,
shared_secret_size: ML_KEM_1024_SHARED_SECRET_LENGTH,
};
}
#[cfg(feature = "fips")]
mod missing_nid {
pub const NID_MLKEM512: i32 = 988;
pub const NID_MLKEM768: i32 = 989;
pub const NID_MLKEM1024: i32 = 990;
}
#[cfg(feature = "fips")]
use self::missing_nid::{NID_MLKEM1024, NID_MLKEM512, NID_MLKEM768};
#[cfg(not(feature = "fips"))]
use aws_lc::{NID_MLKEM1024, NID_MLKEM512, NID_MLKEM768};
pub trait AlgorithmIdentifier:
Copy + Clone + Debug + PartialEq + crate::sealed::Sealed + 'static
{
fn nid(self) -> i32;
}
#[derive(PartialEq)]
pub struct Algorithm<Id = AlgorithmId>
where
Id: AlgorithmIdentifier,
{
pub(crate) id: Id,
pub(crate) decapsulate_key_size: usize,
pub(crate) encapsulate_key_size: usize,
pub(crate) ciphertext_size: usize,
pub(crate) shared_secret_size: usize,
}
impl<Id> Algorithm<Id>
where
Id: AlgorithmIdentifier,
{
#[must_use]
pub fn id(&self) -> Id {
self.id
}
#[inline]
pub(crate) fn decapsulate_key_size(&self) -> usize {
self.decapsulate_key_size
}
#[inline]
pub(crate) fn encapsulate_key_size(&self) -> usize {
self.encapsulate_key_size
}
#[inline]
pub(crate) fn ciphertext_size(&self) -> usize {
self.ciphertext_size
}
#[inline]
pub(crate) fn shared_secret_size(&self) -> usize {
self.shared_secret_size
}
}
impl<Id> Debug for Algorithm<Id>
where
Id: AlgorithmIdentifier,
{
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
Debug::fmt(&self.id, f)
}
}
pub struct DecapsulationKey<Id = AlgorithmId>
where
Id: AlgorithmIdentifier,
{
algorithm: &'static Algorithm<Id>,
evp_pkey: LcPtr<EVP_PKEY>,
}
#[non_exhaustive]
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum AlgorithmId {
MlKem512,
MlKem768,
MlKem1024,
}
impl AlgorithmIdentifier for AlgorithmId {
fn nid(self) -> i32 {
match self {
AlgorithmId::MlKem512 => NID_MLKEM512,
AlgorithmId::MlKem768 => NID_MLKEM768,
AlgorithmId::MlKem1024 => NID_MLKEM1024,
}
}
}
impl crate::sealed::Sealed for AlgorithmId {}
impl<Id> DecapsulationKey<Id>
where
Id: AlgorithmIdentifier,
{
pub fn generate(alg: &'static Algorithm<Id>) -> Result<Self, Unspecified> {
let mut secret_key_size = alg.decapsulate_key_size();
let mut priv_key_bytes = vec![0u8; secret_key_size];
let kyber_key = kem_key_generate(alg.id.nid())?;
if 1 != unsafe {
EVP_PKEY_get_raw_private_key(
*kyber_key.as_const(),
priv_key_bytes.as_mut_ptr(),
&mut secret_key_size,
)
} {
return Err(Unspecified);
}
Ok(DecapsulationKey {
algorithm: alg,
evp_pkey: kyber_key,
})
}
#[must_use]
pub fn algorithm(&self) -> &'static Algorithm<Id> {
self.algorithm
}
#[allow(clippy::missing_panics_doc)]
pub fn encapsulation_key(&self) -> Result<EncapsulationKey<Id>, Unspecified> {
let evp_pkey = self.evp_pkey.clone();
Ok(EncapsulationKey {
algorithm: self.algorithm,
evp_pkey,
})
}
#[allow(clippy::needless_pass_by_value)]
pub fn decapsulate(&self, ciphertext: Ciphertext<'_>) -> Result<SharedSecret, Unspecified> {
let mut shared_secret_len = self.algorithm.shared_secret_size();
let mut shared_secret: Vec<u8> = vec![0u8; shared_secret_len];
let mut ctx = self.evp_pkey.create_EVP_PKEY_CTX()?;
let ciphertext = ciphertext.as_ref();
if 1 != unsafe {
EVP_PKEY_decapsulate(
*ctx.as_mut(),
shared_secret.as_mut_ptr(),
&mut shared_secret_len,
ciphertext.as_ptr() as *mut u8,
ciphertext.len(),
)
} {
return Err(Unspecified);
}
debug_assert_eq!(shared_secret_len, shared_secret.len());
shared_secret.truncate(shared_secret_len);
Ok(SharedSecret(shared_secret.into_boxed_slice()))
}
}
unsafe impl<Id> Send for DecapsulationKey<Id> where Id: AlgorithmIdentifier {}
unsafe impl<Id> Sync for DecapsulationKey<Id> where Id: AlgorithmIdentifier {}
impl<Id> Debug for DecapsulationKey<Id>
where
Id: AlgorithmIdentifier,
{
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
f.debug_struct("DecapsulationKey")
.field("algorithm", &self.algorithm)
.finish_non_exhaustive()
}
}
use paste::paste;
generated_encodings!(EncapsulationKeyBytes);
pub struct EncapsulationKey<Id = AlgorithmId>
where
Id: AlgorithmIdentifier,
{
algorithm: &'static Algorithm<Id>,
evp_pkey: LcPtr<EVP_PKEY>,
}
impl<Id> EncapsulationKey<Id>
where
Id: AlgorithmIdentifier,
{
#[must_use]
pub fn algorithm(&self) -> &'static Algorithm<Id> {
self.algorithm
}
pub fn encapsulate(&self) -> Result<(Ciphertext<'static>, SharedSecret), Unspecified> {
let mut ciphertext_len = self.algorithm.ciphertext_size();
let mut shared_secret_len = self.algorithm.shared_secret_size();
let mut ciphertext: Vec<u8> = vec![0u8; ciphertext_len];
let mut shared_secret: Vec<u8> = vec![0u8; shared_secret_len];
let mut ctx = self.evp_pkey.create_EVP_PKEY_CTX()?;
if 1 != unsafe {
EVP_PKEY_encapsulate(
*ctx.as_mut(),
ciphertext.as_mut_ptr(),
&mut ciphertext_len,
shared_secret.as_mut_ptr(),
&mut shared_secret_len,
)
} {
return Err(Unspecified);
}
debug_assert_eq!(ciphertext_len, ciphertext.len());
ciphertext.truncate(ciphertext_len);
debug_assert_eq!(shared_secret_len, shared_secret.len());
shared_secret.truncate(shared_secret_len);
Ok((
Ciphertext::new(ciphertext),
SharedSecret::new(shared_secret.into_boxed_slice()),
))
}
pub fn key_bytes(&self) -> Result<EncapsulationKeyBytes<'static>, Unspecified> {
let mut encapsulate_key_size = self.algorithm.encapsulate_key_size();
let mut encapsulate_bytes = vec![0u8; encapsulate_key_size];
if 1 != unsafe {
EVP_PKEY_get_raw_public_key(
*self.evp_pkey.as_const(),
encapsulate_bytes.as_mut_ptr(),
&mut encapsulate_key_size,
)
} {
return Err(Unspecified);
}
debug_assert_eq!(encapsulate_key_size, encapsulate_bytes.len());
encapsulate_bytes.truncate(encapsulate_key_size);
Ok(EncapsulationKeyBytes::new(encapsulate_bytes))
}
pub fn new(alg: &'static Algorithm<Id>, bytes: &[u8]) -> Result<Self, KeyRejected> {
match bytes.len().cmp(&alg.encapsulate_key_size()) {
Ordering::Less => Err(KeyRejected::too_small()),
Ordering::Greater => Err(KeyRejected::too_large()),
Ordering::Equal => Ok(()),
}?;
let pubkey = LcPtr::new(unsafe {
EVP_PKEY_kem_new_raw_public_key(alg.id.nid(), bytes.as_ptr(), bytes.len())
})?;
Ok(EncapsulationKey {
algorithm: alg,
evp_pkey: pubkey,
})
}
}
unsafe impl<Id> Send for EncapsulationKey<Id> where Id: AlgorithmIdentifier {}
unsafe impl<Id> Sync for EncapsulationKey<Id> where Id: AlgorithmIdentifier {}
impl<Id> Debug for EncapsulationKey<Id>
where
Id: AlgorithmIdentifier,
{
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
f.debug_struct("EncapsulationKey")
.field("algorithm", &self.algorithm)
.finish_non_exhaustive()
}
}
pub struct Ciphertext<'a>(Cow<'a, [u8]>);
impl<'a> Ciphertext<'a> {
fn new(value: Vec<u8>) -> Ciphertext<'a> {
Self(Cow::Owned(value))
}
}
impl Drop for Ciphertext<'_> {
fn drop(&mut self) {
if let Cow::Owned(ref mut v) = self.0 {
v.zeroize();
}
}
}
impl AsRef<[u8]> for Ciphertext<'_> {
fn as_ref(&self) -> &[u8] {
match self.0 {
Cow::Borrowed(v) => v,
Cow::Owned(ref v) => v.as_ref(),
}
}
}
impl<'a> From<&'a [u8]> for Ciphertext<'a> {
fn from(value: &'a [u8]) -> Self {
Self(Cow::Borrowed(value))
}
}
pub struct SharedSecret(Box<[u8]>);
impl SharedSecret {
fn new(value: Box<[u8]>) -> Self {
Self(value)
}
}
impl Drop for SharedSecret {
fn drop(&mut self) {
self.0.zeroize();
}
}
impl AsRef<[u8]> for SharedSecret {
fn as_ref(&self) -> &[u8] {
self.0.as_ref()
}
}
#[inline]
fn kem_key_generate(nid: i32) -> Result<LcPtr<EVP_PKEY>, Unspecified> {
let mut ctx = LcPtr::new(unsafe { EVP_PKEY_CTX_new_id(EVP_PKEY_KEM, null_mut()) })?;
if 1 != unsafe { EVP_PKEY_CTX_kem_set_params(*ctx.as_mut(), nid) }
|| 1 != unsafe { EVP_PKEY_keygen_init(*ctx.as_mut()) }
{
return Err(Unspecified);
}
let mut key_raw: *mut EVP_PKEY = null_mut();
if 1 != unsafe { EVP_PKEY_keygen(*ctx.as_mut(), &mut key_raw) } {
return Err(Unspecified);
}
Ok(LcPtr::new(key_raw)?)
}
#[cfg(test)]
mod tests {
use super::{Ciphertext, SharedSecret};
#[cfg(not(feature = "fips"))]
use crate::error::KeyRejected;
#[cfg(not(feature = "fips"))]
use super::{DecapsulationKey, EncapsulationKey};
#[cfg(not(feature = "fips"))]
use crate::kem::semistable::{ML_KEM_1024, ML_KEM_512, ML_KEM_768};
#[test]
fn ciphertext() {
let ciphertext_bytes = vec![42u8; 4];
let ciphertext = Ciphertext::from(ciphertext_bytes.as_ref());
assert_eq!(ciphertext.as_ref(), &[42, 42, 42, 42]);
drop(ciphertext);
let ciphertext_bytes = vec![42u8; 4];
let ciphertext = Ciphertext::<'static>::new(ciphertext_bytes);
assert_eq!(ciphertext.as_ref(), &[42, 42, 42, 42]);
}
#[test]
fn shared_secret() {
let secret_bytes = vec![42u8; 4];
let shared_secret = SharedSecret::new(secret_bytes.into_boxed_slice());
assert_eq!(shared_secret.as_ref(), &[42, 42, 42, 42]);
}
#[test]
#[cfg(not(feature = "fips"))]
fn test_kem_serialize() {
for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
let priv_key = DecapsulationKey::generate(algorithm).unwrap();
assert_eq!(priv_key.algorithm(), algorithm);
let pub_key = priv_key.encapsulation_key().unwrap();
let pubkey_raw_bytes = pub_key.key_bytes().unwrap();
let pub_key_from_bytes =
EncapsulationKey::new(algorithm, pubkey_raw_bytes.as_ref()).unwrap();
assert_eq!(
pub_key.key_bytes().unwrap().as_ref(),
pub_key_from_bytes.key_bytes().unwrap().as_ref()
);
assert_eq!(pub_key.algorithm(), pub_key_from_bytes.algorithm());
}
}
#[test]
#[cfg(not(feature = "fips"))]
fn test_kem_wrong_sizes() {
for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
let too_long_bytes = vec![0u8; algorithm.encapsulate_key_size() + 1];
let long_pub_key_from_bytes = EncapsulationKey::new(algorithm, &too_long_bytes);
assert_eq!(
long_pub_key_from_bytes.err(),
Some(KeyRejected::too_large())
);
let too_short_bytes = vec![0u8; algorithm.encapsulate_key_size() - 1];
let short_pub_key_from_bytes = EncapsulationKey::new(algorithm, &too_short_bytes);
assert_eq!(
short_pub_key_from_bytes.err(),
Some(KeyRejected::too_small())
);
}
}
#[test]
#[cfg(not(feature = "fips"))]
fn test_kem_e2e() {
for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
let priv_key = DecapsulationKey::generate(algorithm).unwrap();
assert_eq!(priv_key.algorithm(), algorithm);
let pub_key = priv_key.encapsulation_key().unwrap();
let (alice_ciphertext, alice_secret) =
pub_key.encapsulate().expect("encapsulate successful");
let bob_secret = priv_key
.decapsulate(alice_ciphertext)
.expect("decapsulate successful");
assert_eq!(alice_secret.as_ref(), bob_secret.as_ref());
}
}
#[test]
#[cfg(not(feature = "fips"))]
fn test_serialized_kem_e2e() {
for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
let priv_key = DecapsulationKey::generate(algorithm).unwrap();
assert_eq!(priv_key.algorithm(), algorithm);
let pub_key = priv_key.encapsulation_key().unwrap();
let pub_key_bytes = pub_key.key_bytes().unwrap();
drop(pub_key);
let retrieved_pub_key =
EncapsulationKey::new(algorithm, pub_key_bytes.as_ref()).unwrap();
let (ciphertext, bob_secret) = retrieved_pub_key
.encapsulate()
.expect("encapsulate successful");
let alice_secret = priv_key
.decapsulate(ciphertext)
.expect("decapsulate successful");
assert_eq!(alice_secret.as_ref(), bob_secret.as_ref());
}
}
#[test]
#[cfg(not(feature = "fips"))]
fn test_debug_fmt() {
let private = DecapsulationKey::generate(&ML_KEM_512).expect("successful generation");
assert_eq!(
format!("{private:?}"),
"DecapsulationKey { algorithm: MlKem512, .. }"
);
assert_eq!(
format!(
"{:?}",
private.encapsulation_key().expect("public key retrievable")
),
"EncapsulationKey { algorithm: MlKem512, .. }"
);
}
}