sqlx_core/net/socket/
mod.rs

1use std::future::Future;
2use std::io;
3use std::path::Path;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7use bytes::BufMut;
8use futures_core::ready;
9
10pub use buffered::{BufferedSocket, WriteBuffer};
11
12use crate::io::ReadBuf;
13
14mod buffered;
15
16pub trait Socket: Send + Sync + Unpin + 'static {
17    fn try_read(&mut self, buf: &mut dyn ReadBuf) -> io::Result<usize>;
18
19    fn try_write(&mut self, buf: &[u8]) -> io::Result<usize>;
20
21    fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
22
23    fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
24
25    fn poll_flush(&mut self, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
26        // `flush()` is a no-op for TCP/UDS
27        Poll::Ready(Ok(()))
28    }
29
30    fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
31
32    fn read<'a, B: ReadBuf>(&'a mut self, buf: &'a mut B) -> Read<'a, Self, B>
33    where
34        Self: Sized,
35    {
36        Read { socket: self, buf }
37    }
38
39    fn write<'a>(&'a mut self, buf: &'a [u8]) -> Write<'a, Self>
40    where
41        Self: Sized,
42    {
43        Write { socket: self, buf }
44    }
45
46    fn flush(&mut self) -> Flush<'_, Self>
47    where
48        Self: Sized,
49    {
50        Flush { socket: self }
51    }
52
53    fn shutdown(&mut self) -> Shutdown<'_, Self>
54    where
55        Self: Sized,
56    {
57        Shutdown { socket: self }
58    }
59}
60
61pub struct Read<'a, S: ?Sized, B> {
62    socket: &'a mut S,
63    buf: &'a mut B,
64}
65
66impl<'a, S: ?Sized, B> Future for Read<'a, S, B>
67where
68    S: Socket,
69    B: ReadBuf,
70{
71    type Output = io::Result<usize>;
72
73    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
74        let this = &mut *self;
75
76        while this.buf.has_remaining_mut() {
77            match this.socket.try_read(&mut *this.buf) {
78                Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
79                    ready!(this.socket.poll_read_ready(cx))?;
80                }
81                ready => return Poll::Ready(ready),
82            }
83        }
84
85        Poll::Ready(Ok(0))
86    }
87}
88
89pub struct Write<'a, S: ?Sized> {
90    socket: &'a mut S,
91    buf: &'a [u8],
92}
93
94impl<'a, S: ?Sized> Future for Write<'a, S>
95where
96    S: Socket,
97{
98    type Output = io::Result<usize>;
99
100    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
101        let this = &mut *self;
102
103        while !this.buf.is_empty() {
104            match this.socket.try_write(this.buf) {
105                Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
106                    ready!(this.socket.poll_write_ready(cx))?;
107                }
108                ready => return Poll::Ready(ready),
109            }
110        }
111
112        Poll::Ready(Ok(0))
113    }
114}
115
116pub struct Flush<'a, S: ?Sized> {
117    socket: &'a mut S,
118}
119
120impl<'a, S: Socket + ?Sized> Future for Flush<'a, S> {
121    type Output = io::Result<()>;
122
123    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
124        self.socket.poll_flush(cx)
125    }
126}
127
128pub struct Shutdown<'a, S: ?Sized> {
129    socket: &'a mut S,
130}
131
132impl<'a, S: ?Sized> Future for Shutdown<'a, S>
133where
134    S: Socket,
135{
136    type Output = io::Result<()>;
137
138    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
139        self.socket.poll_shutdown(cx)
140    }
141}
142
143pub trait WithSocket {
144    type Output;
145
146    fn with_socket<S: Socket>(
147        self,
148        socket: S,
149    ) -> impl std::future::Future<Output = Self::Output> + Send;
150}
151
152pub struct SocketIntoBox;
153
154impl WithSocket for SocketIntoBox {
155    type Output = Box<dyn Socket>;
156
157    async fn with_socket<S: Socket>(self, socket: S) -> Self::Output {
158        Box::new(socket)
159    }
160}
161
162impl<S: Socket + ?Sized> Socket for Box<S> {
163    fn try_read(&mut self, buf: &mut dyn ReadBuf) -> io::Result<usize> {
164        (**self).try_read(buf)
165    }
166
167    fn try_write(&mut self, buf: &[u8]) -> io::Result<usize> {
168        (**self).try_write(buf)
169    }
170
171    fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
172        (**self).poll_read_ready(cx)
173    }
174
175    fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
176        (**self).poll_write_ready(cx)
177    }
178
179    fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
180        (**self).poll_flush(cx)
181    }
182
183    fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
184        (**self).poll_shutdown(cx)
185    }
186}
187
188pub async fn connect_tcp<Ws: WithSocket>(
189    host: &str,
190    port: u16,
191    with_socket: Ws,
192) -> crate::Result<Ws::Output> {
193    // IPv6 addresses in URLs will be wrapped in brackets and the `url` crate doesn't trim those.
194    let host = host.trim_matches(&['[', ']'][..]);
195
196    #[cfg(feature = "_rt-tokio")]
197    if crate::rt::rt_tokio::available() {
198        use tokio::net::TcpStream;
199
200        let stream = TcpStream::connect((host, port)).await?;
201        stream.set_nodelay(true)?;
202
203        return Ok(with_socket.with_socket(stream).await);
204    }
205
206    #[cfg(feature = "_rt-async-std")]
207    {
208        use async_io::Async;
209        use async_std::net::ToSocketAddrs;
210        use std::net::TcpStream;
211
212        let mut last_err = None;
213
214        // Loop through all the Socket Addresses that the hostname resolves to
215        for socket_addr in (host, port).to_socket_addrs().await? {
216            let stream = Async::<TcpStream>::connect(socket_addr)
217                .await
218                .and_then(|s| {
219                    s.get_ref().set_nodelay(true)?;
220                    Ok(s)
221                });
222            match stream {
223                Ok(stream) => return Ok(with_socket.with_socket(stream).await),
224                Err(e) => last_err = Some(e),
225            }
226        }
227
228        // If we reach this point, it means we failed to connect to any of the addresses.
229        // Return the last error we encountered, or a custom error if the hostname didn't resolve to any address.
230        match last_err {
231            Some(err) => Err(err.into()),
232            None => Err(io::Error::new(
233                io::ErrorKind::AddrNotAvailable,
234                "Hostname did not resolve to any addresses",
235            )
236            .into()),
237        }
238    }
239
240    #[cfg(not(feature = "_rt-async-std"))]
241    {
242        crate::rt::missing_rt((host, port, with_socket))
243    }
244}
245
246/// Connect a Unix Domain Socket at the given path.
247///
248/// Returns an error if Unix Domain Sockets are not supported on this platform.
249pub async fn connect_uds<P: AsRef<Path>, Ws: WithSocket>(
250    path: P,
251    with_socket: Ws,
252) -> crate::Result<Ws::Output> {
253    #[cfg(unix)]
254    {
255        #[cfg(feature = "_rt-tokio")]
256        if crate::rt::rt_tokio::available() {
257            use tokio::net::UnixStream;
258
259            let stream = UnixStream::connect(path).await?;
260
261            return Ok(with_socket.with_socket(stream).await);
262        }
263
264        #[cfg(feature = "_rt-async-std")]
265        {
266            use async_io::Async;
267            use std::os::unix::net::UnixStream;
268
269            let stream = Async::<UnixStream>::connect(path).await?;
270
271            Ok(with_socket.with_socket(stream).await)
272        }
273
274        #[cfg(not(feature = "_rt-async-std"))]
275        {
276            crate::rt::missing_rt((path, with_socket))
277        }
278    }
279
280    #[cfg(not(unix))]
281    {
282        drop((path, with_socket));
283
284        Err(io::Error::new(
285            io::ErrorKind::Unsupported,
286            "Unix domain sockets are not supported on this platform",
287        )
288        .into())
289    }
290}