compio_net/
tcp.rs

1use std::{future::Future, io, net::SocketAddr};
2
3use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
4use compio_driver::impl_raw_fd;
5use compio_io::{AsyncRead, AsyncWrite};
6use socket2::{Protocol, SockAddr, Socket as Socket2, Type};
7
8use crate::{
9    OwnedReadHalf, OwnedWriteHalf, PollFd, ReadHalf, Socket, ToSocketAddrsAsync, WriteHalf,
10};
11
12/// A TCP socket server, listening for connections.
13///
14/// You can accept a new connection by using the
15/// [`accept`](`TcpListener::accept`) method.
16///
17/// # Examples
18///
19/// ```
20/// use std::net::SocketAddr;
21///
22/// use compio_io::{AsyncReadExt, AsyncWriteExt};
23/// use compio_net::{TcpListener, TcpStream};
24/// use socket2::SockAddr;
25///
26/// # compio_runtime::Runtime::new().unwrap().block_on(async move {
27/// let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
28///
29/// let addr = listener.local_addr().unwrap();
30///
31/// let tx_fut = TcpStream::connect(&addr);
32///
33/// let rx_fut = listener.accept();
34///
35/// let (mut tx, (mut rx, _)) = futures_util::try_join!(tx_fut, rx_fut).unwrap();
36///
37/// tx.write_all("test").await.0.unwrap();
38///
39/// let (_, buf) = rx.read_exact(Vec::with_capacity(4)).await.unwrap();
40///
41/// assert_eq!(buf, b"test");
42/// # });
43/// ```
44#[derive(Debug, Clone)]
45pub struct TcpListener {
46    inner: Socket,
47}
48
49impl TcpListener {
50    /// Creates a new `TcpListener`, which will be bound to the specified
51    /// address.
52    ///
53    /// The returned listener is ready for accepting connections.
54    ///
55    /// Binding with a port number of 0 will request that the OS assigns a port
56    /// to this listener.
57    pub async fn bind(addr: impl ToSocketAddrsAsync) -> io::Result<Self> {
58        super::each_addr(addr, |addr| async move {
59            let sa = SockAddr::from(addr);
60            let socket = Socket::new(sa.domain(), Type::STREAM, Some(Protocol::TCP)).await?;
61            socket.socket.set_reuse_address(true)?;
62            socket.socket.bind(&sa)?;
63            socket.listen(128)?;
64            Ok(Self { inner: socket })
65        })
66        .await
67    }
68
69    /// Creates new TcpListener from a [`std::net::TcpListener`].
70    pub fn from_std(stream: std::net::TcpListener) -> io::Result<Self> {
71        Ok(Self {
72            inner: Socket::from_socket2(Socket2::from(stream))?,
73        })
74    }
75
76    /// Close the socket. If the returned future is dropped before polling, the
77    /// socket won't be closed.
78    pub fn close(self) -> impl Future<Output = io::Result<()>> {
79        self.inner.close()
80    }
81
82    /// Accepts a new incoming connection from this listener.
83    ///
84    /// This function will yield once a new TCP connection is established. When
85    /// established, the corresponding [`TcpStream`] and the remote peer's
86    /// address will be returned.
87    pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
88        let (socket, addr) = self.inner.accept().await?;
89        let stream = TcpStream { inner: socket };
90        Ok((stream, addr.as_socket().expect("should be SocketAddr")))
91    }
92
93    /// Returns the local address that this listener is bound to.
94    ///
95    /// This can be useful, for example, when binding to port 0 to
96    /// figure out which port was actually bound.
97    ///
98    /// # Examples
99    ///
100    /// ```
101    /// use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
102    ///
103    /// use compio_net::TcpListener;
104    /// use socket2::SockAddr;
105    ///
106    /// # compio_runtime::Runtime::new().unwrap().block_on(async {
107    /// let listener = TcpListener::bind("127.0.0.1:8080").await.unwrap();
108    ///
109    /// let addr = listener.local_addr().expect("Couldn't get local address");
110    /// assert_eq!(
111    ///     addr,
112    ///     SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 8080))
113    /// );
114    /// # });
115    /// ```
116    pub fn local_addr(&self) -> io::Result<SocketAddr> {
117        self.inner
118            .local_addr()
119            .map(|addr| addr.as_socket().expect("should be SocketAddr"))
120    }
121}
122
123impl_raw_fd!(TcpListener, socket2::Socket, inner, socket);
124
125/// A TCP stream between a local and a remote socket.
126///
127/// A TCP stream can either be created by connecting to an endpoint, via the
128/// `connect` method, or by accepting a connection from a listener.
129///
130/// # Examples
131///
132/// ```no_run
133/// use std::net::SocketAddr;
134///
135/// use compio_io::AsyncWrite;
136/// use compio_net::TcpStream;
137///
138/// # compio_runtime::Runtime::new().unwrap().block_on(async {
139/// // Connect to a peer
140/// let mut stream = TcpStream::connect("127.0.0.1:8080").await.unwrap();
141///
142/// // Write some data.
143/// stream.write("hello world!").await.unwrap();
144/// # })
145/// ```
146#[derive(Debug, Clone)]
147pub struct TcpStream {
148    inner: Socket,
149}
150
151impl TcpStream {
152    /// Opens a TCP connection to a remote host.
153    pub async fn connect(addr: impl ToSocketAddrsAsync) -> io::Result<Self> {
154        use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
155
156        super::each_addr(addr, |addr| async move {
157            let addr2 = SockAddr::from(addr);
158            let socket = if cfg!(windows) {
159                let bind_addr = if addr.is_ipv4() {
160                    SockAddr::from(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))
161                } else if addr.is_ipv6() {
162                    SockAddr::from(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0))
163                } else {
164                    return Err(io::Error::new(
165                        io::ErrorKind::AddrNotAvailable,
166                        "Unsupported address domain.",
167                    ));
168                };
169                Socket::bind(&bind_addr, Type::STREAM, Some(Protocol::TCP)).await?
170            } else {
171                Socket::new(addr2.domain(), Type::STREAM, Some(Protocol::TCP)).await?
172            };
173            socket.connect_async(&addr2).await?;
174            Ok(Self { inner: socket })
175        })
176        .await
177    }
178
179    /// Creates new TcpStream from a [`std::net::TcpStream`].
180    pub fn from_std(stream: std::net::TcpStream) -> io::Result<Self> {
181        Ok(Self {
182            inner: Socket::from_socket2(Socket2::from(stream))?,
183        })
184    }
185
186    /// Close the socket. If the returned future is dropped before polling, the
187    /// socket won't be closed.
188    pub fn close(self) -> impl Future<Output = io::Result<()>> {
189        self.inner.close()
190    }
191
192    /// Returns the socket address of the remote peer of this TCP connection.
193    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
194        self.inner
195            .peer_addr()
196            .map(|addr| addr.as_socket().expect("should be SocketAddr"))
197    }
198
199    /// Returns the socket address of the local half of this TCP connection.
200    pub fn local_addr(&self) -> io::Result<SocketAddr> {
201        self.inner
202            .local_addr()
203            .map(|addr| addr.as_socket().expect("should be SocketAddr"))
204    }
205
206    /// Splits a [`TcpStream`] into a read half and a write half, which can be
207    /// used to read and write the stream concurrently.
208    ///
209    /// This method is more efficient than
210    /// [`into_split`](TcpStream::into_split), but the halves cannot
211    /// be moved into independently spawned tasks.
212    pub fn split(&self) -> (ReadHalf<Self>, WriteHalf<Self>) {
213        crate::split(self)
214    }
215
216    /// Splits a [`TcpStream`] into a read half and a write half, which can be
217    /// used to read and write the stream concurrently.
218    ///
219    /// Unlike [`split`](TcpStream::split), the owned halves can be moved to
220    /// separate tasks, however this comes at the cost of a heap allocation.
221    pub fn into_split(self) -> (OwnedReadHalf<Self>, OwnedWriteHalf<Self>) {
222        crate::into_split(self)
223    }
224
225    /// Create [`PollFd`] from inner socket.
226    pub fn to_poll_fd(&self) -> io::Result<PollFd<Socket2>> {
227        self.inner.to_poll_fd()
228    }
229
230    /// Create [`PollFd`] from inner socket.
231    pub fn into_poll_fd(self) -> io::Result<PollFd<Socket2>> {
232        self.inner.into_poll_fd()
233    }
234}
235
236impl AsyncRead for TcpStream {
237    #[inline]
238    async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
239        (&*self).read(buf).await
240    }
241
242    #[inline]
243    async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
244        (&*self).read_vectored(buf).await
245    }
246}
247
248impl AsyncRead for &TcpStream {
249    #[inline]
250    async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
251        self.inner.recv(buf).await
252    }
253
254    #[inline]
255    async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
256        self.inner.recv_vectored(buf).await
257    }
258}
259
260impl AsyncWrite for TcpStream {
261    #[inline]
262    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
263        (&*self).write(buf).await
264    }
265
266    #[inline]
267    async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
268        (&*self).write_vectored(buf).await
269    }
270
271    #[inline]
272    async fn flush(&mut self) -> io::Result<()> {
273        (&*self).flush().await
274    }
275
276    #[inline]
277    async fn shutdown(&mut self) -> io::Result<()> {
278        (&*self).shutdown().await
279    }
280}
281
282impl AsyncWrite for &TcpStream {
283    #[inline]
284    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
285        self.inner.send(buf).await
286    }
287
288    #[inline]
289    async fn write_vectored<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<usize, T> {
290        self.inner.send_vectored(buf).await
291    }
292
293    #[inline]
294    async fn flush(&mut self) -> io::Result<()> {
295        Ok(())
296    }
297
298    #[inline]
299    async fn shutdown(&mut self) -> io::Result<()> {
300        self.inner.shutdown().await
301    }
302}
303
304impl_raw_fd!(TcpStream, socket2::Socket, inner, socket);