hickory_proto/xfer/
dns_exchange.rs

1// Copyright 2015-2018 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! This module contains all the types for demuxing DNS oriented streams.
9
10use core::future::Future;
11use core::marker::PhantomData;
12use core::pin::Pin;
13use core::task::{Context, Poll};
14
15use futures_channel::mpsc;
16use futures_util::future::FutureExt;
17use futures_util::stream::{Peekable, Stream, StreamExt};
18use tracing::debug;
19
20use crate::error::*;
21#[cfg(feature = "std")]
22use crate::runtime::Time;
23use crate::xfer::DnsResponseReceiver;
24#[cfg(any(feature = "std", feature = "no-std-rand"))]
25use crate::xfer::dns_handle::DnsHandle;
26use crate::xfer::{
27    BufDnsRequestStreamHandle, CHANNEL_BUFFER_SIZE, DnsRequest, DnsRequestSender, DnsResponse,
28    OneshotDnsRequest,
29};
30
31/// This is a generic Exchange implemented over multiplexed DNS connection providers.
32///
33/// The underlying `DnsRequestSender` is expected to multiplex any I/O connections. DnsExchange assumes that the underlying stream is responsible for this.
34#[must_use = "futures do nothing unless polled"]
35pub struct DnsExchange {
36    sender: BufDnsRequestStreamHandle,
37}
38
39impl DnsExchange {
40    /// Initializes a TcpStream with an existing tcp::TcpStream.
41    ///
42    /// This is intended for use with a TcpListener and Incoming.
43    ///
44    /// # Arguments
45    ///
46    /// * `stream` - the established IO stream for communication
47    pub fn from_stream<S, TE>(stream: S) -> (Self, DnsExchangeBackground<S, TE>)
48    where
49        S: DnsRequestSender + 'static + Send + Unpin,
50    {
51        let (sender, outbound_messages) = mpsc::channel(CHANNEL_BUFFER_SIZE);
52        let message_sender = BufDnsRequestStreamHandle { sender };
53
54        Self::from_stream_with_receiver(stream, outbound_messages, message_sender)
55    }
56
57    /// Wraps a stream where a sender and receiver have already been established
58    pub fn from_stream_with_receiver<S, TE>(
59        stream: S,
60        receiver: mpsc::Receiver<OneshotDnsRequest>,
61        sender: BufDnsRequestStreamHandle,
62    ) -> (Self, DnsExchangeBackground<S, TE>)
63    where
64        S: DnsRequestSender + 'static + Send + Unpin,
65    {
66        let background = DnsExchangeBackground {
67            io_stream: stream,
68            outbound_messages: receiver.peekable(),
69            marker: PhantomData,
70        };
71
72        (Self { sender }, background)
73    }
74
75    /// Returns a future, which itself wraps a future which is awaiting connection.
76    ///
77    /// The connect_future should be lazy.
78    #[cfg(feature = "std")]
79    pub fn connect<F, S, TE>(connect_future: F) -> DnsExchangeConnect<F, S, TE>
80    where
81        F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
82        S: DnsRequestSender + 'static + Send + Unpin,
83        TE: Time + Unpin,
84    {
85        let (sender, outbound_messages) = mpsc::channel(CHANNEL_BUFFER_SIZE);
86        let message_sender = BufDnsRequestStreamHandle { sender };
87
88        DnsExchangeConnect::connect(connect_future, outbound_messages, message_sender)
89    }
90
91    /// Returns a future that returns an error immediately.
92    pub fn error<F, S, TE>(error: ProtoError) -> DnsExchangeConnect<F, S, TE>
93    where
94        F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
95        S: DnsRequestSender + 'static + Send + Unpin,
96        TE: Time + Unpin,
97    {
98        DnsExchangeConnect(DnsExchangeConnectInner::Error(error))
99    }
100}
101
102impl Clone for DnsExchange {
103    fn clone(&self) -> Self {
104        Self {
105            sender: self.sender.clone(),
106        }
107    }
108}
109
110#[cfg(any(feature = "std", feature = "no-std-rand"))]
111impl DnsHandle for DnsExchange {
112    type Response = DnsExchangeSend;
113
114    fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(&self, request: R) -> Self::Response {
115        DnsExchangeSend {
116            result: self.sender.send(request),
117            _sender: self.sender.clone(), // TODO: this shouldn't be necessary, currently the presence of Senders is what allows the background to track current users, it generally is dropped right after send, this makes sure that there is at least one active after send
118        }
119    }
120}
121
122/// A Stream that will resolve to Responses after sending the request
123#[must_use = "futures do nothing unless polled"]
124pub struct DnsExchangeSend {
125    result: DnsResponseReceiver,
126    _sender: BufDnsRequestStreamHandle,
127}
128
129impl Stream for DnsExchangeSend {
130    type Item = Result<DnsResponse, ProtoError>;
131
132    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
133        // as long as there is no result, poll the exchange
134        self.result.poll_next_unpin(cx)
135    }
136}
137
138/// This background future is responsible for driving all network operations for the DNS protocol.
139///
140/// It must be spawned before any DNS messages are sent.
141#[must_use = "futures do nothing unless polled"]
142pub struct DnsExchangeBackground<S, TE>
143where
144    S: DnsRequestSender + 'static + Send + Unpin,
145{
146    io_stream: S,
147    outbound_messages: Peekable<mpsc::Receiver<OneshotDnsRequest>>,
148    marker: PhantomData<TE>,
149}
150
151impl<S, TE> DnsExchangeBackground<S, TE>
152where
153    S: DnsRequestSender + 'static + Send + Unpin,
154{
155    fn pollable_split(&mut self) -> (&mut S, &mut Peekable<mpsc::Receiver<OneshotDnsRequest>>) {
156        (&mut self.io_stream, &mut self.outbound_messages)
157    }
158}
159
160impl<S, TE> Future for DnsExchangeBackground<S, TE>
161where
162    S: DnsRequestSender + 'static + Send + Unpin,
163    TE: Time + Unpin,
164{
165    type Output = Result<(), ProtoError>;
166
167    #[allow(clippy::unused_unit)]
168    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
169        let (io_stream, outbound_messages) = self.pollable_split();
170        let mut io_stream = Pin::new(io_stream);
171        let mut outbound_messages = Pin::new(outbound_messages);
172
173        // this will not accept incoming data while there is data to send
174        //  makes this self throttling.
175        loop {
176            // poll the underlying stream, to drive it...
177            match io_stream.as_mut().poll_next(cx) {
178                // The stream is ready
179                Poll::Ready(Some(Ok(()))) => (),
180                Poll::Pending => {
181                    if io_stream.is_shutdown() {
182                        // the io_stream is in a shutdown state, we are only waiting for final results...
183                        return Poll::Pending;
184                    }
185
186                    // NotReady and not shutdown, see if there are more messages to send
187                    ()
188                } // underlying stream is complete.
189                Poll::Ready(None) => {
190                    debug!("io_stream is done, shutting down");
191                    // TODO: return shutdown error to anything in the stream?
192
193                    return Poll::Ready(Ok(()));
194                }
195                Poll::Ready(Some(Err(err))) => {
196                    debug!(
197                        error = err.as_dyn(),
198                        "io_stream hit an error, shutting down"
199                    );
200
201                    return Poll::Ready(Err(err));
202                }
203            }
204
205            // then see if there is more to send
206            match outbound_messages.as_mut().poll_next(cx) {
207                // already handled above, here to make sure the poll() pops the next message
208                Poll::Ready(Some(dns_request)) => {
209                    // if there is no peer, this connection should die...
210                    let (dns_request, serial_response): (DnsRequest, _) = dns_request.into_parts();
211
212                    // Try to forward the `DnsResponseStream` to the requesting task. If we fail,
213                    // it must be because the requesting task has gone away / is no longer
214                    // interested. In that case, we can just log a warning, but there's no need
215                    // to take any more serious measures (such as shutting down this task).
216                    match serial_response.send_response(io_stream.send_message(dns_request)) {
217                        Ok(()) => (),
218                        Err(_) => {
219                            debug!("failed to associate send_message response to the sender");
220                        }
221                    }
222                }
223                // On not ready, this is our time to return...
224                Poll::Pending => return Poll::Pending,
225                Poll::Ready(None) => {
226                    // if there is nothing that can use this connection to send messages, then this is done...
227                    io_stream.shutdown();
228
229                    // now we'll await the stream to shutdown... see io_stream poll above
230                }
231            }
232
233            // else we loop to poll on the outbound_messages
234        }
235    }
236}
237
238/// A wrapper for a future DnsExchange connection.
239///
240/// DnsExchangeConnect is cloneable, making it possible to share this if the connection
241///  will be shared across threads.
242///
243/// The future will return a tuple of the DnsExchange (for sending messages) and a background
244///  for running the background tasks. The background is optional as only one thread should run
245///  the background. If returned, it must be spawned before any dns requests will function.
246pub struct DnsExchangeConnect<F, S, TE>(DnsExchangeConnectInner<F, S, TE>)
247where
248    F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
249    S: DnsRequestSender + 'static,
250    TE: Time + Unpin;
251
252impl<F, S, TE> DnsExchangeConnect<F, S, TE>
253where
254    F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
255    S: DnsRequestSender + 'static,
256    TE: Time + Unpin,
257{
258    fn connect(
259        connect_future: F,
260        outbound_messages: mpsc::Receiver<OneshotDnsRequest>,
261        sender: BufDnsRequestStreamHandle,
262    ) -> Self {
263        Self(DnsExchangeConnectInner::Connecting {
264            connect_future,
265            outbound_messages: Some(outbound_messages),
266            sender: Some(sender),
267        })
268    }
269}
270
271#[allow(clippy::type_complexity)]
272impl<F, S, TE> Future for DnsExchangeConnect<F, S, TE>
273where
274    F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
275    S: DnsRequestSender + 'static + Send + Unpin,
276    TE: Time + Unpin,
277{
278    type Output = Result<(DnsExchange, DnsExchangeBackground<S, TE>), ProtoError>;
279
280    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
281        self.0.poll_unpin(cx)
282    }
283}
284
285enum DnsExchangeConnectInner<F, S, TE>
286where
287    F: Future<Output = Result<S, ProtoError>> + 'static + Send,
288    S: DnsRequestSender + 'static + Send,
289    TE: Time + Unpin,
290{
291    Connecting {
292        connect_future: F,
293        outbound_messages: Option<mpsc::Receiver<OneshotDnsRequest>>,
294        sender: Option<BufDnsRequestStreamHandle>,
295    },
296    Connected {
297        exchange: DnsExchange,
298        background: Option<DnsExchangeBackground<S, TE>>,
299    },
300    FailAll {
301        error: ProtoError,
302        outbound_messages: mpsc::Receiver<OneshotDnsRequest>,
303    },
304    Error(ProtoError),
305}
306
307#[allow(clippy::type_complexity)]
308impl<F, S, TE> Future for DnsExchangeConnectInner<F, S, TE>
309where
310    F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
311    S: DnsRequestSender + 'static + Send + Unpin,
312    TE: Time + Unpin,
313{
314    type Output = Result<(DnsExchange, DnsExchangeBackground<S, TE>), ProtoError>;
315
316    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
317        loop {
318            let next;
319            match &mut *self {
320                Self::Connecting {
321                    connect_future,
322                    outbound_messages,
323                    sender,
324                } => {
325                    let connect_future = Pin::new(connect_future);
326                    match connect_future.poll(cx) {
327                        Poll::Ready(Ok(stream)) => {
328                            //debug!("connection established: {}", stream);
329
330                            let (exchange, background) = DnsExchange::from_stream_with_receiver(
331                                stream,
332                                outbound_messages
333                                    .take()
334                                    .expect("cannot poll after complete"),
335                                sender.take().expect("cannot poll after complete"),
336                            );
337
338                            next = Self::Connected {
339                                exchange,
340                                background: Some(background),
341                            };
342                        }
343                        Poll::Pending => return Poll::Pending,
344                        Poll::Ready(Err(error)) => {
345                            debug!(error = error.as_dyn(), "stream errored while connecting");
346                            next = Self::FailAll {
347                                error,
348                                outbound_messages: outbound_messages
349                                    .take()
350                                    .expect("cannot poll after complete"),
351                            }
352                        }
353                    };
354                }
355                Self::Connected {
356                    exchange,
357                    background,
358                } => {
359                    let exchange = exchange.clone();
360                    let background = background.take().expect("cannot poll after complete");
361
362                    return Poll::Ready(Ok((exchange, background)));
363                }
364                Self::FailAll {
365                    error,
366                    outbound_messages,
367                } => {
368                    while let Some(outbound_message) = match outbound_messages.poll_next_unpin(cx) {
369                        Poll::Ready(opt) => opt,
370                        Poll::Pending => return Poll::Pending,
371                    } {
372                        // ignoring errors... best effort send...
373                        outbound_message
374                            .into_parts()
375                            .1
376                            .send_response(error.clone().into())
377                            .ok();
378                    }
379
380                    return Poll::Ready(Err(error.clone()));
381                }
382                Self::Error(error) => return Poll::Ready(Err(error.clone())),
383            }
384
385            *self = next;
386        }
387    }
388}