hickory_proto/tcp/
tcp_client_stream.rs

1// Copyright 2015-2016 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use alloc::boxed::Box;
9use core::fmt::{self, Display};
10use core::future::Future;
11use core::pin::Pin;
12use core::task::{Context, Poll};
13use core::time::Duration;
14use std::net::SocketAddr;
15
16use futures_util::{StreamExt, stream::Stream};
17use tracing::warn;
18
19use crate::BufDnsStreamHandle;
20use crate::error::ProtoError;
21use crate::runtime::RuntimeProvider;
22#[cfg(feature = "tokio")]
23use crate::runtime::TokioTime;
24#[cfg(feature = "tokio")]
25use crate::runtime::iocompat::AsyncIoTokioAsStd;
26use crate::tcp::{DnsTcpStream, TcpStream};
27use crate::xfer::{DnsClientStream, SerialMessage};
28
29/// Tcp client stream
30///
31/// Use with `hickory_client::client::DnsMultiplexer` impls
32#[must_use = "futures do nothing unless polled"]
33pub struct TcpClientStream<S>
34where
35    S: DnsTcpStream,
36{
37    tcp_stream: TcpStream<S>,
38}
39
40impl<S: DnsTcpStream> TcpClientStream<S> {
41    /// Create a new TcpClientStream
42    #[allow(clippy::type_complexity)]
43    pub fn new<P: RuntimeProvider<Tcp = S>>(
44        peer_addr: SocketAddr,
45        bind_addr: Option<SocketAddr>,
46        timeout: Option<Duration>,
47        provider: P,
48    ) -> (
49        Pin<Box<dyn Future<Output = Result<Self, ProtoError>> + Send + 'static>>,
50        BufDnsStreamHandle,
51    ) {
52        let (sender, outbound_messages) = BufDnsStreamHandle::new(peer_addr);
53        (
54            Box::pin(async move {
55                let tcp = provider.connect_tcp(peer_addr, bind_addr, timeout).await?;
56                Ok(Self::from_stream(TcpStream::from_stream_with_receiver(
57                    tcp,
58                    peer_addr,
59                    outbound_messages,
60                )))
61            }),
62            sender,
63        )
64    }
65
66    /// Wraps the TcpStream in TcpClientStream
67    pub fn from_stream(tcp_stream: TcpStream<S>) -> Self {
68        Self { tcp_stream }
69    }
70}
71
72impl<S: DnsTcpStream> Display for TcpClientStream<S> {
73    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
74        write!(formatter, "TCP({})", self.tcp_stream.peer_addr())
75    }
76}
77
78impl<S: DnsTcpStream> DnsClientStream for TcpClientStream<S> {
79    type Time = S::Time;
80
81    fn name_server_addr(&self) -> SocketAddr {
82        self.tcp_stream.peer_addr()
83    }
84}
85
86impl<S: DnsTcpStream> Stream for TcpClientStream<S> {
87    type Item = Result<SerialMessage, ProtoError>;
88
89    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
90        let message = try_ready_stream!(self.tcp_stream.poll_next_unpin(cx));
91
92        // this is busted if the tcp connection doesn't have a peer
93        let peer = self.tcp_stream.peer_addr();
94        if message.addr() != peer {
95            // TODO: this should be an error, right?
96            warn!("{} does not match name_server: {}", message.addr(), peer)
97        }
98
99        Poll::Ready(Some(Ok(message)))
100    }
101}
102
103#[cfg(feature = "tokio")]
104impl<T> DnsTcpStream for AsyncIoTokioAsStd<T>
105where
106    T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + Sync + Sized + 'static,
107{
108    type Time = TokioTime;
109}
110
111#[cfg(test)]
112#[cfg(feature = "tokio")]
113mod tests {
114    use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
115
116    use test_support::subscribe;
117
118    use crate::runtime::TokioRuntimeProvider;
119    use crate::tests::tcp_client_stream_test;
120    #[tokio::test]
121    async fn test_tcp_stream_ipv4() {
122        subscribe();
123        tcp_client_stream_test(IpAddr::V4(Ipv4Addr::LOCALHOST), TokioRuntimeProvider::new()).await;
124    }
125
126    #[tokio::test]
127    async fn test_tcp_stream_ipv6() {
128        subscribe();
129        tcp_client_stream_test(
130            IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
131            TokioRuntimeProvider::new(),
132        )
133        .await;
134    }
135}