use core::marker::PhantomData;
use hybrid_array::typenum::U32;
use rand_core::CryptoRngCore;
use crate::crypto::{rand, G, H, J};
use crate::param::{DecapsulationKeySize, EncapsulationKeySize, EncodedCiphertext, KemParams};
use crate::pke::{DecryptionKey, EncryptionKey};
use crate::util::B32;
use crate::{Encoded, EncodedSizeUser};
#[cfg(feature = "zeroize")]
use zeroize::{Zeroize, ZeroizeOnDrop};
pub use ::kem::{Decapsulate, Encapsulate};
pub(crate) type SharedKey = B32;
#[derive(Clone, Debug, PartialEq)]
pub struct DecapsulationKey<P>
where
P: KemParams,
{
dk_pke: DecryptionKey<P>,
ek: EncapsulationKey<P>,
z: B32,
}
#[cfg(feature = "zeroize")]
impl<P> Drop for DecapsulationKey<P>
where
P: KemParams,
{
fn drop(&mut self) {
self.dk_pke.zeroize();
self.z.zeroize();
}
}
#[cfg(feature = "zeroize")]
impl<P> ZeroizeOnDrop for DecapsulationKey<P> where P: KemParams {}
impl<P> EncodedSizeUser for DecapsulationKey<P>
where
P: KemParams,
{
type EncodedSize = DecapsulationKeySize<P>;
#[allow(clippy::similar_names)] fn from_bytes(enc: &Encoded<Self>) -> Self {
let (dk_pke, ek_pke, h, z) = P::split_dk(enc);
let ek_pke = EncryptionKey::from_bytes(ek_pke);
Self {
dk_pke: DecryptionKey::from_bytes(dk_pke),
ek: EncapsulationKey {
ek_pke,
h: h.clone(),
},
z: z.clone(),
}
}
fn as_bytes(&self) -> Encoded<Self> {
let dk_pke = self.dk_pke.as_bytes();
let ek = self.ek.as_bytes();
P::concat_dk(dk_pke, ek, self.ek.h.clone(), self.z.clone())
}
}
fn constant_time_eq(x: u8, y: u8) -> u8 {
let diff = x ^ y;
let is_zero = !diff & diff.wrapping_sub(1);
0u8.wrapping_sub(is_zero >> 7)
}
impl<P> ::kem::Decapsulate<EncodedCiphertext<P>, SharedKey> for DecapsulationKey<P>
where
P: KemParams,
{
type Error = ();
fn decapsulate(&self, encapsulated_key: &EncodedCiphertext<P>) -> Result<SharedKey, ()> {
let mp = self.dk_pke.decrypt(encapsulated_key);
let (Kp, rp) = G(&[&mp, &self.ek.h]);
let Kbar = J(&[self.z.as_slice(), encapsulated_key.as_ref()]);
let cp = self.ek.ek_pke.encrypt(&mp, &rp);
let equal = cp
.iter()
.zip(encapsulated_key.iter())
.map(|(&x, &y)| constant_time_eq(x, y))
.fold(0xff, |x, y| x & y);
Ok(Kp
.iter()
.zip(Kbar.iter())
.map(|(x, y)| (equal & x) | (!equal & y))
.collect())
}
}
impl<P> DecapsulationKey<P>
where
P: KemParams,
{
pub fn encapsulation_key(&self) -> &EncapsulationKey<P> {
&self.ek
}
pub(crate) fn generate(rng: &mut impl CryptoRngCore) -> Self {
let d: B32 = rand(rng);
let z: B32 = rand(rng);
Self::generate_deterministic(&d, &z)
}
#[must_use]
#[allow(clippy::similar_names)] pub(crate) fn generate_deterministic(d: &B32, z: &B32) -> Self {
let (dk_pke, ek_pke) = DecryptionKey::generate(d);
let ek = EncapsulationKey::new(ek_pke);
let z = z.clone();
Self { dk_pke, ek, z }
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct EncapsulationKey<P>
where
P: KemParams,
{
ek_pke: EncryptionKey<P>,
h: B32,
}
impl<P> EncapsulationKey<P>
where
P: KemParams,
{
fn new(ek_pke: EncryptionKey<P>) -> Self {
let h = H(ek_pke.as_bytes());
Self { ek_pke, h }
}
fn encapsulate_deterministic_inner(&self, m: &B32) -> (EncodedCiphertext<P>, SharedKey) {
let (K, r) = G(&[m, &self.h]);
let c = self.ek_pke.encrypt(m, &r);
(c, K)
}
}
impl<P> EncodedSizeUser for EncapsulationKey<P>
where
P: KemParams,
{
type EncodedSize = EncapsulationKeySize<P>;
fn from_bytes(enc: &Encoded<Self>) -> Self {
Self::new(EncryptionKey::from_bytes(enc))
}
fn as_bytes(&self) -> Encoded<Self> {
self.ek_pke.as_bytes()
}
}
impl<P> ::kem::Encapsulate<EncodedCiphertext<P>, SharedKey> for EncapsulationKey<P>
where
P: KemParams,
{
type Error = ();
fn encapsulate(
&self,
rng: &mut impl CryptoRngCore,
) -> Result<(EncodedCiphertext<P>, SharedKey), Self::Error> {
let m: B32 = rand(rng);
Ok(self.encapsulate_deterministic_inner(&m))
}
}
#[cfg(feature = "deterministic")]
impl<P> crate::EncapsulateDeterministic<EncodedCiphertext<P>, SharedKey> for EncapsulationKey<P>
where
P: KemParams,
{
type Error = ();
fn encapsulate_deterministic(
&self,
m: &B32,
) -> Result<(EncodedCiphertext<P>, SharedKey), Self::Error> {
Ok(self.encapsulate_deterministic_inner(m))
}
}
pub struct Kem<P>
where
P: KemParams,
{
_phantom: PhantomData<P>,
}
impl<P> crate::KemCore for Kem<P>
where
P: KemParams,
{
type SharedKeySize = U32;
type CiphertextSize = P::CiphertextSize;
type DecapsulationKey = DecapsulationKey<P>;
type EncapsulationKey = EncapsulationKey<P>;
fn generate(rng: &mut impl CryptoRngCore) -> (Self::DecapsulationKey, Self::EncapsulationKey) {
let dk = Self::DecapsulationKey::generate(rng);
let ek = dk.encapsulation_key().clone();
(dk, ek)
}
#[cfg(feature = "deterministic")]
fn generate_deterministic(
d: &B32,
z: &B32,
) -> (Self::DecapsulationKey, Self::EncapsulationKey) {
let dk = Self::DecapsulationKey::generate_deterministic(d, z);
let ek = dk.encapsulation_key().clone();
(dk, ek)
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{MlKem1024Params, MlKem512Params, MlKem768Params};
use ::kem::{Decapsulate, Encapsulate};
fn round_trip_test<P>()
where
P: KemParams,
{
let mut rng = rand::thread_rng();
let dk = DecapsulationKey::<P>::generate(&mut rng);
let ek = dk.encapsulation_key();
let (ct, k_send) = ek.encapsulate(&mut rng).unwrap();
let k_recv = dk.decapsulate(&ct).unwrap();
assert_eq!(k_send, k_recv);
}
#[test]
fn round_trip() {
round_trip_test::<MlKem512Params>();
round_trip_test::<MlKem768Params>();
round_trip_test::<MlKem1024Params>();
}
fn codec_test<P>()
where
P: KemParams,
{
let mut rng = rand::thread_rng();
let dk_original = DecapsulationKey::<P>::generate(&mut rng);
let ek_original = dk_original.encapsulation_key().clone();
let dk_encoded = dk_original.as_bytes();
let dk_decoded = DecapsulationKey::from_bytes(&dk_encoded);
assert_eq!(dk_original, dk_decoded);
let ek_encoded = ek_original.as_bytes();
let ek_decoded = EncapsulationKey::from_bytes(&ek_encoded);
assert_eq!(ek_original, ek_decoded);
}
#[test]
fn codec() {
codec_test::<MlKem512Params>();
codec_test::<MlKem768Params>();
codec_test::<MlKem1024Params>();
}
}