hickory_proto/tcp/
tcp_client_stream.rs1use alloc::boxed::Box;
9use core::fmt::{self, Display};
10use core::future::Future;
11use core::pin::Pin;
12use core::task::{Context, Poll};
13use core::time::Duration;
14use std::net::SocketAddr;
15
16use futures_util::{StreamExt, stream::Stream};
17use tracing::warn;
18
19use crate::BufDnsStreamHandle;
20use crate::error::ProtoError;
21use crate::runtime::RuntimeProvider;
22#[cfg(feature = "tokio")]
23use crate::runtime::TokioTime;
24#[cfg(feature = "tokio")]
25use crate::runtime::iocompat::AsyncIoTokioAsStd;
26use crate::tcp::{DnsTcpStream, TcpStream};
27use crate::xfer::{DnsClientStream, SerialMessage};
28
29#[must_use = "futures do nothing unless polled"]
33pub struct TcpClientStream<S>
34where
35 S: DnsTcpStream,
36{
37 tcp_stream: TcpStream<S>,
38}
39
40impl<S: DnsTcpStream> TcpClientStream<S> {
41 #[allow(clippy::type_complexity)]
43 pub fn new<P: RuntimeProvider<Tcp = S>>(
44 peer_addr: SocketAddr,
45 bind_addr: Option<SocketAddr>,
46 timeout: Option<Duration>,
47 provider: P,
48 ) -> (
49 Pin<Box<dyn Future<Output = Result<Self, ProtoError>> + Send + 'static>>,
50 BufDnsStreamHandle,
51 ) {
52 let (sender, outbound_messages) = BufDnsStreamHandle::new(peer_addr);
53 (
54 Box::pin(async move {
55 let tcp = provider.connect_tcp(peer_addr, bind_addr, timeout).await?;
56 Ok(Self::from_stream(TcpStream::from_stream_with_receiver(
57 tcp,
58 peer_addr,
59 outbound_messages,
60 )))
61 }),
62 sender,
63 )
64 }
65
66 pub fn from_stream(tcp_stream: TcpStream<S>) -> Self {
68 Self { tcp_stream }
69 }
70}
71
72impl<S: DnsTcpStream> Display for TcpClientStream<S> {
73 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
74 write!(formatter, "TCP({})", self.tcp_stream.peer_addr())
75 }
76}
77
78impl<S: DnsTcpStream> DnsClientStream for TcpClientStream<S> {
79 type Time = S::Time;
80
81 fn name_server_addr(&self) -> SocketAddr {
82 self.tcp_stream.peer_addr()
83 }
84}
85
86impl<S: DnsTcpStream> Stream for TcpClientStream<S> {
87 type Item = Result<SerialMessage, ProtoError>;
88
89 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
90 let message = try_ready_stream!(self.tcp_stream.poll_next_unpin(cx));
91
92 let peer = self.tcp_stream.peer_addr();
94 if message.addr() != peer {
95 warn!("{} does not match name_server: {}", message.addr(), peer)
97 }
98
99 Poll::Ready(Some(Ok(message)))
100 }
101}
102
103#[cfg(feature = "tokio")]
104impl<T> DnsTcpStream for AsyncIoTokioAsStd<T>
105where
106 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + Sync + Sized + 'static,
107{
108 type Time = TokioTime;
109}
110
111#[cfg(test)]
112#[cfg(feature = "tokio")]
113mod tests {
114 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
115
116 use test_support::subscribe;
117
118 use crate::runtime::TokioRuntimeProvider;
119 use crate::tests::tcp_client_stream_test;
120 #[tokio::test]
121 async fn test_tcp_stream_ipv4() {
122 subscribe();
123 tcp_client_stream_test(IpAddr::V4(Ipv4Addr::LOCALHOST), TokioRuntimeProvider::new()).await;
124 }
125
126 #[tokio::test]
127 async fn test_tcp_stream_ipv6() {
128 subscribe();
129 tcp_client_stream_test(
130 IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
131 TokioRuntimeProvider::new(),
132 )
133 .await;
134 }
135}