use super::error::*;
use asn1_der::typed::{DerDecodable, DerEncodable, DerTypeView, Sequence};
use asn1_der::{Asn1DerError, Asn1DerErrorVariant, DerObject, Sink, VecBacking};
use ring::rand::SystemRandom;
use ring::signature::KeyPair;
use ring::signature::{self, RsaKeyPair, RSA_PKCS1_2048_8192_SHA256, RSA_PKCS1_SHA256};
use std::{fmt, sync::Arc};
use zeroize::Zeroize;
#[derive(Clone)]
pub struct Keypair(Arc<RsaKeyPair>);
impl std::fmt::Debug for Keypair {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("Keypair")
.field("public", self.0.public_key())
.finish()
}
}
impl Keypair {
pub fn try_decode_pkcs1(der: &mut [u8]) -> Result<Keypair, DecodingError> {
let kp = RsaKeyPair::from_der(der)
.map_err(|e| DecodingError::failed_to_parse("RSA DER PKCS#1 RSAPrivateKey", e))?;
der.zeroize();
Ok(Keypair(Arc::new(kp)))
}
pub fn try_decode_pkcs8(der: &mut [u8]) -> Result<Keypair, DecodingError> {
let kp = RsaKeyPair::from_pkcs8(der)
.map_err(|e| DecodingError::failed_to_parse("RSA PKCS#8 PrivateKeyInfo", e))?;
der.zeroize();
Ok(Keypair(Arc::new(kp)))
}
pub fn public(&self) -> PublicKey {
PublicKey(self.0.public_key().as_ref().to_vec())
}
pub fn sign(&self, data: &[u8]) -> Result<Vec<u8>, SigningError> {
let mut signature = vec![0; self.0.public().modulus_len()];
let rng = SystemRandom::new();
match self.0.sign(&RSA_PKCS1_SHA256, &rng, data, &mut signature) {
Ok(()) => Ok(signature),
Err(e) => Err(SigningError::new("RSA").source(e)),
}
}
}
#[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct PublicKey(Vec<u8>);
impl PublicKey {
pub fn verify(&self, msg: &[u8], sig: &[u8]) -> bool {
let key = signature::UnparsedPublicKey::new(&RSA_PKCS1_2048_8192_SHA256, &self.0);
key.verify(msg, sig).is_ok()
}
pub fn encode_pkcs1(&self) -> Vec<u8> {
self.0.clone()
}
pub fn encode_x509(&self) -> Vec<u8> {
let spki = Asn1SubjectPublicKeyInfo {
algorithmIdentifier: Asn1RsaEncryption {
algorithm: Asn1OidRsaEncryption,
parameters: (),
},
subjectPublicKey: Asn1SubjectPublicKey(self.clone()),
};
let mut buf = Vec::new();
spki.encode(&mut buf)
.map(|_| buf)
.expect("RSA X.509 public key encoding failed.")
}
pub fn try_decode_x509(pk: &[u8]) -> Result<PublicKey, DecodingError> {
Asn1SubjectPublicKeyInfo::decode(pk)
.map_err(|e| DecodingError::failed_to_parse("RSA X.509", e))
.map(|spki| spki.subjectPublicKey.0)
}
}
impl fmt::Debug for PublicKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("PublicKey(PKCS1): ")?;
for byte in &self.0 {
write!(f, "{byte:x}")?;
}
Ok(())
}
}
#[derive(Copy, Clone)]
struct Asn1RawOid<'a> {
object: DerObject<'a>,
}
impl Asn1RawOid<'_> {
pub(crate) fn oid(&self) -> &[u8] {
self.object.value()
}
pub(crate) fn write<S: Sink>(value: &[u8], sink: &mut S) -> Result<(), Asn1DerError> {
DerObject::write(Self::TAG, value.len(), &mut value.iter(), sink)
}
}
impl<'a> DerTypeView<'a> for Asn1RawOid<'a> {
const TAG: u8 = 6;
fn object(&self) -> DerObject<'a> {
self.object
}
}
impl DerEncodable for Asn1RawOid<'_> {
fn encode<S: Sink>(&self, sink: &mut S) -> Result<(), Asn1DerError> {
self.object.encode(sink)
}
}
impl<'a> DerDecodable<'a> for Asn1RawOid<'a> {
fn load(object: DerObject<'a>) -> Result<Self, Asn1DerError> {
if object.tag() != Self::TAG {
return Err(Asn1DerError::new(Asn1DerErrorVariant::InvalidData(
"DER object tag is not the object identifier tag.",
)));
}
Ok(Self { object })
}
}
#[derive(Clone)]
struct Asn1OidRsaEncryption;
impl Asn1OidRsaEncryption {
const OID: [u8; 9] = [0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x01];
}
impl DerEncodable for Asn1OidRsaEncryption {
fn encode<S: Sink>(&self, sink: &mut S) -> Result<(), Asn1DerError> {
Asn1RawOid::write(&Self::OID, sink)
}
}
impl DerDecodable<'_> for Asn1OidRsaEncryption {
fn load(object: DerObject<'_>) -> Result<Self, Asn1DerError> {
match Asn1RawOid::load(object)?.oid() {
oid if oid == Self::OID => Ok(Self),
_ => Err(Asn1DerError::new(Asn1DerErrorVariant::InvalidData(
"DER object is not the 'rsaEncryption' identifier.",
))),
}
}
}
struct Asn1RsaEncryption {
algorithm: Asn1OidRsaEncryption,
parameters: (),
}
impl DerEncodable for Asn1RsaEncryption {
fn encode<S: Sink>(&self, sink: &mut S) -> Result<(), Asn1DerError> {
let mut algorithm_buf = Vec::new();
let algorithm = self.algorithm.der_object(VecBacking(&mut algorithm_buf))?;
let mut parameters_buf = Vec::new();
let parameters = self
.parameters
.der_object(VecBacking(&mut parameters_buf))?;
Sequence::write(&[algorithm, parameters], sink)
}
}
impl DerDecodable<'_> for Asn1RsaEncryption {
fn load(object: DerObject<'_>) -> Result<Self, Asn1DerError> {
let seq: Sequence = Sequence::load(object)?;
Ok(Self {
algorithm: seq.get_as(0)?,
parameters: seq.get_as(1)?,
})
}
}
struct Asn1SubjectPublicKey(PublicKey);
impl DerEncodable for Asn1SubjectPublicKey {
fn encode<S: Sink>(&self, sink: &mut S) -> Result<(), Asn1DerError> {
let pk_der = &(self.0).0;
let mut bit_string = Vec::with_capacity(pk_der.len() + 1);
bit_string.push(0u8);
bit_string.extend(pk_der);
DerObject::write(3, bit_string.len(), &mut bit_string.iter(), sink)?;
Ok(())
}
}
impl DerDecodable<'_> for Asn1SubjectPublicKey {
fn load(object: DerObject<'_>) -> Result<Self, Asn1DerError> {
if object.tag() != 3 {
return Err(Asn1DerError::new(Asn1DerErrorVariant::InvalidData(
"DER object tag is not the bit string tag.",
)));
}
let pk_der: Vec<u8> = object.value().iter().skip(1).cloned().collect();
Ok(Self(PublicKey(pk_der)))
}
}
#[allow(non_snake_case)]
struct Asn1SubjectPublicKeyInfo {
algorithmIdentifier: Asn1RsaEncryption,
subjectPublicKey: Asn1SubjectPublicKey,
}
impl DerEncodable for Asn1SubjectPublicKeyInfo {
fn encode<S: Sink>(&self, sink: &mut S) -> Result<(), Asn1DerError> {
let mut identifier_buf = Vec::new();
let identifier = self
.algorithmIdentifier
.der_object(VecBacking(&mut identifier_buf))?;
let mut key_buf = Vec::new();
let key = self.subjectPublicKey.der_object(VecBacking(&mut key_buf))?;
Sequence::write(&[identifier, key], sink)
}
}
impl DerDecodable<'_> for Asn1SubjectPublicKeyInfo {
fn load(object: DerObject<'_>) -> Result<Self, Asn1DerError> {
let seq: Sequence = Sequence::load(object)?;
Ok(Self {
algorithmIdentifier: seq.get_as(0)?,
subjectPublicKey: seq.get_as(1)?,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use quickcheck::*;
const KEY1: &[u8] = include_bytes!("test/rsa-2048.pk8");
const KEY2: &[u8] = include_bytes!("test/rsa-3072.pk8");
const KEY3: &[u8] = include_bytes!("test/rsa-4096.pk8");
#[derive(Clone, Debug)]
struct SomeKeypair(Keypair);
impl Arbitrary for SomeKeypair {
fn arbitrary(g: &mut Gen) -> SomeKeypair {
let mut key = g.choose(&[KEY1, KEY2, KEY3]).unwrap().to_vec();
SomeKeypair(Keypair::try_decode_pkcs8(&mut key).unwrap())
}
}
#[test]
fn rsa_from_pkcs8() {
assert!(Keypair::try_decode_pkcs8(&mut KEY1.to_vec()).is_ok());
assert!(Keypair::try_decode_pkcs8(&mut KEY2.to_vec()).is_ok());
assert!(Keypair::try_decode_pkcs8(&mut KEY3.to_vec()).is_ok());
}
#[test]
fn rsa_x509_encode_decode() {
fn prop(SomeKeypair(kp): SomeKeypair) -> Result<bool, String> {
let pk = kp.public();
PublicKey::try_decode_x509(&pk.encode_x509())
.map_err(|e| e.to_string())
.map(|pk2| pk2 == pk)
}
QuickCheck::new().tests(10).quickcheck(prop as fn(_) -> _);
}
#[test]
fn rsa_sign_verify() {
fn prop(SomeKeypair(kp): SomeKeypair, msg: Vec<u8>) -> Result<bool, SigningError> {
kp.sign(&msg).map(|s| kp.public().verify(&msg, &s))
}
QuickCheck::new()
.tests(10)
.quickcheck(prop as fn(_, _) -> _);
}
}