hickory_proto/xfer/
retry_dns_handle.rs

1// Copyright 2015-2016 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! `RetryDnsHandle` allows for DnsQueries to be reattempted on failure
9
10#[cfg(any(feature = "std", feature = "no-std-rand"))]
11use alloc::boxed::Box;
12use core::pin::Pin;
13use core::task::{Context, Poll};
14
15use futures_util::stream::{Stream, StreamExt};
16
17use crate::DnsHandle;
18use crate::error::{ProtoError, ProtoErrorKind};
19use crate::xfer::{DnsRequest, DnsResponse};
20
21/// Can be used to reattempt queries if they fail
22///
23/// Note: this does not reattempt queries that fail with a negative response.
24/// For example, if a query gets a `NODATA` response from a name server, the
25/// query will not be retried. It only reattempts queries that effectively
26/// failed to get a response, such as queries that resulted in IO or timeout
27/// errors.
28///
29/// Whether an error is retryable by the [`RetryDnsHandle`] is determined by the
30/// [`RetryableError`] trait.
31///
32/// *note* Current value of this is not clear, it may be removed
33#[derive(Clone)]
34#[must_use = "queries can only be sent through a ClientHandle"]
35#[allow(dead_code)]
36pub struct RetryDnsHandle<H>
37where
38    H: DnsHandle + Unpin + Send,
39{
40    handle: H,
41    attempts: usize,
42}
43
44impl<H> RetryDnsHandle<H>
45where
46    H: DnsHandle + Unpin + Send,
47{
48    /// Creates a new Client handler for reattempting requests on failures.
49    ///
50    /// # Arguments
51    ///
52    /// * `handle` - handle to the dns connection
53    /// * `attempts` - number of attempts before failing
54    pub fn new(handle: H, attempts: usize) -> Self {
55        Self { handle, attempts }
56    }
57}
58
59#[cfg(any(feature = "std", feature = "no-std-rand"))]
60impl<H> DnsHandle for RetryDnsHandle<H>
61where
62    H: DnsHandle + Send + Unpin + 'static,
63{
64    type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, ProtoError>> + Send + Unpin>>;
65
66    fn send<R: Into<DnsRequest>>(&self, request: R) -> Self::Response {
67        let request = request.into();
68
69        // need to clone here so that the retry can resend if necessary...
70        //  obviously it would be nice to be lazy about this...
71        let stream = self.handle.send(request.clone());
72
73        Box::pin(RetrySendStream {
74            request,
75            handle: self.handle.clone(),
76            stream,
77            remaining_attempts: self.attempts,
78        })
79    }
80}
81
82/// A stream for retrying (on failure, for the remaining number of times specified)
83struct RetrySendStream<H>
84where
85    H: DnsHandle,
86{
87    request: DnsRequest,
88    handle: H,
89    stream: <H as DnsHandle>::Response,
90    remaining_attempts: usize,
91}
92
93impl<H: DnsHandle + Unpin> Stream for RetrySendStream<H> {
94    type Item = Result<DnsResponse, ProtoError>;
95
96    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
97        // loop over the stream, on errors, spawn a new stream
98        //  on ready and not ready return.
99        loop {
100            match self.stream.poll_next_unpin(cx) {
101                Poll::Ready(Some(Err(e))) => {
102                    if self.remaining_attempts == 0 || !e.should_retry() {
103                        return Poll::Ready(Some(Err(e)));
104                    }
105
106                    if e.attempted() {
107                        self.remaining_attempts -= 1;
108                    }
109
110                    // TODO: if the "sent" Message is part of the error result,
111                    //  then we can just reuse it... and no clone necessary
112                    let request = self.request.clone();
113                    self.stream = self.handle.send(request);
114                }
115                poll => return poll,
116            }
117        }
118    }
119}
120
121/// What errors should be retried
122pub trait RetryableError {
123    /// Whether the query should be retried after this error
124    fn should_retry(&self) -> bool;
125    /// Whether this error should count as an attempt
126    fn attempted(&self) -> bool;
127}
128
129impl RetryableError for ProtoError {
130    fn should_retry(&self) -> bool {
131        !matches!(
132            self.kind(),
133            ProtoErrorKind::NoConnections | ProtoErrorKind::NoRecordsFound { .. }
134        )
135    }
136
137    fn attempted(&self) -> bool {
138        !matches!(self.kind(), ProtoErrorKind::Busy)
139    }
140}
141
142#[cfg(all(test, feature = "std"))]
143mod test {
144    use alloc::sync::Arc;
145    use core::sync::atomic::{AtomicU16, Ordering};
146
147    use super::*;
148    use crate::error::*;
149    use crate::op::*;
150    use crate::xfer::FirstAnswer;
151
152    use futures_executor::block_on;
153    use futures_util::future::{err, ok};
154    use futures_util::stream::*;
155    use test_support::subscribe;
156
157    #[derive(Clone)]
158    struct TestClient {
159        last_succeed: bool,
160        retries: u16,
161        attempts: Arc<AtomicU16>,
162    }
163
164    impl DnsHandle for TestClient {
165        type Response = Box<dyn Stream<Item = Result<DnsResponse, ProtoError>> + Send + Unpin>;
166
167        fn send<R: Into<DnsRequest>>(&self, _: R) -> Self::Response {
168            let i = self.attempts.load(Ordering::SeqCst);
169
170            if (i > self.retries || self.retries - i == 0) && self.last_succeed {
171                let mut message = Message::new();
172                message.set_id(i);
173                return Box::new(once(ok(DnsResponse::from_message(message).unwrap())));
174            }
175
176            self.attempts.fetch_add(1, Ordering::SeqCst);
177            Box::new(once(err(ProtoError::from("last retry set to fail"))))
178        }
179    }
180
181    #[test]
182    fn test_retry() {
183        subscribe();
184        let handle = RetryDnsHandle::new(
185            TestClient {
186                last_succeed: true,
187                retries: 1,
188                attempts: Arc::new(AtomicU16::new(0)),
189            },
190            2,
191        );
192        let test1 = Message::new();
193        let result = block_on(handle.send(test1).first_answer()).expect("should have succeeded");
194        assert_eq!(result.id(), 1); // this is checking the number of iterations the TestClient ran
195    }
196
197    #[test]
198    fn test_error() {
199        subscribe();
200        let client = RetryDnsHandle::new(
201            TestClient {
202                last_succeed: false,
203                retries: 1,
204                attempts: Arc::new(AtomicU16::new(0)),
205            },
206            2,
207        );
208        let test1 = Message::new();
209        assert!(block_on(client.send(test1).first_answer()).is_err());
210    }
211}