hickory_proto/xfer/
mod.rs

1//! DNS high level transit implementations.
2//!
3//! Primarily there are two types in this module of interest, the `DnsMultiplexer` type and the `DnsHandle` type. `DnsMultiplexer` can be thought of as the state machine responsible for sending and receiving DNS messages. `DnsHandle` is the type given to API users of the `hickory-proto` library to send messages into the `DnsMultiplexer` for delivery. Finally there is the `DnsRequest` type. This allows for customizations, through `DnsRequestOptions`, to the delivery of messages via a `DnsMultiplexer`.
4//!
5//! TODO: this module needs some serious refactoring and normalization.
6
7#[cfg(feature = "std")]
8use core::fmt::Display;
9use core::fmt::{self, Debug};
10use core::future::Future;
11use core::pin::Pin;
12use core::task::{Context, Poll};
13use core::time::Duration;
14#[cfg(feature = "std")]
15use std::net::SocketAddr;
16
17#[cfg(feature = "std")]
18use futures_channel::mpsc;
19#[cfg(feature = "std")]
20use futures_channel::oneshot;
21use futures_util::ready;
22#[cfg(feature = "std")]
23use futures_util::stream::{Fuse, Peekable};
24use futures_util::stream::{Stream, StreamExt};
25#[cfg(feature = "serde")]
26use serde::{Deserialize, Serialize};
27#[cfg(feature = "std")]
28use tracing::{debug, warn};
29
30use crate::error::{ProtoError, ProtoErrorKind};
31#[cfg(feature = "std")]
32use crate::runtime::Time;
33
34#[cfg(feature = "std")]
35mod dns_exchange;
36pub mod dns_handle;
37#[cfg(feature = "std")]
38pub mod dns_multiplexer;
39pub mod dns_request;
40pub mod dns_response;
41pub mod retry_dns_handle;
42mod serial_message;
43
44#[cfg(feature = "std")]
45pub use self::dns_exchange::{
46    DnsExchange, DnsExchangeBackground, DnsExchangeConnect, DnsExchangeSend,
47};
48pub use self::dns_handle::{DnsHandle, DnsStreamHandle};
49#[cfg(feature = "std")]
50pub use self::dns_multiplexer::{DnsMultiplexer, DnsMultiplexerConnect};
51pub use self::dns_request::{DnsRequest, DnsRequestOptions};
52pub use self::dns_response::DnsResponse;
53#[cfg(feature = "std")]
54pub use self::dns_response::DnsResponseStream;
55pub use self::retry_dns_handle::RetryDnsHandle;
56pub use self::serial_message::SerialMessage;
57
58/// Ignores the result of a send operation and logs and ignores errors
59#[cfg(feature = "std")]
60fn ignore_send<M, T>(result: Result<M, mpsc::TrySendError<T>>) {
61    if let Err(error) = result {
62        if error.is_disconnected() {
63            debug!("ignoring send error on disconnected stream");
64            return;
65        }
66
67        warn!("error notifying wait, possible future leak: {:?}", error);
68    }
69}
70
71/// A non-multiplexed stream of Serialized DNS messages
72#[cfg(feature = "std")]
73pub trait DnsClientStream:
74    Stream<Item = Result<SerialMessage, ProtoError>> + Display + Send
75{
76    /// Time implementation for this impl
77    type Time: Time;
78
79    /// The remote name server address
80    fn name_server_addr(&self) -> SocketAddr;
81}
82
83/// Receiver handle for peekable fused SerialMessage channel
84#[cfg(feature = "std")]
85pub type StreamReceiver = Peekable<Fuse<mpsc::Receiver<SerialMessage>>>;
86
87#[cfg(feature = "std")]
88const CHANNEL_BUFFER_SIZE: usize = 32;
89
90/// A buffering stream bound to a `SocketAddr`
91///
92/// This stream handle ensures that all messages sent via this handle have the remote_addr set as the destination for the packet
93#[derive(Clone)]
94#[cfg(feature = "std")]
95pub struct BufDnsStreamHandle {
96    remote_addr: SocketAddr,
97    sender: mpsc::Sender<SerialMessage>,
98}
99
100#[cfg(feature = "std")]
101impl BufDnsStreamHandle {
102    /// Constructs a new Buffered Stream Handle, used for sending data to the DNS peer.
103    ///
104    /// # Arguments
105    ///
106    /// * `remote_addr` - the address of the remote DNS system (client or server)
107    /// * `sender` - the handle being used to send data to the server
108    pub fn new(remote_addr: SocketAddr) -> (Self, StreamReceiver) {
109        let (sender, receiver) = mpsc::channel(CHANNEL_BUFFER_SIZE);
110        let receiver = receiver.fuse().peekable();
111
112        let this = Self {
113            remote_addr,
114            sender,
115        };
116
117        (this, receiver)
118    }
119
120    /// Associates a different remote address for any responses.
121    ///
122    /// This is mainly useful in server use cases where the incoming address is only known after receiving a packet.
123    pub fn with_remote_addr(&self, remote_addr: SocketAddr) -> Self {
124        Self {
125            remote_addr,
126            sender: self.sender.clone(),
127        }
128    }
129}
130
131#[cfg(feature = "std")]
132impl DnsStreamHandle for BufDnsStreamHandle {
133    fn send(&mut self, buffer: SerialMessage) -> Result<(), ProtoError> {
134        let sender: &mut _ = &mut self.sender;
135        sender
136            .try_send(SerialMessage::new(buffer.into_parts().0, self.remote_addr))
137            .map_err(|e| ProtoError::from(format!("mpsc::SendError {e}")))
138    }
139}
140
141/// Types that implement this are capable of sending a serialized DNS message on a stream
142///
143/// The underlying Stream implementation should yield `Some(())` whenever it is ready to send a message,
144///   NotReady, if it is not ready to send a message, and `Err` or `None` in the case that the stream is
145///   done, and should be shutdown.
146#[cfg(feature = "std")]
147pub trait DnsRequestSender: Stream<Item = Result<(), ProtoError>> + Send + Unpin + 'static {
148    /// Send a message, and return a stream of response
149    ///
150    /// # Return
151    ///
152    /// A stream which will resolve to SerialMessage responses
153    fn send_message(&mut self, request: DnsRequest) -> DnsResponseStream;
154
155    /// Allows the upstream user to inform the underling stream that it should shutdown.
156    ///
157    /// After this is called, the next time `poll` is called on the stream it would be correct to return `Poll::Ready(Ok(()))`. This is not required though, if there are say outstanding requests that are not yet complete, then it would be correct to first wait for those results.
158    fn shutdown(&mut self);
159
160    /// Returns true if the stream has been shutdown with `shutdown`
161    fn is_shutdown(&self) -> bool;
162}
163
164/// Used for associating a name_server to a DnsRequestStreamHandle
165#[derive(Clone)]
166#[cfg(feature = "std")]
167pub struct BufDnsRequestStreamHandle {
168    sender: mpsc::Sender<OneshotDnsRequest>,
169}
170
171#[cfg(feature = "std")]
172macro_rules! try_oneshot {
173    ($expr:expr) => {{
174        use core::result::Result;
175
176        match $expr {
177            Result::Ok(val) => val,
178            Result::Err(err) => return DnsResponseReceiver::Err(Some(ProtoError::from(err))),
179        }
180    }};
181    ($expr:expr,) => {
182        $expr?
183    };
184}
185
186#[cfg(feature = "std")]
187impl DnsHandle for BufDnsRequestStreamHandle {
188    type Response = DnsResponseReceiver;
189
190    fn send<R: Into<DnsRequest>>(&self, request: R) -> Self::Response {
191        let request: DnsRequest = request.into();
192        debug!(
193            "enqueueing message:{}:{:?}",
194            request.op_code(),
195            request.queries()
196        );
197
198        let (request, oneshot) = OneshotDnsRequest::oneshot(request);
199        let mut sender = self.sender.clone();
200        let try_send = sender.try_send(request).map_err(|_| {
201            debug!("unable to enqueue message");
202            ProtoError::from(ProtoErrorKind::Busy)
203        });
204        try_oneshot!(try_send);
205
206        DnsResponseReceiver::Receiver(oneshot)
207    }
208}
209
210// TODO: this future should return the origin message in the response on errors
211/// A OneshotDnsRequest creates a channel for a response to message
212#[cfg(feature = "std")]
213pub struct OneshotDnsRequest {
214    dns_request: DnsRequest,
215    sender_for_response: oneshot::Sender<DnsResponseStream>,
216}
217
218#[cfg(feature = "std")]
219impl OneshotDnsRequest {
220    #[cfg(any(feature = "std", feature = "no-std-rand"))]
221    fn oneshot(dns_request: DnsRequest) -> (Self, oneshot::Receiver<DnsResponseStream>) {
222        let (sender_for_response, receiver) = oneshot::channel();
223
224        (
225            Self {
226                dns_request,
227                sender_for_response,
228            },
229            receiver,
230        )
231    }
232
233    fn into_parts(self) -> (DnsRequest, OneshotDnsResponse) {
234        (
235            self.dns_request,
236            OneshotDnsResponse(self.sender_for_response),
237        )
238    }
239}
240
241#[cfg(feature = "std")]
242struct OneshotDnsResponse(oneshot::Sender<DnsResponseStream>);
243
244#[cfg(feature = "std")]
245impl OneshotDnsResponse {
246    fn send_response(self, serial_response: DnsResponseStream) -> Result<(), DnsResponseStream> {
247        self.0.send(serial_response)
248    }
249}
250
251/// A Stream that wraps a [`oneshot::Receiver<Stream>`] and resolves to items in the inner Stream
252#[cfg(feature = "std")]
253pub enum DnsResponseReceiver {
254    /// The receiver
255    Receiver(oneshot::Receiver<DnsResponseStream>),
256    /// The stream once received
257    Received(DnsResponseStream),
258    /// Error during the send operation
259    Err(Option<ProtoError>),
260}
261
262#[cfg(feature = "std")]
263impl Stream for DnsResponseReceiver {
264    type Item = Result<DnsResponse, ProtoError>;
265
266    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
267        loop {
268            *self = match &mut *self {
269                Self::Receiver(receiver) => {
270                    let receiver = Pin::new(receiver);
271                    let future = ready!(
272                        receiver
273                            .poll(cx)
274                            .map_err(|_| ProtoError::from("receiver was canceled"))
275                    )?;
276                    Self::Received(future)
277                }
278                Self::Received(stream) => {
279                    return stream.poll_next_unpin(cx);
280                }
281                Self::Err(err) => return Poll::Ready(err.take().map(Err)),
282            };
283        }
284    }
285}
286
287/// Helper trait to convert a Stream of dns response into a Future
288pub trait FirstAnswer<T, E: From<ProtoError>>: Stream<Item = Result<T, E>> + Unpin + Sized {
289    /// Convert a Stream of dns response into a Future yielding the first answer,
290    /// discarding others if any.
291    fn first_answer(self) -> FirstAnswerFuture<Self> {
292        FirstAnswerFuture { stream: Some(self) }
293    }
294}
295
296impl<E, S, T> FirstAnswer<T, E> for S
297where
298    S: Stream<Item = Result<T, E>> + Unpin + Sized,
299    E: From<ProtoError>,
300{
301}
302
303/// See [FirstAnswer::first_answer]
304#[derive(Debug)]
305#[must_use = "futures do nothing unless you `.await` or poll them"]
306pub struct FirstAnswerFuture<S> {
307    stream: Option<S>,
308}
309
310impl<E, S: Stream<Item = Result<T, E>> + Unpin, T> Future for FirstAnswerFuture<S>
311where
312    S: Stream<Item = Result<T, E>> + Unpin + Sized,
313    E: From<ProtoError>,
314{
315    type Output = S::Item;
316
317    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
318        let s = self
319            .stream
320            .as_mut()
321            .expect("polling FirstAnswerFuture twice");
322        let item = match ready!(s.poll_next_unpin(cx)) {
323            Some(r) => r,
324            None => Err(ProtoError::from(ProtoErrorKind::Timeout).into()),
325        };
326        self.stream.take();
327        Poll::Ready(item)
328    }
329}
330
331/// The protocol on which a NameServer should be communicated with
332#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
333#[cfg_attr(
334    feature = "serde",
335    derive(Serialize, Deserialize),
336    serde(rename_all = "lowercase")
337)]
338#[non_exhaustive]
339pub enum Protocol {
340    /// UDP is the traditional DNS port, this is generally the correct choice
341    Udp,
342    /// TCP can be used for large queries, but not all NameServers support it
343    Tcp,
344    /// Tls for DNS over TLS
345    #[cfg(feature = "__tls")]
346    Tls,
347    /// Https for DNS over HTTPS
348    #[cfg(feature = "__https")]
349    Https,
350    /// QUIC for DNS over QUIC
351    #[cfg(feature = "__quic")]
352    Quic,
353    /// HTTP/3 for DNS over HTTP/3
354    #[cfg(feature = "__h3")]
355    H3,
356}
357
358impl fmt::Display for Protocol {
359    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
360        let protocol = match self {
361            Self::Udp => "udp",
362            Self::Tcp => "tcp",
363            #[cfg(feature = "__tls")]
364            Self::Tls => "tls",
365            #[cfg(feature = "__https")]
366            Self::Https => "https",
367            #[cfg(feature = "__quic")]
368            Self::Quic => "quic",
369            #[cfg(feature = "__h3")]
370            Self::H3 => "h3",
371        };
372
373        f.write_str(protocol)
374    }
375}
376
377impl Protocol {
378    /// Returns true if this is a datagram oriented protocol, e.g. UDP
379    pub fn is_datagram(self) -> bool {
380        match self {
381            Self::Udp => true,
382            Self::Tcp => false,
383            #[cfg(feature = "__tls")]
384            Self::Tls => false,
385            #[cfg(feature = "__https")]
386            Self::Https => false,
387            // TODO: if you squint, this is true...
388            #[cfg(feature = "__quic")]
389            Self::Quic => true,
390            #[cfg(feature = "__h3")]
391            Self::H3 => true,
392        }
393    }
394
395    /// Returns true if this is a stream oriented protocol, e.g. TCP
396    pub fn is_stream(self) -> bool {
397        !self.is_datagram()
398    }
399
400    /// Is this an encrypted protocol, i.e. TLS or HTTPS
401    pub fn is_encrypted(self) -> bool {
402        match self {
403            Self::Udp => false,
404            Self::Tcp => false,
405            #[cfg(feature = "__tls")]
406            Self::Tls => true,
407            #[cfg(feature = "__https")]
408            Self::Https => true,
409            #[cfg(feature = "__quic")]
410            Self::Quic => true,
411            #[cfg(feature = "__h3")]
412            Self::H3 => true,
413        }
414    }
415}
416
417impl Default for Protocol {
418    /// Default protocol should be UDP, which is supported by all DNS servers
419    fn default() -> Self {
420        Self::Udp
421    }
422}
423
424#[allow(unused)] // May be unused depending on features
425pub(crate) const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);