async_bincode/
writer.rs

1macro_rules! make_writer {
2    ($write_trait:path, $poll_close_method:ident) => {
3        pub use crate::writer::{AsyncDestination, BincodeWriterFor, SyncDestination};
4
5        /// A wrapper around an asynchronous sink that accepts, serializes, and sends bincode-encoded
6        /// values.
7        ///
8        /// To use, provide a reader that implements
9        #[doc=concat!("[`", stringify!($write_trait), "`],")]
10        /// and then use [`futures_sink::Sink`] to send values.
11        ///
12        /// Important: Only one element at a time is written to the output writer. It is recommended
13        /// to use a `BufWriter` in front of the output to batch write operations to the underlying writer.
14        ///
15        /// Note that an `AsyncBincodeWriter` must be of the type [`AsyncDestination`] in order to be
16        /// compatible with an [`AsyncBincodeReader`] on the remote end (recall that it requires the
17        /// serialized size prefixed to the serialized data). The default is [`SyncDestination`], but these
18        /// can be easily toggled between using [`AsyncBincodeWriter::for_async`].
19        #[derive(Debug)]
20        pub struct AsyncBincodeWriter<W, T, D> {
21            pub(crate) writer: W,
22            pub(crate) written: usize,
23            pub(crate) buffer: Vec<u8>,
24            pub(crate) from: std::marker::PhantomData<T>,
25            pub(crate) dest: std::marker::PhantomData<D>,
26        }
27
28        impl<W, T, D> Unpin for AsyncBincodeWriter<W, T, D> where W: Unpin {}
29
30        impl<W, T> Default for AsyncBincodeWriter<W, T, SyncDestination>
31        where
32            W: Default,
33        {
34            fn default() -> Self {
35                Self::from(W::default())
36            }
37        }
38
39        impl<W, T, D> AsyncBincodeWriter<W, T, D> {
40            /// Gets a reference to the underlying writer.
41            ///
42            /// It is inadvisable to directly write to the underlying writer.
43            pub fn get_ref(&self) -> &W {
44                &self.writer
45            }
46
47            /// Gets a mutable reference to the underlying writer.
48            ///
49            /// It is inadvisable to directly write to the underlying writer.
50            pub fn get_mut(&mut self) -> &mut W {
51                &mut self.writer
52            }
53
54            /// Unwraps this `AsyncBincodeWriter`, returning the underlying writer.
55            ///
56            /// Note that any leftover serialized data that has not yet been sent is lost.
57            pub fn into_inner(self) -> W {
58                self.writer
59            }
60        }
61
62        impl<W, T> From<W> for AsyncBincodeWriter<W, T, SyncDestination> {
63            fn from(writer: W) -> Self {
64                Self {
65                    buffer: Vec::new(),
66                    writer,
67                    written: 0,
68                    from: std::marker::PhantomData,
69                    dest: std::marker::PhantomData,
70                }
71            }
72        }
73
74        impl<W, T> AsyncBincodeWriter<W, T, SyncDestination> {
75            /// Make this writer include the serialized data's size before each serialized value.
76            ///
77            /// This is necessary for compatibility with [`AsyncBincodeReader`].
78            pub fn for_async(self) -> AsyncBincodeWriter<W, T, AsyncDestination> {
79                self.make_for()
80            }
81        }
82
83        impl<W, T> AsyncBincodeWriter<W, T, AsyncDestination> {
84            /// Make this writer only send bincode-encoded values.
85            ///
86            /// This is necessary for compatibility with stock `bincode` receivers.
87            pub fn for_sync(self) -> AsyncBincodeWriter<W, T, SyncDestination> {
88                self.make_for()
89            }
90        }
91
92        impl<W, T, D> AsyncBincodeWriter<W, T, D> {
93            pub(crate) fn make_for<D2>(self) -> AsyncBincodeWriter<W, T, D2> {
94                AsyncBincodeWriter {
95                    buffer: self.buffer,
96                    writer: self.writer,
97                    written: self.written,
98                    from: self.from,
99                    dest: std::marker::PhantomData,
100                }
101            }
102        }
103
104        impl<W, T> BincodeWriterFor<T> for AsyncBincodeWriter<W, T, AsyncDestination>
105        where
106            T: serde::Serialize,
107        {
108            fn append(&mut self, item: T) -> Result<(), bincode::error::EncodeError> {
109                use bincode::config;
110                use byteorder::{NetworkEndian, WriteBytesExt};
111                let rewrite_at = self.buffer.len();
112                self.buffer
113                    .write_u32::<NetworkEndian>(0)
114                    .map_err(|inner| bincode::error::EncodeError::Io { inner, index: 0 })?;
115                let written = bincode::serde::encode_into_std_write(
116                    &item,
117                    &mut self.buffer,
118                    config::standard().with_limit::<{ u32::MAX as usize }>(),
119                )?;
120                (&mut self.buffer[rewrite_at..])
121                    .write_u32::<NetworkEndian>(written as u32)
122                    .map_err(|inner| bincode::error::EncodeError::Io {
123                        inner,
124                        index: written,
125                    })?;
126                Ok(())
127            }
128        }
129
130        impl<W, T> BincodeWriterFor<T> for AsyncBincodeWriter<W, T, SyncDestination>
131        where
132            T: serde::Serialize,
133        {
134            fn append(&mut self, item: T) -> Result<(), bincode::error::EncodeError> {
135                bincode::serde::encode_into_std_write(
136                    item,
137                    &mut self.buffer,
138                    bincode::config::standard(),
139                )
140                .map(|_| ())
141            }
142        }
143
144        impl<W, T, D> futures_sink::Sink<T> for AsyncBincodeWriter<W, T, D>
145        where
146            T: serde::Serialize,
147            W: $write_trait + Unpin,
148            Self: BincodeWriterFor<T>,
149        {
150            type Error = bincode::error::EncodeError;
151
152            fn poll_ready(
153                self: std::pin::Pin<&mut Self>,
154                cx: &mut std::task::Context,
155            ) -> std::task::Poll<Result<(), Self::Error>> {
156                // allow us to borrow fields separately
157                let this = self.get_mut();
158
159                // write stuff out if we need to
160                while this.written != this.buffer.len() {
161                    let n = futures_core::ready!(std::pin::Pin::new(&mut this.writer)
162                        .poll_write(cx, &this.buffer[this.written..]))
163                    .map_err(|inner| bincode::error::EncodeError::Io {
164                        inner,
165                        index: this.written,
166                    })?;
167                    this.written += n;
168                }
169
170                // cleanup the buffer
171                this.buffer.clear();
172                this.written = 0;
173                std::task::Poll::Ready(Ok(()))
174            }
175
176            fn start_send(mut self: std::pin::Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
177                // NOTE: in theory we could have a short-circuit here that tries to have bincode write
178                // directly into self.writer. this would be way more efficient in the common case as we
179                // don't have to do the extra buffering. the idea would be to serialize fist, and *if*
180                // it errors, see how many bytes were written, serialize again into a Vec, and then
181                // keep only the bytes following the number that were written in our buffer.
182                // unfortunately, bincode will not tell us that number at the moment, and instead just
183                // fail.
184
185                self.append(item)?;
186                Ok(())
187            }
188
189            fn poll_flush(
190                mut self: std::pin::Pin<&mut Self>,
191                cx: &mut std::task::Context,
192            ) -> std::task::Poll<Result<(), Self::Error>> {
193                futures_core::ready!(self.as_mut().poll_ready(cx))?;
194                std::pin::Pin::new(&mut self.writer)
195                    .poll_flush(cx)
196                    .map_err(|inner| bincode::error::EncodeError::Io {
197                        inner,
198                        index: self.written,
199                    })
200            }
201
202            fn poll_close(
203                mut self: std::pin::Pin<&mut Self>,
204                cx: &mut std::task::Context,
205            ) -> std::task::Poll<Result<(), Self::Error>> {
206                // in order to get to the first call to `poll_close`, `poll_ready` must have already
207                // finished and emptied the buffer. thus the call to `poll_ready` will no longer be calling
208                // `poll_write` on the underlying writer on re-entry.
209                futures_core::ready!(self.as_mut().poll_ready(cx))?;
210
211                // `futures::Sink:poll_close` documentation states that calling `poll_close` implies
212                // `poll_flush`, so explicitly calling `poll_flush` is not needed here.
213                std::pin::Pin::new(&mut self.writer)
214                    .$poll_close_method(cx)
215                    .map_err(|inner| bincode::error::EncodeError::Io {
216                        inner,
217                        index: self.written,
218                    })
219            }
220        }
221    };
222}
223
224/// A marker that indicates that the wrapping type is compatible with `AsyncBincodeReader`.
225#[derive(Debug)]
226pub struct AsyncDestination;
227
228/// A marker that indicates that the wrapping type is compatible with stock `bincode` receivers.
229#[derive(Debug)]
230pub struct SyncDestination;
231
232#[doc(hidden)]
233pub trait BincodeWriterFor<T> {
234    fn append(&mut self, item: T) -> Result<(), bincode::error::EncodeError>;
235}