sqlx_core/net/socket/
mod.rs1use std::future::Future;
2use std::io;
3use std::path::Path;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7use bytes::BufMut;
8use futures_core::ready;
9
10pub use buffered::{BufferedSocket, WriteBuffer};
11
12use crate::io::ReadBuf;
13
14mod buffered;
15
16pub trait Socket: Send + Sync + Unpin + 'static {
17 fn try_read(&mut self, buf: &mut dyn ReadBuf) -> io::Result<usize>;
18
19 fn try_write(&mut self, buf: &[u8]) -> io::Result<usize>;
20
21 fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
22
23 fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
24
25 fn poll_flush(&mut self, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
26 Poll::Ready(Ok(()))
28 }
29
30 fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
31
32 fn read<'a, B: ReadBuf>(&'a mut self, buf: &'a mut B) -> Read<'a, Self, B>
33 where
34 Self: Sized,
35 {
36 Read { socket: self, buf }
37 }
38
39 fn write<'a>(&'a mut self, buf: &'a [u8]) -> Write<'a, Self>
40 where
41 Self: Sized,
42 {
43 Write { socket: self, buf }
44 }
45
46 fn flush(&mut self) -> Flush<'_, Self>
47 where
48 Self: Sized,
49 {
50 Flush { socket: self }
51 }
52
53 fn shutdown(&mut self) -> Shutdown<'_, Self>
54 where
55 Self: Sized,
56 {
57 Shutdown { socket: self }
58 }
59}
60
61pub struct Read<'a, S: ?Sized, B> {
62 socket: &'a mut S,
63 buf: &'a mut B,
64}
65
66impl<'a, S: ?Sized, B> Future for Read<'a, S, B>
67where
68 S: Socket,
69 B: ReadBuf,
70{
71 type Output = io::Result<usize>;
72
73 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
74 let this = &mut *self;
75
76 while this.buf.has_remaining_mut() {
77 match this.socket.try_read(&mut *this.buf) {
78 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
79 ready!(this.socket.poll_read_ready(cx))?;
80 }
81 ready => return Poll::Ready(ready),
82 }
83 }
84
85 Poll::Ready(Ok(0))
86 }
87}
88
89pub struct Write<'a, S: ?Sized> {
90 socket: &'a mut S,
91 buf: &'a [u8],
92}
93
94impl<'a, S: ?Sized> Future for Write<'a, S>
95where
96 S: Socket,
97{
98 type Output = io::Result<usize>;
99
100 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
101 let this = &mut *self;
102
103 while !this.buf.is_empty() {
104 match this.socket.try_write(this.buf) {
105 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
106 ready!(this.socket.poll_write_ready(cx))?;
107 }
108 ready => return Poll::Ready(ready),
109 }
110 }
111
112 Poll::Ready(Ok(0))
113 }
114}
115
116pub struct Flush<'a, S: ?Sized> {
117 socket: &'a mut S,
118}
119
120impl<'a, S: Socket + ?Sized> Future for Flush<'a, S> {
121 type Output = io::Result<()>;
122
123 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
124 self.socket.poll_flush(cx)
125 }
126}
127
128pub struct Shutdown<'a, S: ?Sized> {
129 socket: &'a mut S,
130}
131
132impl<'a, S: ?Sized> Future for Shutdown<'a, S>
133where
134 S: Socket,
135{
136 type Output = io::Result<()>;
137
138 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
139 self.socket.poll_shutdown(cx)
140 }
141}
142
143pub trait WithSocket {
144 type Output;
145
146 fn with_socket<S: Socket>(
147 self,
148 socket: S,
149 ) -> impl std::future::Future<Output = Self::Output> + Send;
150}
151
152pub struct SocketIntoBox;
153
154impl WithSocket for SocketIntoBox {
155 type Output = Box<dyn Socket>;
156
157 async fn with_socket<S: Socket>(self, socket: S) -> Self::Output {
158 Box::new(socket)
159 }
160}
161
162impl<S: Socket + ?Sized> Socket for Box<S> {
163 fn try_read(&mut self, buf: &mut dyn ReadBuf) -> io::Result<usize> {
164 (**self).try_read(buf)
165 }
166
167 fn try_write(&mut self, buf: &[u8]) -> io::Result<usize> {
168 (**self).try_write(buf)
169 }
170
171 fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
172 (**self).poll_read_ready(cx)
173 }
174
175 fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
176 (**self).poll_write_ready(cx)
177 }
178
179 fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
180 (**self).poll_flush(cx)
181 }
182
183 fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
184 (**self).poll_shutdown(cx)
185 }
186}
187
188pub async fn connect_tcp<Ws: WithSocket>(
189 host: &str,
190 port: u16,
191 with_socket: Ws,
192) -> crate::Result<Ws::Output> {
193 let host = host.trim_matches(&['[', ']'][..]);
195
196 #[cfg(feature = "_rt-tokio")]
197 if crate::rt::rt_tokio::available() {
198 use tokio::net::TcpStream;
199
200 let stream = TcpStream::connect((host, port)).await?;
201 stream.set_nodelay(true)?;
202
203 return Ok(with_socket.with_socket(stream).await);
204 }
205
206 #[cfg(feature = "_rt-async-std")]
207 {
208 use async_io::Async;
209 use async_std::net::ToSocketAddrs;
210 use std::net::TcpStream;
211
212 let mut last_err = None;
213
214 for socket_addr in (host, port).to_socket_addrs().await? {
216 let stream = Async::<TcpStream>::connect(socket_addr)
217 .await
218 .and_then(|s| {
219 s.get_ref().set_nodelay(true)?;
220 Ok(s)
221 });
222 match stream {
223 Ok(stream) => return Ok(with_socket.with_socket(stream).await),
224 Err(e) => last_err = Some(e),
225 }
226 }
227
228 match last_err {
231 Some(err) => Err(err.into()),
232 None => Err(io::Error::new(
233 io::ErrorKind::AddrNotAvailable,
234 "Hostname did not resolve to any addresses",
235 )
236 .into()),
237 }
238 }
239
240 #[cfg(not(feature = "_rt-async-std"))]
241 {
242 crate::rt::missing_rt((host, port, with_socket))
243 }
244}
245
246pub async fn connect_uds<P: AsRef<Path>, Ws: WithSocket>(
250 path: P,
251 with_socket: Ws,
252) -> crate::Result<Ws::Output> {
253 #[cfg(unix)]
254 {
255 #[cfg(feature = "_rt-tokio")]
256 if crate::rt::rt_tokio::available() {
257 use tokio::net::UnixStream;
258
259 let stream = UnixStream::connect(path).await?;
260
261 return Ok(with_socket.with_socket(stream).await);
262 }
263
264 #[cfg(feature = "_rt-async-std")]
265 {
266 use async_io::Async;
267 use std::os::unix::net::UnixStream;
268
269 let stream = Async::<UnixStream>::connect(path).await?;
270
271 Ok(with_socket.with_socket(stream).await)
272 }
273
274 #[cfg(not(feature = "_rt-async-std"))]
275 {
276 crate::rt::missing_rt((path, with_socket))
277 }
278 }
279
280 #[cfg(not(unix))]
281 {
282 drop((path, with_socket));
283
284 Err(io::Error::new(
285 io::ErrorKind::Unsupported,
286 "Unix domain sockets are not supported on this platform",
287 )
288 .into())
289 }
290}