aws_lc_rs/rsa/encryption/
oaep.rs#![allow(clippy::module_name_repetitions)]
use super::{EncryptionAlgorithmId, PrivateDecryptingKey, PublicEncryptingKey};
use crate::{
error::Unspecified,
fips::indicator_check,
ptr::{DetachableLcPtr, LcPtr},
};
use aws_lc::{
EVP_PKEY_CTX_set0_rsa_oaep_label, EVP_PKEY_CTX_set_rsa_mgf1_md, EVP_PKEY_CTX_set_rsa_oaep_md,
EVP_PKEY_CTX_set_rsa_padding, EVP_PKEY_decrypt, EVP_PKEY_decrypt_init, EVP_PKEY_encrypt,
EVP_PKEY_encrypt_init, EVP_sha1, EVP_sha256, EVP_sha384, EVP_sha512, OPENSSL_malloc, EVP_MD,
EVP_PKEY_CTX, RSA_PKCS1_OAEP_PADDING,
};
use core::{fmt::Debug, mem::size_of_val, ptr::null_mut};
use mirai_annotations::verify_unreachable;
pub const OAEP_SHA1_MGF1SHA1: OaepAlgorithm = OaepAlgorithm {
id: EncryptionAlgorithmId::OaepSha1Mgf1sha1,
oaep_hash_fn: EVP_sha1,
mgf1_hash_fn: EVP_sha1,
};
pub const OAEP_SHA256_MGF1SHA256: OaepAlgorithm = OaepAlgorithm {
id: EncryptionAlgorithmId::OaepSha256Mgf1sha256,
oaep_hash_fn: EVP_sha256,
mgf1_hash_fn: EVP_sha256,
};
pub const OAEP_SHA384_MGF1SHA384: OaepAlgorithm = OaepAlgorithm {
id: EncryptionAlgorithmId::OaepSha384Mgf1sha384,
oaep_hash_fn: EVP_sha384,
mgf1_hash_fn: EVP_sha384,
};
pub const OAEP_SHA512_MGF1SHA512: OaepAlgorithm = OaepAlgorithm {
id: EncryptionAlgorithmId::OaepSha512Mgf1sha512,
oaep_hash_fn: EVP_sha512,
mgf1_hash_fn: EVP_sha512,
};
type OaepHashFn = unsafe extern "C" fn() -> *const EVP_MD;
type Mgf1HashFn = unsafe extern "C" fn() -> *const EVP_MD;
pub struct OaepAlgorithm {
id: EncryptionAlgorithmId,
oaep_hash_fn: OaepHashFn,
mgf1_hash_fn: Mgf1HashFn,
}
impl OaepAlgorithm {
#[must_use]
pub fn id(&self) -> EncryptionAlgorithmId {
self.id
}
#[inline]
fn oaep_hash_fn(&self) -> OaepHashFn {
self.oaep_hash_fn
}
#[inline]
fn mgf1_hash_fn(&self) -> Mgf1HashFn {
self.mgf1_hash_fn
}
}
impl Debug for OaepAlgorithm {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
Debug::fmt(&self.id, f)
}
}
pub struct OaepPublicEncryptingKey {
public_key: PublicEncryptingKey,
}
impl OaepPublicEncryptingKey {
pub fn new(public_key: PublicEncryptingKey) -> Result<Self, Unspecified> {
Ok(Self { public_key })
}
pub fn encrypt<'ciphertext>(
&self,
algorithm: &'static OaepAlgorithm,
plaintext: &[u8],
ciphertext: &'ciphertext mut [u8],
label: Option<&[u8]>,
) -> Result<&'ciphertext mut [u8], Unspecified> {
let mut pkey_ctx = self.public_key.0.create_EVP_PKEY_CTX()?;
if 1 != unsafe { EVP_PKEY_encrypt_init(*pkey_ctx.as_mut()) } {
return Err(Unspecified);
}
configure_oaep_crypto_operation(
&mut pkey_ctx,
algorithm.oaep_hash_fn(),
algorithm.mgf1_hash_fn(),
label,
)?;
let mut out_len = ciphertext.len();
if 1 != indicator_check!(unsafe {
EVP_PKEY_encrypt(
*pkey_ctx.as_mut(),
ciphertext.as_mut_ptr(),
&mut out_len,
plaintext.as_ptr(),
plaintext.len(),
)
}) {
return Err(Unspecified);
};
Ok(&mut ciphertext[..out_len])
}
#[must_use]
pub fn key_size_bytes(&self) -> usize {
self.public_key.key_size_bytes()
}
#[must_use]
pub fn key_size_bits(&self) -> usize {
self.public_key.key_size_bits()
}
#[must_use]
pub fn max_plaintext_size(&self, algorithm: &'static OaepAlgorithm) -> usize {
#[allow(unreachable_patterns)]
let hash_len: usize = match algorithm.id() {
EncryptionAlgorithmId::OaepSha1Mgf1sha1 => 20,
EncryptionAlgorithmId::OaepSha256Mgf1sha256 => 32,
EncryptionAlgorithmId::OaepSha384Mgf1sha384 => 48,
EncryptionAlgorithmId::OaepSha512Mgf1sha512 => 64,
_ => verify_unreachable!(),
};
self.key_size_bytes() - 2 * hash_len - 2
}
#[must_use]
pub fn ciphertext_size(&self) -> usize {
self.key_size_bytes()
}
}
impl Debug for OaepPublicEncryptingKey {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("OaepPublicEncryptingKey")
.finish_non_exhaustive()
}
}
pub struct OaepPrivateDecryptingKey {
private_key: PrivateDecryptingKey,
}
impl OaepPrivateDecryptingKey {
pub fn new(private_key: PrivateDecryptingKey) -> Result<Self, Unspecified> {
Ok(Self { private_key })
}
pub fn decrypt<'plaintext>(
&self,
algorithm: &'static OaepAlgorithm,
ciphertext: &[u8],
plaintext: &'plaintext mut [u8],
label: Option<&[u8]>,
) -> Result<&'plaintext mut [u8], Unspecified> {
let mut pkey_ctx = self.private_key.0.create_EVP_PKEY_CTX()?;
if 1 != unsafe { EVP_PKEY_decrypt_init(*pkey_ctx.as_mut()) } {
return Err(Unspecified);
}
configure_oaep_crypto_operation(
&mut pkey_ctx,
algorithm.oaep_hash_fn(),
algorithm.mgf1_hash_fn(),
label,
)?;
let mut out_len = plaintext.len();
if 1 != indicator_check!(unsafe {
EVP_PKEY_decrypt(
*pkey_ctx.as_mut(),
plaintext.as_mut_ptr(),
&mut out_len,
ciphertext.as_ptr(),
ciphertext.len(),
)
}) {
return Err(Unspecified);
};
Ok(&mut plaintext[..out_len])
}
#[must_use]
pub fn key_size_bytes(&self) -> usize {
self.private_key.key_size_bytes()
}
#[must_use]
pub fn key_size_bits(&self) -> usize {
self.private_key.key_size_bits()
}
#[must_use]
pub fn min_output_size(&self) -> usize {
self.key_size_bytes()
}
}
impl Debug for OaepPrivateDecryptingKey {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("OaepPrivateDecryptingKey")
.finish_non_exhaustive()
}
}
fn configure_oaep_crypto_operation(
evp_pkey_ctx: &mut LcPtr<EVP_PKEY_CTX>,
oaep_hash_fn: OaepHashFn,
mgf1_hash_fn: Mgf1HashFn,
label: Option<&[u8]>,
) -> Result<(), Unspecified> {
if 1 != unsafe { EVP_PKEY_CTX_set_rsa_padding(*evp_pkey_ctx.as_mut(), RSA_PKCS1_OAEP_PADDING) }
{
return Err(Unspecified);
};
if 1 != unsafe { EVP_PKEY_CTX_set_rsa_oaep_md(*evp_pkey_ctx.as_mut(), oaep_hash_fn()) } {
return Err(Unspecified);
};
if 1 != unsafe { EVP_PKEY_CTX_set_rsa_mgf1_md(*evp_pkey_ctx.as_mut(), mgf1_hash_fn()) } {
return Err(Unspecified);
};
let label = label.unwrap_or(&[0u8; 0]);
if label.is_empty() {
if 1 != unsafe { EVP_PKEY_CTX_set0_rsa_oaep_label(*evp_pkey_ctx.as_mut(), null_mut(), 0) } {
return Err(Unspecified);
}
return Ok(());
}
let mut label_ptr =
DetachableLcPtr::<u8>::new(unsafe { OPENSSL_malloc(size_of_val(label)) }.cast())?;
{
let label_ptr =
unsafe { core::slice::from_raw_parts_mut(*label_ptr.as_mut(), label.len()) };
label_ptr.copy_from_slice(label);
}
if 1 != unsafe {
EVP_PKEY_CTX_set0_rsa_oaep_label(*evp_pkey_ctx.as_mut(), *label_ptr, label.len())
} {
return Err(Unspecified);
};
label_ptr.detach();
Ok(())
}