pub type Behaviour<Req, Resp> = crate::Behaviour<codec::Codec<Req, Resp>>;
mod codec {
use async_trait::async_trait;
use cbor4ii::core::error::DecodeError;
use futures::prelude::*;
use libp2p_swarm::StreamProtocol;
use serde::{de::DeserializeOwned, Serialize};
use std::{collections::TryReserveError, convert::Infallible, io, marker::PhantomData};
const REQUEST_SIZE_MAXIMUM: u64 = 1024 * 1024;
const RESPONSE_SIZE_MAXIMUM: u64 = 10 * 1024 * 1024;
pub struct Codec<Req, Resp> {
phantom: PhantomData<(Req, Resp)>,
}
impl<Req, Resp> Default for Codec<Req, Resp> {
fn default() -> Self {
Codec {
phantom: PhantomData,
}
}
}
impl<Req, Resp> Clone for Codec<Req, Resp> {
fn clone(&self) -> Self {
Self::default()
}
}
#[async_trait]
impl<Req, Resp> crate::Codec for Codec<Req, Resp>
where
Req: Send + Serialize + DeserializeOwned,
Resp: Send + Serialize + DeserializeOwned,
{
type Protocol = StreamProtocol;
type Request = Req;
type Response = Resp;
async fn read_request<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Req>
where
T: AsyncRead + Unpin + Send,
{
let mut vec = Vec::new();
io.take(REQUEST_SIZE_MAXIMUM).read_to_end(&mut vec).await?;
cbor4ii::serde::from_slice(vec.as_slice()).map_err(decode_into_io_error)
}
async fn read_response<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Resp>
where
T: AsyncRead + Unpin + Send,
{
let mut vec = Vec::new();
io.take(RESPONSE_SIZE_MAXIMUM).read_to_end(&mut vec).await?;
cbor4ii::serde::from_slice(vec.as_slice()).map_err(decode_into_io_error)
}
async fn write_request<T>(
&mut self,
_: &Self::Protocol,
io: &mut T,
req: Self::Request,
) -> io::Result<()>
where
T: AsyncWrite + Unpin + Send,
{
let data: Vec<u8> =
cbor4ii::serde::to_vec(Vec::new(), &req).map_err(encode_into_io_error)?;
io.write_all(data.as_ref()).await?;
Ok(())
}
async fn write_response<T>(
&mut self,
_: &Self::Protocol,
io: &mut T,
resp: Self::Response,
) -> io::Result<()>
where
T: AsyncWrite + Unpin + Send,
{
let data: Vec<u8> =
cbor4ii::serde::to_vec(Vec::new(), &resp).map_err(encode_into_io_error)?;
io.write_all(data.as_ref()).await?;
Ok(())
}
}
fn decode_into_io_error(err: cbor4ii::serde::DecodeError<Infallible>) -> io::Error {
match err {
cbor4ii::serde::DecodeError::Core(DecodeError::Read(e)) => {
io::Error::new(io::ErrorKind::Other, e)
}
cbor4ii::serde::DecodeError::Core(e @ DecodeError::Unsupported { .. }) => {
io::Error::new(io::ErrorKind::Unsupported, e)
}
cbor4ii::serde::DecodeError::Core(e @ DecodeError::Eof { .. }) => {
io::Error::new(io::ErrorKind::UnexpectedEof, e)
}
cbor4ii::serde::DecodeError::Core(e) => io::Error::new(io::ErrorKind::InvalidData, e),
cbor4ii::serde::DecodeError::Custom(e) => {
io::Error::new(io::ErrorKind::Other, e.to_string())
}
}
}
fn encode_into_io_error(err: cbor4ii::serde::EncodeError<TryReserveError>) -> io::Error {
io::Error::new(io::ErrorKind::Other, err)
}
}
#[cfg(test)]
mod tests {
use crate::cbor::codec::Codec;
use crate::Codec as _;
use futures::AsyncWriteExt;
use futures_ringbuf::Endpoint;
use libp2p_swarm::StreamProtocol;
use serde::{Deserialize, Serialize};
#[async_std::test]
async fn test_codec() {
let expected_request = TestRequest {
payload: "test_payload".to_string(),
};
let expected_response = TestResponse {
payload: "test_payload".to_string(),
};
let protocol = StreamProtocol::new("/test_cbor/1");
let mut codec = Codec::default();
let (mut a, mut b) = Endpoint::pair(124, 124);
codec
.write_request(&protocol, &mut a, expected_request.clone())
.await
.expect("Should write request");
a.close().await.unwrap();
let actual_request = codec
.read_request(&protocol, &mut b)
.await
.expect("Should read request");
b.close().await.unwrap();
assert_eq!(actual_request, expected_request);
let (mut a, mut b) = Endpoint::pair(124, 124);
codec
.write_response(&protocol, &mut a, expected_response.clone())
.await
.expect("Should write response");
a.close().await.unwrap();
let actual_response = codec
.read_response(&protocol, &mut b)
.await
.expect("Should read response");
b.close().await.unwrap();
assert_eq!(actual_response, expected_response);
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
struct TestRequest {
payload: String,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
struct TestResponse {
payload: String,
}
}