use crate::{Algorithm, EcdsaCurve, Error, Result};
use core::fmt;
use encoding::{CheckedSum, Decode, Encode, Reader, Writer};
use sec1::consts::{U32, U48, U66};
pub type EcdsaNistP256PublicKey = sec1::EncodedPoint<U32>;
pub type EcdsaNistP384PublicKey = sec1::EncodedPoint<U48>;
pub type EcdsaNistP521PublicKey = sec1::EncodedPoint<U66>;
#[derive(Copy, Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub enum EcdsaPublicKey {
NistP256(EcdsaNistP256PublicKey),
NistP384(EcdsaNistP384PublicKey),
NistP521(EcdsaNistP521PublicKey),
}
impl EcdsaPublicKey {
const MAX_SIZE: usize = 133;
pub fn from_sec1_bytes(bytes: &[u8]) -> Result<Self> {
match bytes {
[tag, rest @ ..] => {
let point_size = match sec1::point::Tag::from_u8(*tag)? {
sec1::point::Tag::CompressedEvenY | sec1::point::Tag::CompressedOddY => {
rest.len()
}
sec1::point::Tag::Uncompressed => rest.len() / 2,
_ => return Err(Error::AlgorithmUnknown),
};
match point_size {
32 => Ok(Self::NistP256(EcdsaNistP256PublicKey::from_bytes(bytes)?)),
48 => Ok(Self::NistP384(EcdsaNistP384PublicKey::from_bytes(bytes)?)),
66 => Ok(Self::NistP521(EcdsaNistP521PublicKey::from_bytes(bytes)?)),
_ => Err(encoding::Error::Length.into()),
}
}
_ => Err(encoding::Error::Length.into()),
}
}
pub fn as_sec1_bytes(&self) -> &[u8] {
match self {
EcdsaPublicKey::NistP256(point) => point.as_bytes(),
EcdsaPublicKey::NistP384(point) => point.as_bytes(),
EcdsaPublicKey::NistP521(point) => point.as_bytes(),
}
}
pub fn algorithm(&self) -> Algorithm {
Algorithm::Ecdsa {
curve: self.curve(),
}
}
pub fn curve(&self) -> EcdsaCurve {
match self {
EcdsaPublicKey::NistP256(_) => EcdsaCurve::NistP256,
EcdsaPublicKey::NistP384(_) => EcdsaCurve::NistP384,
EcdsaPublicKey::NistP521(_) => EcdsaCurve::NistP521,
}
}
}
impl AsRef<[u8]> for EcdsaPublicKey {
fn as_ref(&self) -> &[u8] {
self.as_sec1_bytes()
}
}
impl Decode for EcdsaPublicKey {
type Error = Error;
fn decode(reader: &mut impl Reader) -> Result<Self> {
let curve = EcdsaCurve::decode(reader)?;
let mut buf = [0u8; Self::MAX_SIZE];
let key = Self::from_sec1_bytes(reader.read_byten(&mut buf)?)?;
if key.curve() == curve {
Ok(key)
} else {
Err(Error::AlgorithmUnknown)
}
}
}
impl Encode for EcdsaPublicKey {
fn encoded_len(&self) -> encoding::Result<usize> {
[
self.curve().encoded_len()?,
4, self.as_ref().len(),
]
.checked_sum()
}
fn encode(&self, writer: &mut impl Writer) -> encoding::Result<()> {
self.curve().encode(writer)?;
self.as_ref().encode(writer)?;
Ok(())
}
}
impl fmt::Display for EcdsaPublicKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{self:X}")
}
}
impl fmt::LowerHex for EcdsaPublicKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for byte in self.as_sec1_bytes() {
write!(f, "{byte:02x}")?;
}
Ok(())
}
}
impl fmt::UpperHex for EcdsaPublicKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for byte in self.as_sec1_bytes() {
write!(f, "{byte:02X}")?;
}
Ok(())
}
}
macro_rules! impl_ecdsa_for_curve {
($krate:ident, $feature:expr, $curve:ident) => {
#[cfg(feature = $feature)]
impl TryFrom<EcdsaPublicKey> for $krate::ecdsa::VerifyingKey {
type Error = Error;
fn try_from(key: EcdsaPublicKey) -> Result<$krate::ecdsa::VerifyingKey> {
$krate::ecdsa::VerifyingKey::try_from(&key)
}
}
#[cfg(feature = $feature)]
impl TryFrom<&EcdsaPublicKey> for $krate::ecdsa::VerifyingKey {
type Error = Error;
fn try_from(public_key: &EcdsaPublicKey) -> Result<$krate::ecdsa::VerifyingKey> {
match public_key {
EcdsaPublicKey::$curve(key) => {
$krate::ecdsa::VerifyingKey::from_encoded_point(key)
.map_err(|_| Error::Crypto)
}
_ => Err(Error::AlgorithmUnknown),
}
}
}
#[cfg(feature = $feature)]
impl From<$krate::ecdsa::VerifyingKey> for EcdsaPublicKey {
fn from(key: $krate::ecdsa::VerifyingKey) -> EcdsaPublicKey {
EcdsaPublicKey::from(&key)
}
}
#[cfg(feature = $feature)]
impl From<&$krate::ecdsa::VerifyingKey> for EcdsaPublicKey {
fn from(key: &$krate::ecdsa::VerifyingKey) -> EcdsaPublicKey {
EcdsaPublicKey::$curve(key.to_encoded_point(false))
}
}
};
}
impl_ecdsa_for_curve!(p256, "p256", NistP256);
impl_ecdsa_for_curve!(p384, "p384", NistP384);
impl_ecdsa_for_curve!(p521, "p521", NistP521);