tokio_serde/
lib.rs

1//! This crate provides the utilities needed to easily implement a Tokio
2//! transport using [serde] for serialization and deserialization of frame
3//! values.
4//!
5//! # Introduction
6//!
7//! This crate provides [transport] combinators that transform a stream of
8//! frames encoded as bytes into a stream of frame values. It is expected that
9//! the framing happens at another layer. One option is to use a [length
10//! delimited] framing transport.
11//!
12//! The crate provides two traits that must be implemented: [`Serializer`] and
13//! [`Deserializer`]. Implementations of these traits are then passed to
14//! [`Framed`] along with the upstream [`Stream`] or
15//! [`Sink`] that handles the byte encoded frames.
16//!
17//! By doing this, a transformation pipeline is built. For reading, it looks
18//! something like this:
19//!
20//! * `tokio_serde::Framed`
21//! * `tokio_util::codec::FramedRead`
22//! * `tokio::net::TcpStream`
23//!
24//! The write half looks like:
25//!
26//! * `tokio_serde::Framed`
27//! * `tokio_util::codec::FramedWrite`
28//! * `tokio::net::TcpStream`
29//!
30//! # Examples
31//!
32//! For an example, see how JSON support is implemented:
33//!
34//! * [server](https://github.com/carllerche/tokio-serde/blob/master/examples/server.rs)
35//! * [client](https://github.com/carllerche/tokio-serde/blob/master/examples/client.rs)
36//!
37//! [serde]: https://serde.rs
38//! [serde-json]: https://github.com/serde-rs/json
39//! [transport]: https://tokio.rs/docs/going-deeper/transports/
40//! [length delimited]: https://docs.rs/tokio-util/0.2/tokio_util/codec/length_delimited/index.html
41//! [`Serializer`]: trait.Serializer.html
42//! [`Deserializer`]: trait.Deserializer.html
43//! [`Framed`]: struct.Framed.html
44//! [`Stream`]: https://docs.rs/futures/0.3/futures/stream/trait.Stream.html
45//! [`Sink`]: https://docs.rs/futures/0.3/futures/sink/trait.Sink.html
46
47#![cfg_attr(docsrs, feature(doc_cfg))]
48
49use bytes::{Bytes, BytesMut};
50use futures_core::{ready, Stream, TryStream};
51use futures_sink::Sink;
52use pin_project::pin_project;
53use std::{
54    marker::PhantomData,
55    pin::Pin,
56    task::{Context, Poll},
57};
58
59/// Serializes a value into a destination buffer
60///
61/// Implementations of `Serializer` are able to take values of type `T` and
62/// convert them to a byte representation. The specific byte format, i.e. JSON,
63/// protobuf, binpack, ... is an implementation detail.
64///
65/// The `serialize` function takes `&mut self`, allowing for `Serializer`
66/// instances to be created with runtime configuration settings.
67///
68/// # Examples
69///
70/// An integer serializer that allows the width to be configured.
71///
72/// ```
73/// use tokio_serde::Serializer;
74/// use bytes::{Buf, Bytes, BytesMut, BufMut};
75/// use std::pin::Pin;
76///
77/// struct IntSerializer {
78///     width: usize,
79/// }
80///
81/// #[derive(Debug)]
82/// enum Error {
83///     Overflow,
84/// }
85///
86/// impl Serializer<u64> for IntSerializer {
87///     type Error = Error;
88///
89///     fn serialize(self: Pin<&mut Self>, item: &u64) -> Result<Bytes, Self::Error> {
90///         assert!(self.width <= 8);
91///
92///         let max = (1 << (self.width * 8)) - 1;
93///
94///         if *item > max {
95///             return Err(Error::Overflow);
96///         }
97///
98///         let mut ret = BytesMut::with_capacity(self.width);
99///         ret.put_uint(*item, self.width);
100///         Ok(ret.into())
101///     }
102/// }
103///
104/// let mut serializer = IntSerializer { width: 3 };
105///
106/// let buf = Pin::new(&mut serializer).serialize(&5).unwrap();
107/// assert_eq!(buf, &b"\x00\x00\x05"[..]);
108/// ```
109pub trait Serializer<T> {
110    type Error;
111
112    /// Serializes `item` into a new buffer
113    ///
114    /// The serialization format is specific to the various implementations of
115    /// `Serializer`. If the serialization is successful, a buffer containing
116    /// the serialized item is returned. If the serialization is unsuccessful,
117    /// an error is returned.
118    ///
119    /// Implementations of this function should not mutate `item` via any sort
120    /// of internal mutability strategy.
121    ///
122    /// See the trait level docs for more detail.
123    fn serialize(self: Pin<&mut Self>, item: &T) -> Result<Bytes, Self::Error>;
124}
125
126/// Deserializes a value from a source buffer
127///
128/// Implementatinos of `Deserializer` take a byte buffer and return a value by
129/// parsing the contents of the buffer according to the implementation's format.
130/// The specific byte format, i.e. JSON, protobuf, binpack, is an implementation
131/// detail
132///
133/// The `deserialize` function takes `&mut self`, allowing for `Deserializer`
134/// instances to be created with runtime configuration settings.
135///
136/// It is expected that the supplied buffer represents a full value and only
137/// that value. If after deserializing a value there are remaining bytes the
138/// buffer, the deserializer will return an error.
139///
140/// # Examples
141///
142/// An integer deserializer that allows the width to be configured.
143///
144/// ```
145/// use tokio_serde::Deserializer;
146/// use bytes::{BytesMut, Buf};
147/// use std::pin::Pin;
148///
149/// struct IntDeserializer {
150///     width: usize,
151/// }
152///
153/// #[derive(Debug)]
154/// enum Error {
155///     Underflow,
156///     Overflow
157/// }
158///
159/// impl Deserializer<u64> for IntDeserializer {
160///     type Error = Error;
161///
162///     fn deserialize(self: Pin<&mut Self>, buf: &BytesMut) -> Result<u64, Self::Error> {
163///         assert!(self.width <= 8);
164///
165///         if buf.len() > self.width {
166///             return Err(Error::Overflow);
167///         }
168///
169///         if buf.len() < self.width {
170///             return Err(Error::Underflow);
171///         }
172///
173///         let ret = std::io::Cursor::new(buf).get_uint(self.width);
174///         Ok(ret)
175///     }
176/// }
177///
178/// let mut deserializer = IntDeserializer { width: 3 };
179///
180/// let i = Pin::new(&mut deserializer).deserialize(&b"\x00\x00\x05"[..].into()).unwrap();
181/// assert_eq!(i, 5);
182/// ```
183pub trait Deserializer<T> {
184    type Error;
185
186    /// Deserializes a value from `buf`
187    ///
188    /// The serialization format is specific to the various implementations of
189    /// `Deserializer`. If the deserialization is successful, the value is
190    /// returned. If the deserialization is unsuccessful, an error is returned.
191    ///
192    /// See the trait level docs for more detail.
193    fn deserialize(self: Pin<&mut Self>, src: &BytesMut) -> Result<T, Self::Error>;
194}
195
196/// Adapts a transport to a value sink by serializing the values and to a stream of values by deserializing them.
197///
198/// It is expected that the buffers yielded by the supplied transport be framed. In
199/// other words, each yielded buffer must represent exactly one serialized
200/// value.
201///
202/// The provided transport will receive buffer values containing the
203/// serialized value. Each buffer contains exactly one value. This sink will be
204/// responsible for writing these buffers to an `AsyncWrite` using some sort of
205/// framing strategy.
206///
207/// The specific framing strategy is left up to the
208/// implementor. One option would be to use [length_delimited] provided by
209/// [tokio-util].
210///
211/// [length_delimited]: http://docs.rs/tokio-util/0.2/tokio_util/codec/length_delimited/index.html
212/// [tokio-util]: http://crates.io/crates/tokio-util
213#[pin_project]
214#[derive(Debug)]
215pub struct Framed<Transport, Item, SinkItem, Codec> {
216    #[pin]
217    inner: Transport,
218    #[pin]
219    codec: Codec,
220    item: PhantomData<(Item, SinkItem)>,
221}
222
223impl<Transport, Item, SinkItem, Codec> Framed<Transport, Item, SinkItem, Codec> {
224    /// Creates a new `Framed` with the given transport and codec.
225    pub fn new(inner: Transport, codec: Codec) -> Self {
226        Self {
227            inner,
228            codec,
229            item: PhantomData,
230        }
231    }
232
233    /// Returns a reference to the underlying transport wrapped by `Framed`.
234    ///
235    /// Note that care should be taken to not tamper with the underlying transport as
236    /// it may corrupt the sequence of frames otherwise being worked with.
237    pub fn get_ref(&self) -> &Transport {
238        &self.inner
239    }
240
241    /// Returns a mutable reference to the underlying transport wrapped by
242    /// `Framed`.
243    ///
244    /// Note that care should be taken to not tamper with the underlying transport as
245    /// it may corrupt the sequence of frames otherwise being worked with.
246    pub fn get_mut(&mut self) -> &mut Transport {
247        &mut self.inner
248    }
249
250    /// Consumes the `Framed`, returning its underlying transport.
251    ///
252    /// Note that care should be taken to not tamper with the underlying transport as
253    /// it may corrupt the sequence of frames otherwise being worked with.
254    pub fn into_inner(self) -> Transport {
255        self.inner
256    }
257}
258
259impl<Transport, Item, SinkItem, Codec> Stream for Framed<Transport, Item, SinkItem, Codec>
260where
261    Transport: TryStream<Ok = BytesMut>,
262    Transport::Error: From<Codec::Error>,
263    BytesMut: From<Transport::Ok>,
264    Codec: Deserializer<Item>,
265{
266    type Item = Result<Item, Transport::Error>;
267
268    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
269        match ready!(self.as_mut().project().inner.try_poll_next(cx)) {
270            Some(bytes) => Poll::Ready(Some(Ok(self
271                .as_mut()
272                .project()
273                .codec
274                .deserialize(&bytes?)?))),
275            None => Poll::Ready(None),
276        }
277    }
278}
279
280impl<Transport, Item, SinkItem, Codec> Sink<SinkItem> for Framed<Transport, Item, SinkItem, Codec>
281where
282    Transport: Sink<Bytes>,
283    Codec: Serializer<SinkItem>,
284    Codec::Error: Into<Transport::Error>,
285{
286    type Error = Transport::Error;
287
288    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
289        self.project().inner.poll_ready(cx)
290    }
291
292    fn start_send(mut self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
293        let res = self.as_mut().project().codec.serialize(&item);
294        let bytes = res.map_err(Into::into)?;
295
296        self.as_mut().project().inner.start_send(bytes)?;
297
298        Ok(())
299    }
300
301    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
302        self.project().inner.poll_flush(cx)
303    }
304
305    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
306        ready!(self.as_mut().poll_flush(cx))?;
307        self.project().inner.poll_close(cx)
308    }
309}
310
311pub type SymmetricallyFramed<Transport, Value, Codec> = Framed<Transport, Value, Value, Codec>;
312
313#[cfg(any(
314    feature = "json",
315    feature = "bincode",
316    feature = "messagepack",
317    feature = "cbor"
318))]
319pub mod formats {
320    #[cfg(feature = "bincode")]
321    pub use self::bincode::*;
322    #[cfg(feature = "cbor")]
323    pub use self::cbor::*;
324    #[cfg(feature = "json")]
325    pub use self::json::*;
326    #[cfg(feature = "messagepack")]
327    pub use self::messagepack::*;
328
329    use super::{Deserializer, Serializer};
330    use bytes::{Bytes, BytesMut};
331    use educe::Educe;
332    use serde::{Deserialize, Serialize};
333    use std::{marker::PhantomData, pin::Pin};
334
335    #[cfg(feature = "bincode")]
336    mod bincode {
337        use super::*;
338        use bincode_crate::config::Options;
339        use std::io;
340
341        /// Bincode codec using [bincode](https://docs.rs/bincode) crate.
342        #[cfg_attr(docsrs, doc(cfg(feature = "bincode")))]
343        #[derive(Educe)]
344        #[educe(Debug)]
345        pub struct Bincode<Item, SinkItem, O = bincode_crate::DefaultOptions> {
346            #[educe(Debug(ignore))]
347            options: O,
348            #[educe(Debug(ignore))]
349            ghost: PhantomData<(Item, SinkItem)>,
350        }
351
352        impl<Item, SinkItem> Default for Bincode<Item, SinkItem> {
353            fn default() -> Self {
354                Bincode {
355                    options: Default::default(),
356                    ghost: PhantomData,
357                }
358            }
359        }
360
361        impl<Item, SinkItem, O> From<O> for Bincode<Item, SinkItem, O>
362        where
363            O: Options,
364        {
365            fn from(options: O) -> Self {
366                Self {
367                    options,
368                    ghost: PhantomData,
369                }
370            }
371        }
372
373        #[cfg_attr(docsrs, doc(cfg(feature = "bincode")))]
374        pub type SymmetricalBincode<T, O = bincode_crate::DefaultOptions> = Bincode<T, T, O>;
375
376        impl<Item, SinkItem, O> Deserializer<Item> for Bincode<Item, SinkItem, O>
377        where
378            for<'a> Item: Deserialize<'a>,
379            O: Options + Clone,
380        {
381            type Error = io::Error;
382
383            fn deserialize(self: Pin<&mut Self>, src: &BytesMut) -> Result<Item, Self::Error> {
384                self.options
385                    .clone()
386                    .deserialize(src)
387                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
388            }
389        }
390
391        impl<Item, SinkItem, O> Serializer<SinkItem> for Bincode<Item, SinkItem, O>
392        where
393            SinkItem: Serialize,
394            O: Options + Clone,
395        {
396            type Error = io::Error;
397
398            fn serialize(self: Pin<&mut Self>, item: &SinkItem) -> Result<Bytes, Self::Error> {
399                self.options
400                    .clone()
401                    .serialize(item)
402                    .map(From::from)
403                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
404            }
405        }
406    }
407
408    #[cfg(feature = "json")]
409    mod json {
410        use super::*;
411        use bytes::Buf;
412
413        /// JSON codec using [serde_json](https://docs.rs/serde_json) crate.
414        #[cfg_attr(docsrs, doc(cfg(feature = "json")))]
415        #[derive(Educe)]
416        #[educe(Debug, Default)]
417        pub struct Json<Item, SinkItem> {
418            #[educe(Debug(ignore))]
419            ghost: PhantomData<(Item, SinkItem)>,
420        }
421
422        #[cfg_attr(docsrs, doc(cfg(feature = "json")))]
423        pub type SymmetricalJson<T> = Json<T, T>;
424
425        impl<Item, SinkItem> Deserializer<Item> for Json<Item, SinkItem>
426        where
427            for<'a> Item: Deserialize<'a>,
428        {
429            type Error = serde_json::Error;
430
431            fn deserialize(self: Pin<&mut Self>, src: &BytesMut) -> Result<Item, Self::Error> {
432                serde_json::from_reader(std::io::Cursor::new(src).reader())
433            }
434        }
435
436        impl<Item, SinkItem> Serializer<SinkItem> for Json<Item, SinkItem>
437        where
438            SinkItem: Serialize,
439        {
440            type Error = serde_json::Error;
441
442            fn serialize(self: Pin<&mut Self>, item: &SinkItem) -> Result<Bytes, Self::Error> {
443                serde_json::to_vec(item).map(Into::into)
444            }
445        }
446    }
447
448    #[cfg(feature = "messagepack")]
449    mod messagepack {
450        use super::*;
451        use bytes::Buf;
452        use std::io;
453
454        /// MessagePack codec using [rmp-serde](https://docs.rs/rmp-serde) crate.
455        #[cfg_attr(docsrs, doc(cfg(feature = "messagepack")))]
456        #[derive(Educe)]
457        #[educe(Debug, Default)]
458        pub struct MessagePack<Item, SinkItem> {
459            #[educe(Debug(ignore))]
460            ghost: PhantomData<(Item, SinkItem)>,
461        }
462
463        #[cfg_attr(docsrs, doc(cfg(feature = "messagepack")))]
464        pub type SymmetricalMessagePack<T> = MessagePack<T, T>;
465
466        impl<Item, SinkItem> Deserializer<Item> for MessagePack<Item, SinkItem>
467        where
468            for<'a> Item: Deserialize<'a>,
469        {
470            type Error = io::Error;
471
472            fn deserialize(self: Pin<&mut Self>, src: &BytesMut) -> Result<Item, Self::Error> {
473                rmp_serde::from_read(std::io::Cursor::new(src).reader())
474                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
475            }
476        }
477
478        impl<Item, SinkItem> Serializer<SinkItem> for MessagePack<Item, SinkItem>
479        where
480            SinkItem: Serialize,
481        {
482            type Error = io::Error;
483
484            fn serialize(self: Pin<&mut Self>, item: &SinkItem) -> Result<Bytes, Self::Error> {
485                Ok(rmp_serde::to_vec(item)
486                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?
487                    .into())
488            }
489        }
490    }
491
492    #[cfg(feature = "cbor")]
493    mod cbor {
494        use super::*;
495        use std::io;
496
497        /// CBOR codec using [serde_cbor](https://docs.rs/serde_cbor) crate.
498        #[cfg_attr(docsrs, doc(cfg(feature = "cbor")))]
499        #[derive(Educe)]
500        #[educe(Debug, Default)]
501        pub struct Cbor<Item, SinkItem> {
502            #[educe(Debug(ignore))]
503            _mkr: PhantomData<(Item, SinkItem)>,
504        }
505
506        #[cfg_attr(docsrs, doc(cfg(feature = "cbor")))]
507        pub type SymmetricalCbor<T> = Cbor<T, T>;
508
509        impl<Item, SinkItem> Deserializer<Item> for Cbor<Item, SinkItem>
510        where
511            for<'a> Item: Deserialize<'a>,
512        {
513            type Error = io::Error;
514
515            fn deserialize(self: Pin<&mut Self>, src: &BytesMut) -> Result<Item, Self::Error> {
516                serde_cbor::from_slice(src.as_ref()).map_err(into_io_error)
517            }
518        }
519
520        impl<Item, SinkItem> Serializer<SinkItem> for Cbor<Item, SinkItem>
521        where
522            SinkItem: Serialize,
523        {
524            type Error = io::Error;
525
526            fn serialize(self: Pin<&mut Self>, item: &SinkItem) -> Result<Bytes, Self::Error> {
527                serde_cbor::to_vec(item)
528                    .map_err(into_io_error)
529                    .map(Into::into)
530            }
531        }
532
533        fn into_io_error(cbor_err: serde_cbor::Error) -> io::Error {
534            use io::ErrorKind;
535            use serde_cbor::error::Category;
536            use std::error::Error;
537
538            match cbor_err.classify() {
539                Category::Eof => io::Error::new(ErrorKind::UnexpectedEof, cbor_err),
540                Category::Syntax => io::Error::new(ErrorKind::InvalidInput, cbor_err),
541                Category::Data => io::Error::new(ErrorKind::InvalidData, cbor_err),
542                Category::Io => {
543                    // Extract the underlying io error's type
544                    let kind = cbor_err
545                        .source()
546                        .and_then(|err| err.downcast_ref::<io::Error>())
547                        .map(|io_err| io_err.kind())
548                        .unwrap_or(ErrorKind::Other);
549                    io::Error::new(kind, cbor_err)
550                }
551            }
552        }
553    }
554}
555
556#[cfg(test)]
557mod tests {
558    #[cfg(feature = "bincode")]
559    #[test]
560    fn bincode_impls() {
561        use impls::impls;
562        use std::fmt::Debug;
563
564        struct Nothing;
565        type T = crate::formats::Bincode<Nothing, Nothing>;
566
567        assert!(impls!(T: Debug));
568        assert!(impls!(T: Default));
569    }
570
571    #[cfg(feature = "json")]
572    #[test]
573    fn json_impls() {
574        use impls::impls;
575        use std::fmt::Debug;
576
577        struct Nothing;
578        type T = crate::formats::Json<Nothing, Nothing>;
579
580        assert!(impls!(T: Debug));
581        assert!(impls!(T: Default));
582    }
583
584    #[cfg(feature = "messagepack")]
585    #[test]
586    fn messagepack_impls() {
587        use impls::impls;
588        use std::fmt::Debug;
589
590        struct Nothing;
591        type T = crate::formats::MessagePack<Nothing, Nothing>;
592
593        assert!(impls!(T: Debug));
594        assert!(impls!(T: Default));
595    }
596
597    #[cfg(feature = "cbor")]
598    #[test]
599    fn cbor_impls() {
600        use impls::impls;
601        use std::fmt::Debug;
602
603        struct Nothing;
604        type T = crate::formats::Cbor<Nothing, Nothing>;
605
606        assert!(impls!(T: Debug));
607        assert!(impls!(T: Default));
608    }
609}