libp2p_request_response/
cbor.rs1pub type Behaviour<Req, Resp> = crate::Behaviour<codec::Codec<Req, Resp>>;
48
49mod codec {
50 use std::{collections::TryReserveError, convert::Infallible, io, marker::PhantomData};
51
52 use async_trait::async_trait;
53 use cbor4ii::core::error::DecodeError;
54 use futures::prelude::*;
55 use libp2p_swarm::StreamProtocol;
56 use serde::{de::DeserializeOwned, Serialize};
57
58 pub struct Codec<Req, Resp> {
59 request_size_maximum: u64,
61 response_size_maximum: u64,
63 phantom: PhantomData<(Req, Resp)>,
64 }
65
66 impl<Req, Resp> Default for Codec<Req, Resp> {
67 fn default() -> Self {
68 Codec {
69 request_size_maximum: 1024 * 1024,
70 response_size_maximum: 10 * 1024 * 1024,
71 phantom: PhantomData,
72 }
73 }
74 }
75
76 impl<Req, Resp> Clone for Codec<Req, Resp> {
77 fn clone(&self) -> Self {
78 Self {
79 request_size_maximum: self.request_size_maximum,
80 response_size_maximum: self.response_size_maximum,
81 phantom: PhantomData,
82 }
83 }
84 }
85
86 impl<Req, Resp> Codec<Req, Resp> {
87 pub fn set_request_size_maximum(mut self, request_size_maximum: u64) -> Self {
89 self.request_size_maximum = request_size_maximum;
90 self
91 }
92
93 pub fn set_response_size_maximum(mut self, response_size_maximum: u64) -> Self {
95 self.response_size_maximum = response_size_maximum;
96 self
97 }
98 }
99
100 #[async_trait]
101 impl<Req, Resp> crate::Codec for Codec<Req, Resp>
102 where
103 Req: Send + Serialize + DeserializeOwned,
104 Resp: Send + Serialize + DeserializeOwned,
105 {
106 type Protocol = StreamProtocol;
107 type Request = Req;
108 type Response = Resp;
109
110 async fn read_request<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Req>
111 where
112 T: AsyncRead + Unpin + Send,
113 {
114 let mut vec = Vec::new();
115
116 io.take(self.request_size_maximum)
117 .read_to_end(&mut vec)
118 .await?;
119
120 cbor4ii::serde::from_slice(vec.as_slice()).map_err(decode_into_io_error)
121 }
122
123 async fn read_response<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Resp>
124 where
125 T: AsyncRead + Unpin + Send,
126 {
127 let mut vec = Vec::new();
128
129 io.take(self.response_size_maximum)
130 .read_to_end(&mut vec)
131 .await?;
132
133 cbor4ii::serde::from_slice(vec.as_slice()).map_err(decode_into_io_error)
134 }
135
136 async fn write_request<T>(
137 &mut self,
138 _: &Self::Protocol,
139 io: &mut T,
140 req: Self::Request,
141 ) -> io::Result<()>
142 where
143 T: AsyncWrite + Unpin + Send,
144 {
145 let data: Vec<u8> =
146 cbor4ii::serde::to_vec(Vec::new(), &req).map_err(encode_into_io_error)?;
147
148 io.write_all(data.as_ref()).await?;
149
150 Ok(())
151 }
152
153 async fn write_response<T>(
154 &mut self,
155 _: &Self::Protocol,
156 io: &mut T,
157 resp: Self::Response,
158 ) -> io::Result<()>
159 where
160 T: AsyncWrite + Unpin + Send,
161 {
162 let data: Vec<u8> =
163 cbor4ii::serde::to_vec(Vec::new(), &resp).map_err(encode_into_io_error)?;
164
165 io.write_all(data.as_ref()).await?;
166
167 Ok(())
168 }
169 }
170
171 fn decode_into_io_error(err: cbor4ii::serde::DecodeError<Infallible>) -> io::Error {
172 match err {
173 #[allow(unreachable_patterns)]
175 cbor4ii::serde::DecodeError::Core(DecodeError::Read(e)) => {
176 io::Error::new(io::ErrorKind::Other, e)
177 }
178 cbor4ii::serde::DecodeError::Core(e @ DecodeError::Unsupported { .. }) => {
179 io::Error::new(io::ErrorKind::Unsupported, e)
180 }
181 cbor4ii::serde::DecodeError::Core(e @ DecodeError::Eof { .. }) => {
182 io::Error::new(io::ErrorKind::UnexpectedEof, e)
183 }
184 cbor4ii::serde::DecodeError::Core(e) => io::Error::new(io::ErrorKind::InvalidData, e),
185 cbor4ii::serde::DecodeError::Custom(e) => {
186 io::Error::new(io::ErrorKind::Other, e.to_string())
187 }
188 }
189 }
190
191 fn encode_into_io_error(err: cbor4ii::serde::EncodeError<TryReserveError>) -> io::Error {
192 io::Error::new(io::ErrorKind::Other, err)
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use futures::AsyncWriteExt;
199 use futures_ringbuf::Endpoint;
200 use libp2p_swarm::StreamProtocol;
201 use serde::{Deserialize, Serialize};
202
203 use crate::{cbor::codec::Codec, Codec as _};
204
205 #[async_std::test]
206 async fn test_codec() {
207 let expected_request = TestRequest {
208 payload: "test_payload".to_string(),
209 };
210 let expected_response = TestResponse {
211 payload: "test_payload".to_string(),
212 };
213 let protocol = StreamProtocol::new("/test_cbor/1");
214 let mut codec = Codec::default();
215
216 let (mut a, mut b) = Endpoint::pair(124, 124);
217 codec
218 .write_request(&protocol, &mut a, expected_request.clone())
219 .await
220 .expect("Should write request");
221 a.close().await.unwrap();
222
223 let actual_request = codec
224 .read_request(&protocol, &mut b)
225 .await
226 .expect("Should read request");
227 b.close().await.unwrap();
228
229 assert_eq!(actual_request, expected_request);
230
231 let (mut a, mut b) = Endpoint::pair(124, 124);
232 codec
233 .write_response(&protocol, &mut a, expected_response.clone())
234 .await
235 .expect("Should write response");
236 a.close().await.unwrap();
237
238 let actual_response = codec
239 .read_response(&protocol, &mut b)
240 .await
241 .expect("Should read response");
242 b.close().await.unwrap();
243
244 assert_eq!(actual_response, expected_response);
245 }
246
247 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
248 struct TestRequest {
249 payload: String,
250 }
251
252 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
253 struct TestResponse {
254 payload: String,
255 }
256}