use base64ct::{Base64UrlUnpadded, Encoding};
use serde::{de::DeserializeOwned, Serialize};
use core::{marker::PhantomData, num::NonZeroUsize};
#[cfg(feature = "ciborium")]
use crate::error::CborSerError;
use crate::{
alloc::{Cow, String, ToOwned, Vec},
token::CompleteHeader,
Claims, CreationError, Header, SignedToken, Token, UntrustedToken, ValidationError,
};
pub trait AlgorithmSignature: Sized {
const LENGTH: Option<NonZeroUsize> = None;
fn try_from_slice(slice: &[u8]) -> anyhow::Result<Self>;
fn as_bytes(&self) -> Cow<'_, [u8]>;
}
pub trait Algorithm {
type SigningKey;
type VerifyingKey;
type Signature: AlgorithmSignature;
fn name(&self) -> Cow<'static, str>;
fn sign(&self, signing_key: &Self::SigningKey, message: &[u8]) -> Self::Signature;
fn verify_signature(
&self,
signature: &Self::Signature,
verifying_key: &Self::VerifyingKey,
message: &[u8],
) -> bool;
}
#[derive(Debug, Clone, Copy)]
pub struct Renamed<A> {
inner: A,
name: &'static str,
}
impl<A: Algorithm> Renamed<A> {
pub fn new(algorithm: A, new_name: &'static str) -> Self {
Self {
inner: algorithm,
name: new_name,
}
}
}
impl<A: Algorithm> Algorithm for Renamed<A> {
type SigningKey = A::SigningKey;
type VerifyingKey = A::VerifyingKey;
type Signature = A::Signature;
fn name(&self) -> Cow<'static, str> {
Cow::Borrowed(self.name)
}
fn sign(&self, signing_key: &Self::SigningKey, message: &[u8]) -> Self::Signature {
self.inner.sign(signing_key, message)
}
fn verify_signature(
&self,
signature: &Self::Signature,
verifying_key: &Self::VerifyingKey,
message: &[u8],
) -> bool {
self.inner
.verify_signature(signature, verifying_key, message)
}
}
pub trait AlgorithmExt: Algorithm {
fn token<T>(
&self,
header: &Header<impl Serialize>,
claims: &Claims<T>,
signing_key: &Self::SigningKey,
) -> Result<String, CreationError>
where
T: Serialize;
#[cfg(feature = "ciborium")]
#[cfg_attr(docsrs, doc(cfg(feature = "ciborium")))]
fn compact_token<T>(
&self,
header: &Header<impl Serialize>,
claims: &Claims<T>,
signing_key: &Self::SigningKey,
) -> Result<String, CreationError>
where
T: Serialize;
fn validator<'a, T>(&'a self, verifying_key: &'a Self::VerifyingKey) -> Validator<'a, Self, T>;
#[deprecated = "Use `.validator().validate()` for added flexibility"]
fn validate_integrity<T>(
&self,
token: &UntrustedToken<'_>,
verifying_key: &Self::VerifyingKey,
) -> Result<Token<T>, ValidationError>
where
T: DeserializeOwned;
#[deprecated = "Use `.validator().validate_for_signed_token()` for added flexibility"]
fn validate_for_signed_token<T>(
&self,
token: &UntrustedToken<'_>,
verifying_key: &Self::VerifyingKey,
) -> Result<SignedToken<Self, T>, ValidationError>
where
T: DeserializeOwned;
}
impl<A: Algorithm> AlgorithmExt for A {
fn token<T>(
&self,
header: &Header<impl Serialize>,
claims: &Claims<T>,
signing_key: &Self::SigningKey,
) -> Result<String, CreationError>
where
T: Serialize,
{
let complete_header = CompleteHeader {
algorithm: self.name(),
content_type: None,
inner: header,
};
let header = serde_json::to_string(&complete_header).map_err(CreationError::Header)?;
let mut buffer = Vec::new();
encode_base64_buf(&header, &mut buffer);
let claims = serde_json::to_string(claims).map_err(CreationError::Claims)?;
buffer.push(b'.');
encode_base64_buf(&claims, &mut buffer);
let signature = self.sign(signing_key, &buffer);
buffer.push(b'.');
encode_base64_buf(signature.as_bytes(), &mut buffer);
Ok(unsafe { String::from_utf8_unchecked(buffer) })
}
#[cfg(feature = "ciborium")]
fn compact_token<T>(
&self,
header: &Header<impl Serialize>,
claims: &Claims<T>,
signing_key: &Self::SigningKey,
) -> Result<String, CreationError>
where
T: Serialize,
{
let complete_header = CompleteHeader {
algorithm: self.name(),
content_type: Some("CBOR".to_owned()),
inner: header,
};
let header = serde_json::to_string(&complete_header).map_err(CreationError::Header)?;
let mut buffer = Vec::new();
encode_base64_buf(&header, &mut buffer);
let mut serialized_claims = vec![];
ciborium::into_writer(claims, &mut serialized_claims).map_err(|err| {
CreationError::CborClaims(match err {
CborSerError::Value(message) => CborSerError::Value(message),
CborSerError::Io(_) => unreachable!(), })
})?;
buffer.push(b'.');
encode_base64_buf(&serialized_claims, &mut buffer);
let signature = self.sign(signing_key, &buffer);
buffer.push(b'.');
encode_base64_buf(signature.as_bytes(), &mut buffer);
Ok(unsafe { String::from_utf8_unchecked(buffer) })
}
fn validator<'a, T>(&'a self, verifying_key: &'a Self::VerifyingKey) -> Validator<'a, Self, T> {
Validator {
algorithm: self,
verifying_key,
_claims: PhantomData,
}
}
fn validate_integrity<T>(
&self,
token: &UntrustedToken<'_>,
verifying_key: &Self::VerifyingKey,
) -> Result<Token<T>, ValidationError>
where
T: DeserializeOwned,
{
self.validator::<T>(verifying_key).validate(token)
}
fn validate_for_signed_token<T>(
&self,
token: &UntrustedToken<'_>,
verifying_key: &Self::VerifyingKey,
) -> Result<SignedToken<Self, T>, ValidationError>
where
T: DeserializeOwned,
{
self.validator::<T>(verifying_key)
.validate_for_signed_token(token)
}
}
#[derive(Debug)]
pub struct Validator<'a, A: Algorithm + ?Sized, T> {
algorithm: &'a A,
verifying_key: &'a A::VerifyingKey,
_claims: PhantomData<fn() -> T>,
}
impl<A: Algorithm + ?Sized, T> Clone for Validator<'_, A, T> {
fn clone(&self) -> Self {
*self
}
}
impl<A: Algorithm + ?Sized, T> Copy for Validator<'_, A, T> {}
impl<A: Algorithm + ?Sized, T: DeserializeOwned> Validator<'_, A, T> {
pub fn validate<H: Clone>(
self,
token: &UntrustedToken<'_, H>,
) -> Result<Token<T, H>, ValidationError> {
self.validate_for_signed_token(token)
.map(|signed| signed.token)
}
pub fn validate_for_signed_token<H: Clone>(
self,
token: &UntrustedToken<'_, H>,
) -> Result<SignedToken<A, T, H>, ValidationError> {
let expected_alg = self.algorithm.name();
if expected_alg != token.algorithm() {
return Err(ValidationError::AlgorithmMismatch {
expected: expected_alg.into_owned(),
actual: token.algorithm().to_owned(),
});
}
let signature = token.signature_bytes();
if let Some(expected_len) = A::Signature::LENGTH {
if signature.len() != expected_len.get() {
return Err(ValidationError::InvalidSignatureLen {
expected: expected_len.get(),
actual: signature.len(),
});
}
}
let signature =
A::Signature::try_from_slice(signature).map_err(ValidationError::MalformedSignature)?;
let claims = token.deserialize_claims_unchecked::<T>()?;
if !self
.algorithm
.verify_signature(&signature, self.verifying_key, &token.signed_data)
{
return Err(ValidationError::InvalidSignature);
}
Ok(SignedToken {
signature,
token: Token::new(token.header().clone(), claims),
})
}
}
fn encode_base64_buf(source: impl AsRef<[u8]>, buffer: &mut Vec<u8>) {
let source = source.as_ref();
let previous_len = buffer.len();
let claims_len = Base64UrlUnpadded::encoded_len(source);
buffer.resize(previous_len + claims_len, 0);
Base64UrlUnpadded::encode(source, &mut buffer[previous_len..])
.expect("miscalculated base64-encoded length; this should never happen");
}