async_tungstenite/
bytes.rs

1//! Provides abstractions to use `AsyncRead` and `AsyncWrite` with a `WebSocketStream`.
2
3use std::{
4    io,
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9use futures_core::stream::Stream;
10
11use crate::{tungstenite::Bytes, Message, WsError};
12
13/// Treat a `WebSocketStream` as an `AsyncWrite` implementation.
14///
15/// Every write sends a binary message. If you want to group writes together, consider wrapping
16/// this with a `BufWriter`.
17#[cfg(feature = "futures-03-sink")]
18#[derive(Debug)]
19pub struct ByteWriter<S>(S);
20
21#[cfg(feature = "futures-03-sink")]
22impl<S> ByteWriter<S> {
23    /// Create a new `ByteWriter` from a `Sink` that accepts a WebSocket `Message`
24    #[inline(always)]
25    pub fn new(s: S) -> Self {
26        Self(s)
27    }
28
29    /// Get the underlying `Sink` back.
30    #[inline(always)]
31    pub fn into_inner(self) -> S {
32        self.0
33    }
34}
35
36#[cfg(feature = "futures-03-sink")]
37fn poll_write_helper<S>(
38    mut s: Pin<&mut ByteWriter<S>>,
39    cx: &mut Context<'_>,
40    buf: &[u8],
41) -> Poll<io::Result<usize>>
42where
43    S: futures_util::Sink<Message, Error = WsError> + Unpin,
44{
45    match Pin::new(&mut s.0).poll_ready(cx).map_err(convert_err) {
46        Poll::Ready(Ok(())) => {}
47        Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
48        Poll::Pending => return Poll::Pending,
49    }
50    let len = buf.len();
51    let msg = Message::binary(buf.to_owned());
52    Poll::Ready(
53        Pin::new(&mut s.0)
54            .start_send(msg)
55            .map_err(convert_err)
56            .map(|()| len),
57    )
58}
59
60#[cfg(feature = "futures-03-sink")]
61impl<S> futures_io::AsyncWrite for ByteWriter<S>
62where
63    S: futures_util::Sink<Message, Error = WsError> + Unpin,
64{
65    fn poll_write(
66        self: Pin<&mut Self>,
67        cx: &mut Context<'_>,
68        buf: &[u8],
69    ) -> Poll<io::Result<usize>> {
70        poll_write_helper(self, cx, buf)
71    }
72
73    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
74        Pin::new(&mut self.0).poll_flush(cx).map_err(convert_err)
75    }
76
77    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
78        Pin::new(&mut self.0).poll_close(cx).map_err(convert_err)
79    }
80}
81
82#[cfg(feature = "futures-03-sink")]
83#[cfg(feature = "tokio-runtime")]
84impl<S> tokio::io::AsyncWrite for ByteWriter<S>
85where
86    S: futures_util::Sink<Message, Error = WsError> + Unpin,
87{
88    fn poll_write(
89        self: Pin<&mut Self>,
90        cx: &mut Context<'_>,
91        buf: &[u8],
92    ) -> Poll<io::Result<usize>> {
93        poll_write_helper(self, cx, buf)
94    }
95
96    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
97        Pin::new(&mut self.0).poll_flush(cx).map_err(convert_err)
98    }
99
100    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
101        Pin::new(&mut self.0).poll_close(cx).map_err(convert_err)
102    }
103}
104
105/// Treat a `WebSocketStream` as an `AsyncRead` implementation.
106///
107/// This also works with any other `Stream` of `Message`, such as a `SplitStream`.
108///
109/// Each read will only return data from one message. If you want to combine data from multiple
110/// messages into one read, consider wrapping this in a `BufReader`.
111#[derive(Debug)]
112pub struct ByteReader<S> {
113    stream: S,
114    bytes: Option<Bytes>,
115}
116
117impl<S> ByteReader<S> {
118    /// Create a new `ByteReader` from a `Stream` that returns a WebSocket `Message`
119    #[inline(always)]
120    pub fn new(stream: S) -> Self {
121        Self {
122            stream,
123            bytes: None,
124        }
125    }
126}
127
128fn poll_read_helper<S>(
129    mut s: Pin<&mut ByteReader<S>>,
130    cx: &mut Context<'_>,
131    buf_len: usize,
132) -> Poll<io::Result<Option<Bytes>>>
133where
134    S: Stream<Item = Result<Message, WsError>> + Unpin,
135{
136    Poll::Ready(Ok(Some(match s.bytes {
137        None => match Pin::new(&mut s.stream).poll_next(cx) {
138            Poll::Pending => return Poll::Pending,
139            Poll::Ready(None) => return Poll::Ready(Ok(None)),
140            Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(convert_err(e))),
141            Poll::Ready(Some(Ok(msg))) => {
142                let bytes = msg.into_data();
143                if bytes.len() > buf_len {
144                    s.bytes.insert(bytes).split_to(buf_len)
145                } else {
146                    bytes
147                }
148            }
149        },
150        Some(ref mut bytes) if bytes.len() > buf_len => bytes.split_to(buf_len),
151        Some(ref mut bytes) => {
152            let bytes = bytes.clone();
153            s.bytes = None;
154            bytes
155        }
156    })))
157}
158
159impl<S> futures_io::AsyncRead for ByteReader<S>
160where
161    S: Stream<Item = Result<Message, WsError>> + Unpin,
162{
163    fn poll_read(
164        self: Pin<&mut Self>,
165        cx: &mut Context<'_>,
166        buf: &mut [u8],
167    ) -> Poll<io::Result<usize>> {
168        poll_read_helper(self, cx, buf.len()).map_ok(|bytes| {
169            bytes.map_or(0, |bytes| {
170                buf[..bytes.len()].copy_from_slice(&bytes);
171                bytes.len()
172            })
173        })
174    }
175}
176
177#[cfg(feature = "tokio-runtime")]
178impl<S> tokio::io::AsyncRead for ByteReader<S>
179where
180    S: Stream<Item = Result<Message, WsError>> + Unpin,
181{
182    fn poll_read(
183        self: Pin<&mut Self>,
184        cx: &mut Context<'_>,
185        buf: &mut tokio::io::ReadBuf,
186    ) -> Poll<io::Result<()>> {
187        poll_read_helper(self, cx, buf.remaining()).map_ok(|bytes| {
188            if let Some(ref bytes) = bytes {
189                buf.put_slice(bytes);
190            }
191        })
192    }
193}
194
195fn convert_err(e: WsError) -> io::Error {
196    match e {
197        WsError::Io(io) => io,
198        _ => io::Error::new(io::ErrorKind::Other, e),
199    }
200}