libp2p_tcp/provider/
tokio.rs

1// Copyright 2020 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use std::{
22    io, net,
23    pin::Pin,
24    task::{Context, Poll},
25};
26
27use futures::{
28    future::{BoxFuture, FutureExt},
29    prelude::*,
30};
31
32use super::{Incoming, Provider};
33
34/// A TCP [`Transport`](libp2p_core::Transport) that works with the `tokio` ecosystem.
35///
36/// # Example
37///
38/// ```rust
39/// # use libp2p_tcp as tcp;
40/// # use libp2p_core::{Transport, transport::ListenerId};
41/// # use futures::future;
42/// # use std::pin::Pin;
43/// #
44/// # #[tokio::main]
45/// # async fn main() {
46/// let mut transport = tcp::tokio::Transport::new(tcp::Config::default());
47/// let id = transport
48///     .listen_on(ListenerId::next(), "/ip4/127.0.0.1/tcp/0".parse().unwrap())
49///     .unwrap();
50///
51/// let addr = future::poll_fn(|cx| Pin::new(&mut transport).poll(cx))
52///     .await
53///     .into_new_address()
54///     .unwrap();
55///
56/// println!("Listening on {addr}");
57/// # }
58/// ```
59pub type Transport = crate::Transport<Tcp>;
60
61#[derive(Copy, Clone)]
62#[doc(hidden)]
63pub enum Tcp {}
64
65impl Provider for Tcp {
66    type Stream = TcpStream;
67    type Listener = tokio::net::TcpListener;
68    type IfWatcher = if_watch::tokio::IfWatcher;
69
70    fn new_if_watcher() -> io::Result<Self::IfWatcher> {
71        Self::IfWatcher::new()
72    }
73
74    fn addrs(if_watcher: &Self::IfWatcher) -> Vec<if_watch::IpNet> {
75        if_watcher.iter().copied().collect()
76    }
77
78    fn new_listener(l: net::TcpListener) -> io::Result<Self::Listener> {
79        tokio::net::TcpListener::try_from(l)
80    }
81
82    fn new_stream(s: net::TcpStream) -> BoxFuture<'static, io::Result<Self::Stream>> {
83        async move {
84            // Taken from [`tokio::net::TcpStream::connect_mio`].
85
86            let stream = tokio::net::TcpStream::try_from(s)?;
87
88            // Once we've connected, wait for the stream to be writable as
89            // that's when the actual connection has been initiated. Once we're
90            // writable we check for `take_socket_error` to see if the connect
91            // actually hit an error or not.
92            //
93            // If all that succeeded then we ship everything on up.
94            stream.writable().await?;
95
96            if let Some(e) = stream.take_error()? {
97                return Err(e);
98            }
99
100            Ok(TcpStream(stream))
101        }
102        .boxed()
103    }
104
105    fn poll_accept(
106        l: &mut Self::Listener,
107        cx: &mut Context<'_>,
108    ) -> Poll<io::Result<Incoming<Self::Stream>>> {
109        let (stream, remote_addr) = match l.poll_accept(cx) {
110            Poll::Pending => return Poll::Pending,
111            Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
112            Poll::Ready(Ok((stream, remote_addr))) => (stream, remote_addr),
113        };
114
115        let local_addr = stream.local_addr()?;
116        let stream = TcpStream(stream);
117
118        Poll::Ready(Ok(Incoming {
119            stream,
120            local_addr,
121            remote_addr,
122        }))
123    }
124}
125
126/// A [`tokio::net::TcpStream`] that implements [`AsyncRead`] and [`AsyncWrite`].
127#[derive(Debug)]
128pub struct TcpStream(pub tokio::net::TcpStream);
129
130impl From<TcpStream> for tokio::net::TcpStream {
131    fn from(t: TcpStream) -> tokio::net::TcpStream {
132        t.0
133    }
134}
135
136impl AsyncRead for TcpStream {
137    fn poll_read(
138        mut self: Pin<&mut Self>,
139        cx: &mut Context,
140        buf: &mut [u8],
141    ) -> Poll<Result<usize, io::Error>> {
142        let mut read_buf = tokio::io::ReadBuf::new(buf);
143        futures::ready!(tokio::io::AsyncRead::poll_read(
144            Pin::new(&mut self.0),
145            cx,
146            &mut read_buf
147        ))?;
148        Poll::Ready(Ok(read_buf.filled().len()))
149    }
150}
151
152impl AsyncWrite for TcpStream {
153    fn poll_write(
154        mut self: Pin<&mut Self>,
155        cx: &mut Context,
156        buf: &[u8],
157    ) -> Poll<Result<usize, io::Error>> {
158        tokio::io::AsyncWrite::poll_write(Pin::new(&mut self.0), cx, buf)
159    }
160
161    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
162        tokio::io::AsyncWrite::poll_flush(Pin::new(&mut self.0), cx)
163    }
164
165    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
166        tokio::io::AsyncWrite::poll_shutdown(Pin::new(&mut self.0), cx)
167    }
168
169    fn poll_write_vectored(
170        mut self: Pin<&mut Self>,
171        cx: &mut Context<'_>,
172        bufs: &[io::IoSlice<'_>],
173    ) -> Poll<io::Result<usize>> {
174        tokio::io::AsyncWrite::poll_write_vectored(Pin::new(&mut self.0), cx, bufs)
175    }
176}