hickory_proto/xfer/
dns_exchange.rs

1// Copyright 2015-2018 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! This module contains all the types for demuxing DNS oriented streams.
9
10use std::marker::PhantomData;
11use std::pin::Pin;
12use std::task::{Context, Poll};
13
14use futures_channel::mpsc;
15use futures_util::future::{Future, FutureExt};
16use futures_util::stream::{Peekable, Stream, StreamExt};
17use tracing::{debug, warn};
18
19use crate::error::*;
20use crate::xfer::dns_handle::DnsHandle;
21use crate::xfer::DnsResponseReceiver;
22use crate::xfer::{
23    BufDnsRequestStreamHandle, DnsRequest, DnsRequestSender, DnsResponse, OneshotDnsRequest,
24    CHANNEL_BUFFER_SIZE,
25};
26use crate::Time;
27
28/// This is a generic Exchange implemented over multiplexed DNS connection providers.
29///
30/// The underlying `DnsRequestSender` is expected to multiplex any I/O connections. DnsExchange assumes that the underlying stream is responsible for this.
31#[must_use = "futures do nothing unless polled"]
32pub struct DnsExchange {
33    sender: BufDnsRequestStreamHandle,
34}
35
36impl DnsExchange {
37    /// Initializes a TcpStream with an existing tcp::TcpStream.
38    ///
39    /// This is intended for use with a TcpListener and Incoming.
40    ///
41    /// # Arguments
42    ///
43    /// * `stream` - the established IO stream for communication
44    pub fn from_stream<S, TE>(stream: S) -> (Self, DnsExchangeBackground<S, TE>)
45    where
46        S: DnsRequestSender + 'static + Send + Unpin,
47    {
48        let (sender, outbound_messages) = mpsc::channel(CHANNEL_BUFFER_SIZE);
49        let message_sender = BufDnsRequestStreamHandle { sender };
50
51        Self::from_stream_with_receiver(stream, outbound_messages, message_sender)
52    }
53
54    /// Wraps a stream where a sender and receiver have already been established
55    pub fn from_stream_with_receiver<S, TE>(
56        stream: S,
57        receiver: mpsc::Receiver<OneshotDnsRequest>,
58        sender: BufDnsRequestStreamHandle,
59    ) -> (Self, DnsExchangeBackground<S, TE>)
60    where
61        S: DnsRequestSender + 'static + Send + Unpin,
62    {
63        let background = DnsExchangeBackground {
64            io_stream: stream,
65            outbound_messages: receiver.peekable(),
66            marker: PhantomData,
67        };
68
69        (Self { sender }, background)
70    }
71
72    /// Returns a future, which itself wraps a future which is awaiting connection.
73    ///
74    /// The connect_future should be lazy.
75    pub fn connect<F, S, TE>(connect_future: F) -> DnsExchangeConnect<F, S, TE>
76    where
77        F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
78        S: DnsRequestSender + 'static + Send + Unpin,
79        TE: Time + Unpin,
80    {
81        let (sender, outbound_messages) = mpsc::channel(CHANNEL_BUFFER_SIZE);
82        let message_sender = BufDnsRequestStreamHandle { sender };
83
84        DnsExchangeConnect::connect(connect_future, outbound_messages, message_sender)
85    }
86
87    /// Returns a future that returns an error immediately.
88    pub fn error<F, S, TE>(error: ProtoError) -> DnsExchangeConnect<F, S, TE>
89    where
90        F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
91        S: DnsRequestSender + 'static + Send + Unpin,
92        TE: Time + Unpin,
93    {
94        DnsExchangeConnect(DnsExchangeConnectInner::Error(error))
95    }
96}
97
98impl Clone for DnsExchange {
99    fn clone(&self) -> Self {
100        Self {
101            sender: self.sender.clone(),
102        }
103    }
104}
105
106impl DnsHandle for DnsExchange {
107    type Response = DnsExchangeSend;
108    type Error = ProtoError;
109
110    fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(&self, request: R) -> Self::Response {
111        DnsExchangeSend {
112            result: self.sender.send(request),
113            _sender: self.sender.clone(), // TODO: this shouldn't be necessary, currently the presence of Senders is what allows the background to track current users, it generally is dropped right after send, this makes sure that there is at least one active after send
114        }
115    }
116}
117
118/// A Stream that will resolve to Responses after sending the request
119#[must_use = "futures do nothing unless polled"]
120pub struct DnsExchangeSend {
121    result: DnsResponseReceiver,
122    _sender: BufDnsRequestStreamHandle,
123}
124
125impl Stream for DnsExchangeSend {
126    type Item = Result<DnsResponse, ProtoError>;
127
128    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
129        // as long as there is no result, poll the exchange
130        self.result.poll_next_unpin(cx)
131    }
132}
133
134/// This background future is responsible for driving all network operations for the DNS protocol.
135///
136/// It must be spawned before any DNS messages are sent.
137#[must_use = "futures do nothing unless polled"]
138pub struct DnsExchangeBackground<S, TE>
139where
140    S: DnsRequestSender + 'static + Send + Unpin,
141{
142    io_stream: S,
143    outbound_messages: Peekable<mpsc::Receiver<OneshotDnsRequest>>,
144    marker: PhantomData<TE>,
145}
146
147impl<S, TE> DnsExchangeBackground<S, TE>
148where
149    S: DnsRequestSender + 'static + Send + Unpin,
150{
151    fn pollable_split(&mut self) -> (&mut S, &mut Peekable<mpsc::Receiver<OneshotDnsRequest>>) {
152        (&mut self.io_stream, &mut self.outbound_messages)
153    }
154}
155
156impl<S, TE> Future for DnsExchangeBackground<S, TE>
157where
158    S: DnsRequestSender + 'static + Send + Unpin,
159    TE: Time + Unpin,
160{
161    type Output = Result<(), ProtoError>;
162
163    #[allow(clippy::unused_unit)]
164    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
165        let (io_stream, outbound_messages) = self.pollable_split();
166        let mut io_stream = Pin::new(io_stream);
167        let mut outbound_messages = Pin::new(outbound_messages);
168
169        // this will not accept incoming data while there is data to send
170        //  makes this self throttling.
171        loop {
172            // poll the underlying stream, to drive it...
173            match io_stream.as_mut().poll_next(cx) {
174                // The stream is ready
175                Poll::Ready(Some(Ok(()))) => (),
176                Poll::Pending => {
177                    if io_stream.is_shutdown() {
178                        // the io_stream is in a shutdown state, we are only waiting for final results...
179                        return Poll::Pending;
180                    }
181
182                    // NotReady and not shutdown, see if there are more messages to send
183                    ()
184                } // underlying stream is complete.
185                Poll::Ready(None) => {
186                    debug!("io_stream is done, shutting down");
187                    // TODO: return shutdown error to anything in the stream?
188
189                    return Poll::Ready(Ok(()));
190                }
191                Poll::Ready(Some(Err(err))) => {
192                    debug!(
193                        error = err.as_dyn(),
194                        "io_stream hit an error, shutting down"
195                    );
196
197                    return Poll::Ready(Err(err));
198                }
199            }
200
201            // then see if there is more to send
202            match outbound_messages.as_mut().poll_next(cx) {
203                // already handled above, here to make sure the poll() pops the next message
204                Poll::Ready(Some(dns_request)) => {
205                    // if there is no peer, this connection should die...
206                    let (dns_request, serial_response): (DnsRequest, _) = dns_request.into_parts();
207
208                    // Try to forward the `DnsResponseStream` to the requesting task. If we fail,
209                    // it must be because the requesting task has gone away / is no longer
210                    // interested. In that case, we can just log a warning, but there's no need
211                    // to take any more serious measures (such as shutting down this task).
212                    match serial_response.send_response(io_stream.send_message(dns_request)) {
213                        Ok(()) => (),
214                        Err(_) => {
215                            warn!("failed to associate send_message response to the sender");
216                        }
217                    }
218                }
219                // On not ready, this is our time to return...
220                Poll::Pending => return Poll::Pending,
221                Poll::Ready(None) => {
222                    // if there is nothing that can use this connection to send messages, then this is done...
223                    io_stream.shutdown();
224
225                    // now we'll await the stream to shutdown... see io_stream poll above
226                }
227            }
228
229            // else we loop to poll on the outbound_messages
230        }
231    }
232}
233
234/// A wrapper for a future DnsExchange connection.
235///
236/// DnsExchangeConnect is cloneable, making it possible to share this if the connection
237///  will be shared across threads.
238///
239/// The future will return a tuple of the DnsExchange (for sending messages) and a background
240///  for running the background tasks. The background is optional as only one thread should run
241///  the background. If returned, it must be spawned before any dns requests will function.
242pub struct DnsExchangeConnect<F, S, TE>(DnsExchangeConnectInner<F, S, TE>)
243where
244    F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
245    S: DnsRequestSender + 'static,
246    TE: Time + Unpin;
247
248impl<F, S, TE> DnsExchangeConnect<F, S, TE>
249where
250    F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
251    S: DnsRequestSender + 'static,
252    TE: Time + Unpin,
253{
254    fn connect(
255        connect_future: F,
256        outbound_messages: mpsc::Receiver<OneshotDnsRequest>,
257        sender: BufDnsRequestStreamHandle,
258    ) -> Self {
259        Self(DnsExchangeConnectInner::Connecting {
260            connect_future,
261            outbound_messages: Some(outbound_messages),
262            sender: Some(sender),
263        })
264    }
265}
266
267#[allow(clippy::type_complexity)]
268impl<F, S, TE> Future for DnsExchangeConnect<F, S, TE>
269where
270    F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
271    S: DnsRequestSender + 'static + Send + Unpin,
272    TE: Time + Unpin,
273{
274    type Output = Result<(DnsExchange, DnsExchangeBackground<S, TE>), ProtoError>;
275
276    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
277        self.0.poll_unpin(cx)
278    }
279}
280
281enum DnsExchangeConnectInner<F, S, TE>
282where
283    F: Future<Output = Result<S, ProtoError>> + 'static + Send,
284    S: DnsRequestSender + 'static + Send,
285    TE: Time + Unpin,
286{
287    Connecting {
288        connect_future: F,
289        outbound_messages: Option<mpsc::Receiver<OneshotDnsRequest>>,
290        sender: Option<BufDnsRequestStreamHandle>,
291    },
292    Connected {
293        exchange: DnsExchange,
294        background: Option<DnsExchangeBackground<S, TE>>,
295    },
296    FailAll {
297        error: ProtoError,
298        outbound_messages: mpsc::Receiver<OneshotDnsRequest>,
299    },
300    Error(ProtoError),
301}
302
303#[allow(clippy::type_complexity)]
304impl<F, S, TE> Future for DnsExchangeConnectInner<F, S, TE>
305where
306    F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
307    S: DnsRequestSender + 'static + Send + Unpin,
308    TE: Time + Unpin,
309{
310    type Output = Result<(DnsExchange, DnsExchangeBackground<S, TE>), ProtoError>;
311
312    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
313        loop {
314            let next;
315            match *self {
316                Self::Connecting {
317                    ref mut connect_future,
318                    ref mut outbound_messages,
319                    ref mut sender,
320                } => {
321                    let connect_future = Pin::new(connect_future);
322                    match connect_future.poll(cx) {
323                        Poll::Ready(Ok(stream)) => {
324                            //debug!("connection established: {}", stream);
325
326                            let (exchange, background) = DnsExchange::from_stream_with_receiver(
327                                stream,
328                                outbound_messages
329                                    .take()
330                                    .expect("cannot poll after complete"),
331                                sender.take().expect("cannot poll after complete"),
332                            );
333
334                            next = Self::Connected {
335                                exchange,
336                                background: Some(background),
337                            };
338                        }
339                        Poll::Pending => return Poll::Pending,
340                        Poll::Ready(Err(error)) => {
341                            debug!(error = error.as_dyn(), "stream errored while connecting");
342                            next = Self::FailAll {
343                                error,
344                                outbound_messages: outbound_messages
345                                    .take()
346                                    .expect("cannot poll after complete"),
347                            }
348                        }
349                    };
350                }
351                Self::Connected {
352                    ref exchange,
353                    ref mut background,
354                } => {
355                    let exchange = exchange.clone();
356                    let background = background.take().expect("cannot poll after complete");
357
358                    return Poll::Ready(Ok((exchange, background)));
359                }
360                Self::FailAll {
361                    ref error,
362                    ref mut outbound_messages,
363                } => {
364                    while let Some(outbound_message) = match outbound_messages.poll_next_unpin(cx) {
365                        Poll::Ready(opt) => opt,
366                        Poll::Pending => return Poll::Pending,
367                    } {
368                        // ignoring errors... best effort send...
369                        outbound_message
370                            .into_parts()
371                            .1
372                            .send_response(error.clone().into())
373                            .ok();
374                    }
375
376                    return Poll::Ready(Err(error.clone()));
377                }
378                Self::Error(ref error) => return Poll::Ready(Err(error.clone())),
379            }
380
381            *self = next;
382        }
383    }
384}