compio_net/resolve/
mod.rs

1cfg_if::cfg_if! {
2    if #[cfg(windows)] {
3        #[path = "windows.rs"]
4        mod sys;
5    } else if #[cfg(unix)] {
6        #[path = "unix.rs"]
7        mod sys;
8    }
9}
10
11use std::{
12    future::Future,
13    io,
14    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
15};
16
17use compio_buf::{BufResult, buf_try};
18use either::Either;
19pub use sys::resolve_sock_addrs;
20
21/// A trait for objects which can be converted or resolved to one or more
22/// [`SocketAddr`] values.
23///
24/// See [`std::net::ToSocketAddrs`].
25///
26/// # Cancel safety
27///
28/// All implementation of [`ToSocketAddrsAsync`] in this crate is safe to cancel
29/// by dropping the future. The Glibc impl may leak the control block if the
30/// task is not completed when dropping.
31#[allow(async_fn_in_trait)]
32pub trait ToSocketAddrsAsync {
33    /// See [`std::net::ToSocketAddrs::Iter`].
34    type Iter: Iterator<Item = SocketAddr>;
35
36    /// See [`std::net::ToSocketAddrs::to_socket_addrs`].
37    async fn to_socket_addrs_async(&self) -> io::Result<Self::Iter>;
38}
39
40macro_rules! impl_to_socket_addrs_async {
41    ($($t:ty),* $(,)?) => {
42        $(
43            impl ToSocketAddrsAsync for $t {
44                type Iter = std::iter::Once<SocketAddr>;
45
46                async fn to_socket_addrs_async(&self) -> io::Result<Self::Iter> {
47                    Ok(std::iter::once(SocketAddr::from(*self)))
48                }
49            }
50        )*
51    }
52}
53
54impl_to_socket_addrs_async![
55    SocketAddr,
56    SocketAddrV4,
57    SocketAddrV6,
58    (IpAddr, u16),
59    (Ipv4Addr, u16),
60    (Ipv6Addr, u16),
61];
62
63impl ToSocketAddrsAsync for (&str, u16) {
64    type Iter = Either<std::iter::Once<SocketAddr>, std::vec::IntoIter<SocketAddr>>;
65
66    async fn to_socket_addrs_async(&self) -> io::Result<Self::Iter> {
67        let (host, port) = self;
68        if let Ok(addr) = host.parse::<Ipv4Addr>() {
69            return Ok(Either::Left(std::iter::once(SocketAddr::from((
70                addr, *port,
71            )))));
72        }
73        if let Ok(addr) = host.parse::<Ipv6Addr>() {
74            return Ok(Either::Left(std::iter::once(SocketAddr::from((
75                addr, *port,
76            )))));
77        }
78
79        resolve_sock_addrs(host, *port).await.map(Either::Right)
80    }
81}
82
83impl ToSocketAddrsAsync for (String, u16) {
84    type Iter = <(&'static str, u16) as ToSocketAddrsAsync>::Iter;
85
86    async fn to_socket_addrs_async(&self) -> io::Result<Self::Iter> {
87        (&*self.0, self.1).to_socket_addrs_async().await
88    }
89}
90
91impl ToSocketAddrsAsync for str {
92    type Iter = <(&'static str, u16) as ToSocketAddrsAsync>::Iter;
93
94    async fn to_socket_addrs_async(&self) -> io::Result<Self::Iter> {
95        if let Ok(addr) = self.parse::<SocketAddr>() {
96            return Ok(Either::Left(std::iter::once(addr)));
97        }
98
99        let (host, port_str) = self
100            .rsplit_once(':')
101            .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "invalid socket address"))?;
102        let port: u16 = port_str
103            .parse()
104            .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid port value"))?;
105        (host, port).to_socket_addrs_async().await
106    }
107}
108
109impl ToSocketAddrsAsync for String {
110    type Iter = <(&'static str, u16) as ToSocketAddrsAsync>::Iter;
111
112    async fn to_socket_addrs_async(&self) -> io::Result<Self::Iter> {
113        self.as_str().to_socket_addrs_async().await
114    }
115}
116
117impl<'a> ToSocketAddrsAsync for &'a [SocketAddr] {
118    type Iter = std::iter::Copied<std::slice::Iter<'a, SocketAddr>>;
119
120    async fn to_socket_addrs_async(&self) -> io::Result<Self::Iter> {
121        Ok(self.iter().copied())
122    }
123}
124
125impl<T: ToSocketAddrsAsync + ?Sized> ToSocketAddrsAsync for &T {
126    type Iter = T::Iter;
127
128    async fn to_socket_addrs_async(&self) -> io::Result<Self::Iter> {
129        (**self).to_socket_addrs_async().await
130    }
131}
132
133pub async fn each_addr<T, F: Future<Output = io::Result<T>>>(
134    addr: impl ToSocketAddrsAsync,
135    f: impl Fn(SocketAddr) -> F,
136) -> io::Result<T> {
137    let addrs = addr.to_socket_addrs_async().await?;
138    let mut last_err = None;
139    for addr in addrs {
140        match f(addr).await {
141            Ok(l) => return Ok(l),
142            Err(e) => last_err = Some(e),
143        }
144    }
145    Err(last_err.unwrap_or_else(|| {
146        io::Error::new(
147            io::ErrorKind::InvalidInput,
148            "could not resolve to any addresses",
149        )
150    }))
151}
152
153pub async fn first_addr_buf<T, B, F: Future<Output = BufResult<T, B>>>(
154    addr: impl ToSocketAddrsAsync,
155    buffer: B,
156    f: impl FnOnce(SocketAddr, B) -> F,
157) -> BufResult<T, B> {
158    let (mut addrs, buffer) = buf_try!(addr.to_socket_addrs_async().await, buffer);
159    if let Some(addr) = addrs.next() {
160        let (res, buffer) = buf_try!(f(addr, buffer).await);
161        BufResult(Ok(res), buffer)
162    } else {
163        BufResult(
164            Err(io::Error::new(
165                io::ErrorKind::InvalidInput,
166                "could not operate on first address",
167            )),
168            buffer,
169        )
170    }
171}