hickory_proto/tcp/
tcp_client_stream.rs

1// Copyright 2015-2016 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use std::fmt::{self, Display};
9use std::io;
10use std::net::SocketAddr;
11use std::pin::Pin;
12use std::task::{Context, Poll};
13use std::time::Duration;
14
15#[cfg(feature = "tokio-runtime")]
16use async_trait::async_trait;
17use futures_util::{future::Future, stream::Stream, StreamExt, TryFutureExt};
18use tracing::warn;
19
20use crate::error::ProtoError;
21#[cfg(feature = "tokio-runtime")]
22use crate::iocompat::AsyncIoTokioAsStd;
23use crate::tcp::{Connect, DnsTcpStream, TcpStream};
24use crate::xfer::{DnsClientStream, SerialMessage};
25use crate::BufDnsStreamHandle;
26#[cfg(feature = "tokio-runtime")]
27use crate::TokioTime;
28
29/// Tcp client stream
30///
31/// Use with `hickory_client::client::DnsMultiplexer` impls
32#[must_use = "futures do nothing unless polled"]
33pub struct TcpClientStream<S>
34where
35    S: DnsTcpStream,
36{
37    tcp_stream: TcpStream<S>,
38}
39
40impl<S: Connect> TcpClientStream<S> {
41    /// Constructs a new TcpStream for a client to the specified SocketAddr.
42    ///
43    /// Defaults to a 5 second timeout
44    ///
45    /// # Arguments
46    ///
47    /// * `name_server` - the IP and Port of the DNS server to connect to
48    #[allow(clippy::new_ret_no_self)]
49    pub fn new(name_server: SocketAddr) -> (TcpClientConnect<S>, BufDnsStreamHandle) {
50        Self::with_timeout(name_server, Duration::from_secs(5))
51    }
52
53    /// Constructs a new TcpStream for a client to the specified SocketAddr.
54    ///
55    /// # Arguments
56    ///
57    /// * `name_server` - the IP and Port of the DNS server to connect to
58    /// * `timeout` - connection timeout
59    pub fn with_timeout(
60        name_server: SocketAddr,
61        timeout: Duration,
62    ) -> (TcpClientConnect<S>, BufDnsStreamHandle) {
63        Self::with_bind_addr_and_timeout(name_server, None, timeout)
64    }
65
66    /// Constructs a new TcpStream for a client to the specified SocketAddr.
67    ///
68    /// # Arguments
69    ///
70    /// * `name_server` - the IP and Port of the DNS server to connect to
71    /// * `bind_addr` - the IP and port to connect from
72    /// * `timeout` - connection timeout
73    #[allow(clippy::new_ret_no_self)]
74    pub fn with_bind_addr_and_timeout(
75        name_server: SocketAddr,
76        bind_addr: Option<SocketAddr>,
77        timeout: Duration,
78    ) -> (TcpClientConnect<S>, BufDnsStreamHandle) {
79        let (stream_future, sender) =
80            TcpStream::<S>::with_bind_addr_and_timeout(name_server, bind_addr, timeout);
81
82        let new_future = Box::pin(
83            stream_future
84                .map_ok(move |tcp_stream| Self { tcp_stream })
85                .map_err(ProtoError::from),
86        );
87
88        (TcpClientConnect(new_future), sender)
89    }
90}
91
92impl<S: DnsTcpStream> TcpClientStream<S> {
93    /// Wraps the TcpStream in TcpClientStream
94    pub fn from_stream(tcp_stream: TcpStream<S>) -> Self {
95        Self { tcp_stream }
96    }
97
98    /// Constructs a new TcpStream for a client to the specified SocketAddr.
99    ///
100    /// # Arguments
101    ///
102    /// * `future` - a future of a connecting tcp
103    /// * `name_server` - the IP and Port of the DNS server to connect to
104    /// * `timeout` - connection timeout
105    #[allow(clippy::new_ret_no_self)]
106    pub fn with_future<F: Future<Output = io::Result<S>> + Send + 'static>(
107        future: F,
108        name_server: SocketAddr,
109        timeout: Duration,
110    ) -> (TcpClientConnect<S>, BufDnsStreamHandle) {
111        let (stream_future, sender) = TcpStream::<S>::with_future(future, name_server, timeout);
112
113        let new_future = Box::pin(
114            stream_future
115                .map_ok(move |tcp_stream| Self { tcp_stream })
116                .map_err(ProtoError::from),
117        );
118
119        (TcpClientConnect(new_future), sender)
120    }
121}
122
123impl<S: DnsTcpStream> Display for TcpClientStream<S> {
124    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
125        write!(formatter, "TCP({})", self.tcp_stream.peer_addr())
126    }
127}
128
129impl<S: DnsTcpStream> DnsClientStream for TcpClientStream<S> {
130    type Time = S::Time;
131
132    fn name_server_addr(&self) -> SocketAddr {
133        self.tcp_stream.peer_addr()
134    }
135}
136
137impl<S: DnsTcpStream> Stream for TcpClientStream<S> {
138    type Item = Result<SerialMessage, ProtoError>;
139
140    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
141        let message = try_ready_stream!(self.tcp_stream.poll_next_unpin(cx));
142
143        // this is busted if the tcp connection doesn't have a peer
144        let peer = self.tcp_stream.peer_addr();
145        if message.addr() != peer {
146            // TODO: this should be an error, right?
147            warn!("{} does not match name_server: {}", message.addr(), peer)
148        }
149
150        Poll::Ready(Some(Ok(message)))
151    }
152}
153
154// TODO: create unboxed future for the TCP Stream
155/// A future that resolves to an TcpClientStream
156pub struct TcpClientConnect<S: DnsTcpStream>(
157    Pin<Box<dyn Future<Output = Result<TcpClientStream<S>, ProtoError>> + Send + 'static>>,
158);
159
160impl<S: DnsTcpStream> Future for TcpClientConnect<S> {
161    type Output = Result<TcpClientStream<S>, ProtoError>;
162
163    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
164        self.0.as_mut().poll(cx)
165    }
166}
167
168#[cfg(feature = "tokio-runtime")]
169use tokio::net::TcpStream as TokioTcpStream;
170
171#[cfg(feature = "tokio-runtime")]
172impl<T> DnsTcpStream for AsyncIoTokioAsStd<T>
173where
174    T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + Sync + Sized + 'static,
175{
176    type Time = TokioTime;
177}
178
179#[cfg(feature = "tokio-runtime")]
180#[async_trait]
181impl Connect for AsyncIoTokioAsStd<TokioTcpStream> {
182    async fn connect_with_bind(
183        addr: SocketAddr,
184        bind_addr: Option<SocketAddr>,
185    ) -> io::Result<Self> {
186        super::tokio::connect_with_bind(&addr, &bind_addr)
187            .await
188            .map(AsyncIoTokioAsStd)
189    }
190}
191
192#[cfg(test)]
193#[cfg(feature = "tokio-runtime")]
194mod tests {
195    use super::AsyncIoTokioAsStd;
196    #[cfg(not(target_os = "linux"))]
197    use std::net::Ipv6Addr;
198    use std::net::{IpAddr, Ipv4Addr};
199    use tokio::net::TcpStream as TokioTcpStream;
200    use tokio::runtime::Runtime;
201
202    use crate::tests::tcp_client_stream_test;
203    #[test]
204    fn test_tcp_stream_ipv4() {
205        let io_loop = Runtime::new().expect("failed to create tokio runtime");
206        tcp_client_stream_test::<AsyncIoTokioAsStd<TokioTcpStream>, Runtime>(
207            IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
208            io_loop,
209        )
210    }
211
212    #[test]
213    #[cfg(not(target_os = "linux"))] // ignored until Travis-CI fixes IPv6
214    fn test_tcp_stream_ipv6() {
215        let io_loop = Runtime::new().expect("failed to create tokio runtime");
216        tcp_client_stream_test::<AsyncIoTokioAsStd<TokioTcpStream>, Runtime>(
217            IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
218            io_loop,
219        )
220    }
221}