libp2p_request_response/
cbor.rs

1// Copyright 2023 Protocol Labs
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21/// A request-response behaviour using [`cbor4ii::serde`] for serializing and
22/// deserializing the messages.
23///
24/// # Example
25///
26/// ```
27/// # use libp2p_request_response::{cbor, ProtocolSupport, self as request_response};
28/// # use libp2p_swarm::StreamProtocol;
29/// #[derive(Debug, serde::Serialize, serde::Deserialize)]
30/// struct GreetRequest {
31///     name: String,
32/// }
33///
34/// #[derive(Debug, serde::Serialize, serde::Deserialize)]
35/// struct GreetResponse {
36///     message: String,
37/// }
38///
39/// let behaviour = cbor::Behaviour::<GreetRequest, GreetResponse>::new(
40///     [(
41///         StreamProtocol::new("/my-cbor-protocol"),
42///         ProtocolSupport::Full,
43///     )],
44///     request_response::Config::default(),
45/// );
46/// ```
47pub type Behaviour<Req, Resp> = crate::Behaviour<codec::Codec<Req, Resp>>;
48
49mod codec {
50    use std::{collections::TryReserveError, convert::Infallible, io, marker::PhantomData};
51
52    use async_trait::async_trait;
53    use cbor4ii::core::error::DecodeError;
54    use futures::prelude::*;
55    use libp2p_swarm::StreamProtocol;
56    use serde::{de::DeserializeOwned, Serialize};
57
58    pub struct Codec<Req, Resp> {
59        /// Max request size in bytes.
60        request_size_maximum: u64,
61        /// Max response size in bytes.
62        response_size_maximum: u64,
63        phantom: PhantomData<(Req, Resp)>,
64    }
65
66    impl<Req, Resp> Default for Codec<Req, Resp> {
67        fn default() -> Self {
68            Codec {
69                request_size_maximum: 1024 * 1024,
70                response_size_maximum: 10 * 1024 * 1024,
71                phantom: PhantomData,
72            }
73        }
74    }
75
76    impl<Req, Resp> Clone for Codec<Req, Resp> {
77        fn clone(&self) -> Self {
78            Self {
79                request_size_maximum: self.request_size_maximum,
80                response_size_maximum: self.response_size_maximum,
81                phantom: PhantomData,
82            }
83        }
84    }
85
86    impl<Req, Resp> Codec<Req, Resp> {
87        /// Sets the limit for request size in bytes.
88        pub fn set_request_size_maximum(mut self, request_size_maximum: u64) -> Self {
89            self.request_size_maximum = request_size_maximum;
90            self
91        }
92
93        /// Sets the limit for response size in bytes.
94        pub fn set_response_size_maximum(mut self, response_size_maximum: u64) -> Self {
95            self.response_size_maximum = response_size_maximum;
96            self
97        }
98    }
99
100    #[async_trait]
101    impl<Req, Resp> crate::Codec for Codec<Req, Resp>
102    where
103        Req: Send + Serialize + DeserializeOwned,
104        Resp: Send + Serialize + DeserializeOwned,
105    {
106        type Protocol = StreamProtocol;
107        type Request = Req;
108        type Response = Resp;
109
110        async fn read_request<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Req>
111        where
112            T: AsyncRead + Unpin + Send,
113        {
114            let mut vec = Vec::new();
115
116            io.take(self.request_size_maximum)
117                .read_to_end(&mut vec)
118                .await?;
119
120            cbor4ii::serde::from_slice(vec.as_slice()).map_err(decode_into_io_error)
121        }
122
123        async fn read_response<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Resp>
124        where
125            T: AsyncRead + Unpin + Send,
126        {
127            let mut vec = Vec::new();
128
129            io.take(self.response_size_maximum)
130                .read_to_end(&mut vec)
131                .await?;
132
133            cbor4ii::serde::from_slice(vec.as_slice()).map_err(decode_into_io_error)
134        }
135
136        async fn write_request<T>(
137            &mut self,
138            _: &Self::Protocol,
139            io: &mut T,
140            req: Self::Request,
141        ) -> io::Result<()>
142        where
143            T: AsyncWrite + Unpin + Send,
144        {
145            let data: Vec<u8> =
146                cbor4ii::serde::to_vec(Vec::new(), &req).map_err(encode_into_io_error)?;
147
148            io.write_all(data.as_ref()).await?;
149
150            Ok(())
151        }
152
153        async fn write_response<T>(
154            &mut self,
155            _: &Self::Protocol,
156            io: &mut T,
157            resp: Self::Response,
158        ) -> io::Result<()>
159        where
160            T: AsyncWrite + Unpin + Send,
161        {
162            let data: Vec<u8> =
163                cbor4ii::serde::to_vec(Vec::new(), &resp).map_err(encode_into_io_error)?;
164
165            io.write_all(data.as_ref()).await?;
166
167            Ok(())
168        }
169    }
170
171    fn decode_into_io_error(err: cbor4ii::serde::DecodeError<Infallible>) -> io::Error {
172        match err {
173            // TODO: remove when Rust 1.82 is MSRV
174            #[allow(unreachable_patterns)]
175            cbor4ii::serde::DecodeError::Core(DecodeError::Read(e)) => {
176                io::Error::new(io::ErrorKind::Other, e)
177            }
178            cbor4ii::serde::DecodeError::Core(e @ DecodeError::Unsupported { .. }) => {
179                io::Error::new(io::ErrorKind::Unsupported, e)
180            }
181            cbor4ii::serde::DecodeError::Core(e @ DecodeError::Eof { .. }) => {
182                io::Error::new(io::ErrorKind::UnexpectedEof, e)
183            }
184            cbor4ii::serde::DecodeError::Core(e) => io::Error::new(io::ErrorKind::InvalidData, e),
185            cbor4ii::serde::DecodeError::Custom(e) => {
186                io::Error::new(io::ErrorKind::Other, e.to_string())
187            }
188        }
189    }
190
191    fn encode_into_io_error(err: cbor4ii::serde::EncodeError<TryReserveError>) -> io::Error {
192        io::Error::new(io::ErrorKind::Other, err)
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use futures::AsyncWriteExt;
199    use futures_ringbuf::Endpoint;
200    use libp2p_swarm::StreamProtocol;
201    use serde::{Deserialize, Serialize};
202
203    use crate::{cbor::codec::Codec, Codec as _};
204
205    #[async_std::test]
206    async fn test_codec() {
207        let expected_request = TestRequest {
208            payload: "test_payload".to_string(),
209        };
210        let expected_response = TestResponse {
211            payload: "test_payload".to_string(),
212        };
213        let protocol = StreamProtocol::new("/test_cbor/1");
214        let mut codec = Codec::default();
215
216        let (mut a, mut b) = Endpoint::pair(124, 124);
217        codec
218            .write_request(&protocol, &mut a, expected_request.clone())
219            .await
220            .expect("Should write request");
221        a.close().await.unwrap();
222
223        let actual_request = codec
224            .read_request(&protocol, &mut b)
225            .await
226            .expect("Should read request");
227        b.close().await.unwrap();
228
229        assert_eq!(actual_request, expected_request);
230
231        let (mut a, mut b) = Endpoint::pair(124, 124);
232        codec
233            .write_response(&protocol, &mut a, expected_response.clone())
234            .await
235            .expect("Should write response");
236        a.close().await.unwrap();
237
238        let actual_response = codec
239            .read_response(&protocol, &mut b)
240            .await
241            .expect("Should read response");
242        b.close().await.unwrap();
243
244        assert_eq!(actual_response, expected_response);
245    }
246
247    #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
248    struct TestRequest {
249        payload: String,
250    }
251
252    #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
253    struct TestResponse {
254        payload: String,
255    }
256}