use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer};
use serde_bytes::ByteBuf;
use serde_cbor::Error as CborError;
use serde_cbor::Value as CborValue;
use crate::crypto::{Decryption, Encryption, Entropy};
use crate::error::CoseError;
use crate::header_map::{map_to_empty_or_serialized, HeaderMap};
const KTY: i8 = 1;
const IV: i8 = 5;
pub enum CipherConfiguration {
Gcm,
}
impl CipherConfiguration {
fn cose_alg(&self, key: &[u8]) -> Option<CoseAlgorithm> {
Some(match self {
CipherConfiguration::Gcm => match key.len() {
16 => CoseAlgorithm::AesGcm96_128_128,
24 => CoseAlgorithm::AesGcm96_128_192,
32 => CoseAlgorithm::AesGcm96_128_256,
_ => return None,
},
})
}
}
pub(crate) enum CoseAlgorithm {
AesGcm96_128_128,
AesGcm96_128_192,
AesGcm96_128_256,
}
impl CoseAlgorithm {
fn value(&self) -> usize {
match self {
CoseAlgorithm::AesGcm96_128_128 => 1,
CoseAlgorithm::AesGcm96_128_192 => 2,
CoseAlgorithm::AesGcm96_128_256 => 3,
}
}
fn from_value(value: i8) -> Option<CoseAlgorithm> {
Some(match value {
1 => CoseAlgorithm::AesGcm96_128_128,
2 => CoseAlgorithm::AesGcm96_128_192,
3 => CoseAlgorithm::AesGcm96_128_256,
_ => return None,
})
}
fn tag_size(&self) -> usize {
match self {
CoseAlgorithm::AesGcm96_128_128 => 16,
CoseAlgorithm::AesGcm96_128_192 => 16,
CoseAlgorithm::AesGcm96_128_256 => 16,
}
}
fn iv_len(&self) -> Option<usize> {
match self {
CoseAlgorithm::AesGcm96_128_128 => Some(12),
CoseAlgorithm::AesGcm96_128_192 => Some(12),
CoseAlgorithm::AesGcm96_128_256 => Some(12),
}
}
}
#[derive(Debug, Clone, Deserialize)]
struct EncStructure {
context: String,
protected: ByteBuf,
external_aad: ByteBuf,
}
impl Serialize for EncStructure {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut seq = serializer.serialize_seq(Some(3))?;
seq.serialize_element(&self.context)?;
seq.serialize_element(&self.protected)?;
seq.serialize_element(&self.external_aad)?;
seq.end()
}
}
impl EncStructure {
fn new_encrypt0(protected: &[u8]) -> Result<Self, CborError> {
Ok(EncStructure {
context: String::from("Encrypt0"),
protected: ByteBuf::from(protected.to_vec()),
external_aad: ByteBuf::new(),
})
}
fn as_bytes(&self) -> Result<Vec<u8>, CborError> {
serde_cbor::to_vec(self)
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct CoseEncrypt0 {
protected: ByteBuf,
unprotected: HeaderMap,
ciphertext: ByteBuf,
}
impl Serialize for CoseEncrypt0 {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut seq = serializer.serialize_seq(Some(3))?;
seq.serialize_element(&self.protected)?;
seq.serialize_element(&self.unprotected)?;
seq.serialize_element(&self.ciphertext)?;
seq.end()
}
}
impl CoseEncrypt0 {
pub fn new<C: Encryption + Entropy>(
payload: &[u8],
cipher_config: CipherConfiguration,
key: &[u8],
) -> Result<Self, CoseError> {
let cose_alg = match cipher_config.cose_alg(key) {
Some(v) => v,
None => {
return Err(CoseError::UnsupportedError(
"Unsupported encryption algorithm".to_string(),
))
}
};
let mut iv = vec![0; cose_alg.iv_len().unwrap()];
C::rand_bytes(&mut iv)?;
let cose_alg_value = cose_alg.value();
let mut protected = HeaderMap::new();
protected.insert(KTY.into(), CborValue::Integer(cose_alg_value as i128));
let mut unprotected = HeaderMap::new();
unprotected.insert(IV.into(), CborValue::Bytes(iv.to_owned()));
let protected_bytes =
map_to_empty_or_serialized(&protected).map_err(CoseError::SerializationError)?;
let enc_structure =
EncStructure::new_encrypt0(&protected_bytes).map_err(CoseError::SerializationError)?;
let mut tag = vec![0; cose_alg.tag_size()];
let mut ciphertext = C::encrypt_aead(
cose_alg.into(),
key,
Some(&iv[..]),
&enc_structure
.as_bytes()
.map_err(CoseError::SerializationError)?,
payload,
&mut tag,
)
.map_err(|e| CoseError::EncryptionError(Box::new(e)))?;
ciphertext.append(&mut tag);
Ok(CoseEncrypt0 {
protected: ByteBuf::from(protected_bytes),
unprotected,
ciphertext: ByteBuf::from(ciphertext),
})
}
pub fn decrypt<C: Decryption>(
&self,
key: &[u8],
) -> Result<(HeaderMap, &HeaderMap, Vec<u8>), CoseError> {
let protected: HeaderMap =
HeaderMap::from_bytes(&self.protected).map_err(CoseError::SerializationError)?;
let protected_enc_alg = match protected.get(&CborValue::Integer(1)) {
Some(CborValue::Integer(val)) => val,
_ => {
return Err(CoseError::SpecificationError(
"Protected Header contains invalid Encryption Algorithm specification"
.to_string(),
))
}
};
let cose_alg = match CoseAlgorithm::from_value(*protected_enc_alg as i8) {
Some(v) => v,
None => {
return Err(CoseError::UnsupportedError(
"Unsupported encryption algorithm".to_string(),
))
}
};
let protected_bytes =
map_to_empty_or_serialized(&protected).map_err(CoseError::SerializationError)?;
let enc_structure =
EncStructure::new_encrypt0(&protected_bytes).map_err(CoseError::SerializationError)?;
let iv = match self.unprotected.get(&CborValue::Integer(5)) {
Some(CborValue::Bytes(val)) => val,
_ => {
return Err(CoseError::SpecificationError(
"Unprotected Header contains invalid IV specification".to_string(),
))
}
};
let (ciphertext, tag) = self
.ciphertext
.split_at(self.ciphertext.len() - cose_alg.tag_size());
let payload = C::decrypt_aead(
cose_alg.into(),
key,
Some(iv),
&enc_structure
.as_bytes()
.map_err(CoseError::SerializationError)?,
ciphertext,
tag,
)
.map_err(|e| CoseError::EncryptionError(Box::new(e)))?;
Ok((protected, &self.unprotected, payload))
}
pub fn as_bytes(&self, tagged: bool) -> Result<Vec<u8>, CoseError> {
let bytes = if tagged {
serde_cbor::to_vec(&serde_cbor::tags::Tagged::new(Some(16), &self))
} else {
serde_cbor::to_vec(&self)
};
bytes.map_err(CoseError::SerializationError)
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, CoseError> {
let coseencrypt0: serde_cbor::tags::Tagged<Self> =
serde_cbor::from_slice(bytes).map_err(CoseError::SerializationError)?;
match coseencrypt0.tag {
None | Some(16) => (),
Some(tag) => return Err(CoseError::TagError(Some(tag))),
}
let protected = coseencrypt0.value.protected.as_slice();
let _: HeaderMap =
serde_cbor::from_slice(protected).map_err(CoseError::SerializationError)?;
Ok(coseencrypt0.value)
}
}
#[cfg(all(test, feature = "openssl"))]
mod tests {
use super::*;
use crate::crypto::Openssl;
#[test]
fn test_encrypt_decrypt() {
let key = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x0C\x0D\x0E\x0F";
let plaintext = b"\x12\x34\x56\x78\x90\x12\x34\x56\x12\x34\x56\x78\x90\x12\x34\x56";
let cencrypt0 =
CoseEncrypt0::new::<Openssl>(plaintext, CipherConfiguration::Gcm, key).unwrap();
let (_, _, dec) = cencrypt0.decrypt::<Openssl>(key).unwrap();
assert_eq!(dec, plaintext);
assert_ne!(
plaintext.to_vec(),
serde_cbor::to_vec(&cencrypt0.ciphertext).unwrap()
);
let fromb = CoseEncrypt0::from_bytes(&cencrypt0.as_bytes(true).unwrap()[..]).unwrap();
let (_, _, dec) = fromb.decrypt::<Openssl>(key).unwrap();
assert_eq!(dec, plaintext);
assert_ne!(
plaintext.to_vec(),
serde_cbor::to_vec(&fromb.ciphertext).unwrap()
);
}
#[test]
fn test_encrypt_unsupported_alg() {
let key = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x0C\x0D\x0E\x0F\x56\x56";
let plaintext = b"\x12\x34\x56\x78\x90\x12\x34\x56\x12\x34\x56\x78\x90\x12\x34\x56";
let cencrypt0 = CoseEncrypt0::new::<Openssl>(plaintext, CipherConfiguration::Gcm, key);
match cencrypt0.unwrap_err() {
CoseError::UnsupportedError(_) => (),
_ => panic!(),
}
}
#[test]
fn test_decrypt_invalid_alg_spec() {
let key = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x0C\x0D\x0E\x0F";
let plaintext = b"\x12\x34\x56\x78\x90\x12\x34\x56\x12\x34\x56\x78\x90\x12\x34\x56";
let mut cencrypt0 =
CoseEncrypt0::new::<Openssl>(plaintext, CipherConfiguration::Gcm, key).unwrap();
let mut protected = HeaderMap::new();
protected.insert(KTY.into(), CborValue::Text("invalid".to_string()));
let protected_bytes = map_to_empty_or_serialized(&protected)
.map_err(CoseError::SerializationError)
.unwrap();
cencrypt0.protected = ByteBuf::from(protected_bytes);
match cencrypt0.decrypt::<Openssl>(key).unwrap_err() {
CoseError::SpecificationError(_) => (),
_ => panic!(),
}
}
#[test]
fn test_decrypt_unsupported_openssl_cipher() {
let key = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x0C\x0D\x0E\x0F";
let plaintext = b"\x12\x34\x56\x78\x90\x12\x34\x56\x12\x34\x56\x78\x90\x12\x34\x56";
let mut cencrypt0 =
CoseEncrypt0::new::<Openssl>(plaintext, CipherConfiguration::Gcm, key).unwrap();
let mut protected = HeaderMap::new();
protected.insert(KTY.into(), CborValue::Integer(42));
let protected_bytes = map_to_empty_or_serialized(&protected)
.map_err(CoseError::SerializationError)
.unwrap();
cencrypt0.protected = ByteBuf::from(protected_bytes);
match cencrypt0.decrypt::<Openssl>(key).unwrap_err() {
CoseError::UnsupportedError(_) => (),
_ => panic!(),
}
}
#[test]
fn test_decrypt_invalid_iv() {
let key = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x0C\x0D\x0E\x0F";
let plaintext = b"\x12\x34\x56\x78\x90\x12\x34\x56\x12\x34\x56\x78\x90\x12\x34\x56";
let mut cencrypt0 =
CoseEncrypt0::new::<Openssl>(plaintext, CipherConfiguration::Gcm, key).unwrap();
let mut unprotected = HeaderMap::new();
unprotected.insert(IV.into(), CborValue::Integer(42));
cencrypt0.unprotected = unprotected;
match cencrypt0.decrypt::<Openssl>(key).unwrap_err() {
CoseError::SpecificationError(_) => (),
_ => panic!(),
}
}
}