use std::io::Error as IoError;
use std::marker::PhantomData;
use crate::{Decoder, Encoder};
use bytes::{Buf, BufMut, BytesMut};
use serde::{Deserialize, Serialize};
use serde_cbor::Error as CborError;
#[derive(Debug, PartialEq)]
pub struct CborCodec<Enc, Dec> {
enc: PhantomData<Enc>,
dec: PhantomData<Dec>,
}
#[derive(Debug)]
pub enum CborCodecError {
Io(IoError),
Cbor(CborError),
}
impl std::fmt::Display for CborCodecError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CborCodecError::Io(e) => write!(f, "I/O error: {}", e),
CborCodecError::Cbor(e) => write!(f, "CBOR error: {}", e),
}
}
}
impl std::error::Error for CborCodecError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
CborCodecError::Io(ref e) => Some(e),
CborCodecError::Cbor(ref e) => Some(e),
}
}
}
impl From<IoError> for CborCodecError {
fn from(e: IoError) -> CborCodecError {
CborCodecError::Io(e)
}
}
impl From<CborError> for CborCodecError {
fn from(e: CborError) -> CborCodecError {
CborCodecError::Cbor(e)
}
}
impl<Enc, Dec> CborCodec<Enc, Dec>
where
for<'de> Dec: Deserialize<'de> + 'static,
for<'de> Enc: Serialize + 'static,
{
pub fn new() -> CborCodec<Enc, Dec> {
CborCodec {
enc: PhantomData,
dec: PhantomData,
}
}
}
impl<Enc, Dec> Clone for CborCodec<Enc, Dec>
where
for<'de> Dec: Deserialize<'de> + 'static,
for<'de> Enc: Serialize + 'static,
{
fn clone(&self) -> CborCodec<Enc, Dec> {
CborCodec::new()
}
}
impl<Enc, Dec> Decoder for CborCodec<Enc, Dec>
where
for<'de> Dec: Deserialize<'de> + 'static,
for<'de> Enc: Serialize + 'static,
{
type Item = Dec;
type Error = CborCodecError;
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
let mut de = serde_cbor::Deserializer::from_slice(&buf);
let res: Result<Dec, _> = serde::de::Deserialize::deserialize(&mut de);
let item = match res {
Ok(item) => item,
Err(e) if e.is_eof() => return Ok(None),
Err(e) => return Err(e.into()),
};
let offset = de.byte_offset();
buf.advance(offset);
Ok(Some(item))
}
}
impl<Enc, Dec> Encoder for CborCodec<Enc, Dec>
where
for<'de> Dec: Deserialize<'de> + 'static,
for<'de> Enc: Serialize + 'static,
{
type Item = Enc;
type Error = CborCodecError;
fn encode(&mut self, data: Self::Item, buf: &mut BytesMut) -> Result<(), Self::Error> {
let j = serde_cbor::to_vec(&data)?;
buf.reserve(j.len());
buf.put_slice(&j);
Ok(())
}
}
impl<Enc, Dec> Default for CborCodec<Enc, Dec>
where
for<'de> Dec: Deserialize<'de> + 'static,
for<'de> Enc: Serialize + 'static,
{
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod test {
use bytes::BytesMut;
use serde::{Deserialize, Serialize};
use super::CborCodec;
use crate::{Decoder, Encoder};
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
struct TestStruct {
pub name: String,
pub data: u16,
}
#[test]
fn cbor_codec_encode_decode() {
let mut codec = CborCodec::<TestStruct, TestStruct>::new();
let mut buff = BytesMut::new();
let item1 = TestStruct {
name: "Test name".to_owned(),
data: 16,
};
codec.encode(item1.clone(), &mut buff).unwrap();
let item2 = codec.decode(&mut buff).unwrap().unwrap();
assert_eq!(item1, item2);
assert_eq!(codec.decode(&mut buff).unwrap(), None);
assert_eq!(buff.len(), 0);
}
#[test]
fn cbor_codec_partial_decode() {
let mut codec = CborCodec::<TestStruct, TestStruct>::new();
let mut buff = BytesMut::new();
let item1 = TestStruct {
name: "Test name".to_owned(),
data: 34,
};
codec.encode(item1, &mut buff).unwrap();
let mut start = buff.clone().split_to(4);
assert_eq!(codec.decode(&mut start).unwrap(), None);
codec.decode(&mut buff).unwrap().unwrap();
assert_eq!(buff.len(), 0);
}
#[test]
fn cbor_codec_eof_reached() {
let mut codec = CborCodec::<TestStruct, TestStruct>::new();
let mut buff = BytesMut::new();
let item1 = TestStruct {
name: "Test name".to_owned(),
data: 34,
};
codec.encode(item1.clone(), &mut buff).unwrap();
let mut buff_start = buff.clone().split_to(4);
let buff_end = buff.clone().split_off(4);
assert_eq!(codec.decode(&mut buff_start).unwrap(), None);
assert_eq!(buff_start.len(), 4);
buff_start.extend(buff_end.iter());
let item2 = codec.decode(&mut buff).unwrap().unwrap();
assert_eq!(item1, item2);
}
#[test]
fn cbor_codec_decode_error() {
let mut codec = CborCodec::<TestStruct, TestStruct>::new();
let mut buff = BytesMut::new();
let item1 = TestStruct {
name: "Test name".to_owned(),
data: 34,
};
codec.encode(item1.clone(), &mut buff).unwrap();
let mut buff_end = buff.clone().split_off(4);
let buff_end_length = buff_end.len();
assert!(codec.decode(&mut buff_end).is_err());
assert_eq!(buff_end.len(), buff_end_length);
}
}