async_tungstenite/
bytes.rs1use std::{
4 io,
5 pin::Pin,
6 task::{Context, Poll},
7};
8
9use futures_core::stream::Stream;
10
11use crate::{tungstenite::Bytes, Message, WsError};
12
13#[cfg(feature = "futures-03-sink")]
18#[derive(Debug)]
19pub struct ByteWriter<S>(S);
20
21#[cfg(feature = "futures-03-sink")]
22impl<S> ByteWriter<S> {
23 #[inline(always)]
25 pub fn new(s: S) -> Self {
26 Self(s)
27 }
28
29 #[inline(always)]
31 pub fn into_inner(self) -> S {
32 self.0
33 }
34}
35
36#[cfg(feature = "futures-03-sink")]
37fn poll_write_helper<S>(
38 mut s: Pin<&mut ByteWriter<S>>,
39 cx: &mut Context<'_>,
40 buf: &[u8],
41) -> Poll<io::Result<usize>>
42where
43 S: futures_util::Sink<Message, Error = WsError> + Unpin,
44{
45 match Pin::new(&mut s.0).poll_ready(cx).map_err(convert_err) {
46 Poll::Ready(Ok(())) => {}
47 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
48 Poll::Pending => return Poll::Pending,
49 }
50 let len = buf.len();
51 let msg = Message::binary(buf.to_owned());
52 Poll::Ready(
53 Pin::new(&mut s.0)
54 .start_send(msg)
55 .map_err(convert_err)
56 .map(|()| len),
57 )
58}
59
60#[cfg(feature = "futures-03-sink")]
61impl<S> futures_io::AsyncWrite for ByteWriter<S>
62where
63 S: futures_util::Sink<Message, Error = WsError> + Unpin,
64{
65 fn poll_write(
66 self: Pin<&mut Self>,
67 cx: &mut Context<'_>,
68 buf: &[u8],
69 ) -> Poll<io::Result<usize>> {
70 poll_write_helper(self, cx, buf)
71 }
72
73 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
74 Pin::new(&mut self.0).poll_flush(cx).map_err(convert_err)
75 }
76
77 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
78 Pin::new(&mut self.0).poll_close(cx).map_err(convert_err)
79 }
80}
81
82#[cfg(feature = "futures-03-sink")]
83#[cfg(feature = "tokio-runtime")]
84impl<S> tokio::io::AsyncWrite for ByteWriter<S>
85where
86 S: futures_util::Sink<Message, Error = WsError> + Unpin,
87{
88 fn poll_write(
89 self: Pin<&mut Self>,
90 cx: &mut Context<'_>,
91 buf: &[u8],
92 ) -> Poll<io::Result<usize>> {
93 poll_write_helper(self, cx, buf)
94 }
95
96 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
97 Pin::new(&mut self.0).poll_flush(cx).map_err(convert_err)
98 }
99
100 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
101 Pin::new(&mut self.0).poll_close(cx).map_err(convert_err)
102 }
103}
104
105#[derive(Debug)]
112pub struct ByteReader<S> {
113 stream: S,
114 bytes: Option<Bytes>,
115}
116
117impl<S> ByteReader<S> {
118 #[inline(always)]
120 pub fn new(stream: S) -> Self {
121 Self {
122 stream,
123 bytes: None,
124 }
125 }
126}
127
128fn poll_read_helper<S>(
129 mut s: Pin<&mut ByteReader<S>>,
130 cx: &mut Context<'_>,
131 buf_len: usize,
132) -> Poll<io::Result<Option<Bytes>>>
133where
134 S: Stream<Item = Result<Message, WsError>> + Unpin,
135{
136 Poll::Ready(Ok(Some(match s.bytes {
137 None => match Pin::new(&mut s.stream).poll_next(cx) {
138 Poll::Pending => return Poll::Pending,
139 Poll::Ready(None) => return Poll::Ready(Ok(None)),
140 Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(convert_err(e))),
141 Poll::Ready(Some(Ok(msg))) => {
142 let bytes = msg.into_data();
143 if bytes.len() > buf_len {
144 s.bytes.insert(bytes).split_to(buf_len)
145 } else {
146 bytes
147 }
148 }
149 },
150 Some(ref mut bytes) if bytes.len() > buf_len => bytes.split_to(buf_len),
151 Some(ref mut bytes) => {
152 let bytes = bytes.clone();
153 s.bytes = None;
154 bytes
155 }
156 })))
157}
158
159impl<S> futures_io::AsyncRead for ByteReader<S>
160where
161 S: Stream<Item = Result<Message, WsError>> + Unpin,
162{
163 fn poll_read(
164 self: Pin<&mut Self>,
165 cx: &mut Context<'_>,
166 buf: &mut [u8],
167 ) -> Poll<io::Result<usize>> {
168 poll_read_helper(self, cx, buf.len()).map_ok(|bytes| {
169 bytes.map_or(0, |bytes| {
170 buf[..bytes.len()].copy_from_slice(&bytes);
171 bytes.len()
172 })
173 })
174 }
175}
176
177#[cfg(feature = "tokio-runtime")]
178impl<S> tokio::io::AsyncRead for ByteReader<S>
179where
180 S: Stream<Item = Result<Message, WsError>> + Unpin,
181{
182 fn poll_read(
183 self: Pin<&mut Self>,
184 cx: &mut Context<'_>,
185 buf: &mut tokio::io::ReadBuf,
186 ) -> Poll<io::Result<()>> {
187 poll_read_helper(self, cx, buf.remaining()).map_ok(|bytes| {
188 if let Some(ref bytes) = bytes {
189 buf.put_slice(bytes);
190 }
191 })
192 }
193}
194
195fn convert_err(e: WsError) -> io::Error {
196 match e {
197 WsError::Io(io) => io,
198 _ => io::Error::new(io::ErrorKind::Other, e),
199 }
200}