hickory_proto/tcp/
tcp_client_stream.rs1use std::fmt::{self, Display};
9use std::io;
10use std::net::SocketAddr;
11use std::pin::Pin;
12use std::task::{Context, Poll};
13use std::time::Duration;
14
15#[cfg(feature = "tokio-runtime")]
16use async_trait::async_trait;
17use futures_util::{future::Future, stream::Stream, StreamExt, TryFutureExt};
18use tracing::warn;
19
20use crate::error::ProtoError;
21#[cfg(feature = "tokio-runtime")]
22use crate::iocompat::AsyncIoTokioAsStd;
23use crate::tcp::{Connect, DnsTcpStream, TcpStream};
24use crate::xfer::{DnsClientStream, SerialMessage};
25use crate::BufDnsStreamHandle;
26#[cfg(feature = "tokio-runtime")]
27use crate::TokioTime;
28
29#[must_use = "futures do nothing unless polled"]
33pub struct TcpClientStream<S>
34where
35 S: DnsTcpStream,
36{
37 tcp_stream: TcpStream<S>,
38}
39
40impl<S: Connect> TcpClientStream<S> {
41 #[allow(clippy::new_ret_no_self)]
49 pub fn new(name_server: SocketAddr) -> (TcpClientConnect<S>, BufDnsStreamHandle) {
50 Self::with_timeout(name_server, Duration::from_secs(5))
51 }
52
53 pub fn with_timeout(
60 name_server: SocketAddr,
61 timeout: Duration,
62 ) -> (TcpClientConnect<S>, BufDnsStreamHandle) {
63 Self::with_bind_addr_and_timeout(name_server, None, timeout)
64 }
65
66 #[allow(clippy::new_ret_no_self)]
74 pub fn with_bind_addr_and_timeout(
75 name_server: SocketAddr,
76 bind_addr: Option<SocketAddr>,
77 timeout: Duration,
78 ) -> (TcpClientConnect<S>, BufDnsStreamHandle) {
79 let (stream_future, sender) =
80 TcpStream::<S>::with_bind_addr_and_timeout(name_server, bind_addr, timeout);
81
82 let new_future = Box::pin(
83 stream_future
84 .map_ok(move |tcp_stream| Self { tcp_stream })
85 .map_err(ProtoError::from),
86 );
87
88 (TcpClientConnect(new_future), sender)
89 }
90}
91
92impl<S: DnsTcpStream> TcpClientStream<S> {
93 pub fn from_stream(tcp_stream: TcpStream<S>) -> Self {
95 Self { tcp_stream }
96 }
97
98 #[allow(clippy::new_ret_no_self)]
106 pub fn with_future<F: Future<Output = io::Result<S>> + Send + 'static>(
107 future: F,
108 name_server: SocketAddr,
109 timeout: Duration,
110 ) -> (TcpClientConnect<S>, BufDnsStreamHandle) {
111 let (stream_future, sender) = TcpStream::<S>::with_future(future, name_server, timeout);
112
113 let new_future = Box::pin(
114 stream_future
115 .map_ok(move |tcp_stream| Self { tcp_stream })
116 .map_err(ProtoError::from),
117 );
118
119 (TcpClientConnect(new_future), sender)
120 }
121}
122
123impl<S: DnsTcpStream> Display for TcpClientStream<S> {
124 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
125 write!(formatter, "TCP({})", self.tcp_stream.peer_addr())
126 }
127}
128
129impl<S: DnsTcpStream> DnsClientStream for TcpClientStream<S> {
130 type Time = S::Time;
131
132 fn name_server_addr(&self) -> SocketAddr {
133 self.tcp_stream.peer_addr()
134 }
135}
136
137impl<S: DnsTcpStream> Stream for TcpClientStream<S> {
138 type Item = Result<SerialMessage, ProtoError>;
139
140 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
141 let message = try_ready_stream!(self.tcp_stream.poll_next_unpin(cx));
142
143 let peer = self.tcp_stream.peer_addr();
145 if message.addr() != peer {
146 warn!("{} does not match name_server: {}", message.addr(), peer)
148 }
149
150 Poll::Ready(Some(Ok(message)))
151 }
152}
153
154pub struct TcpClientConnect<S: DnsTcpStream>(
157 Pin<Box<dyn Future<Output = Result<TcpClientStream<S>, ProtoError>> + Send + 'static>>,
158);
159
160impl<S: DnsTcpStream> Future for TcpClientConnect<S> {
161 type Output = Result<TcpClientStream<S>, ProtoError>;
162
163 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
164 self.0.as_mut().poll(cx)
165 }
166}
167
168#[cfg(feature = "tokio-runtime")]
169use tokio::net::TcpStream as TokioTcpStream;
170
171#[cfg(feature = "tokio-runtime")]
172impl<T> DnsTcpStream for AsyncIoTokioAsStd<T>
173where
174 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + Sync + Sized + 'static,
175{
176 type Time = TokioTime;
177}
178
179#[cfg(feature = "tokio-runtime")]
180#[async_trait]
181impl Connect for AsyncIoTokioAsStd<TokioTcpStream> {
182 async fn connect_with_bind(
183 addr: SocketAddr,
184 bind_addr: Option<SocketAddr>,
185 ) -> io::Result<Self> {
186 super::tokio::connect_with_bind(&addr, &bind_addr)
187 .await
188 .map(AsyncIoTokioAsStd)
189 }
190}
191
192#[cfg(test)]
193#[cfg(feature = "tokio-runtime")]
194mod tests {
195 use super::AsyncIoTokioAsStd;
196 #[cfg(not(target_os = "linux"))]
197 use std::net::Ipv6Addr;
198 use std::net::{IpAddr, Ipv4Addr};
199 use tokio::net::TcpStream as TokioTcpStream;
200 use tokio::runtime::Runtime;
201
202 use crate::tests::tcp_client_stream_test;
203 #[test]
204 fn test_tcp_stream_ipv4() {
205 let io_loop = Runtime::new().expect("failed to create tokio runtime");
206 tcp_client_stream_test::<AsyncIoTokioAsStd<TokioTcpStream>, Runtime>(
207 IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
208 io_loop,
209 )
210 }
211
212 #[test]
213 #[cfg(not(target_os = "linux"))] fn test_tcp_stream_ipv6() {
215 let io_loop = Runtime::new().expect("failed to create tokio runtime");
216 tcp_client_stream_test::<AsyncIoTokioAsStd<TokioTcpStream>, Runtime>(
217 IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
218 io_loop,
219 )
220 }
221}