hickory_proto/xfer/
retry_dns_handle.rs1#[cfg(any(feature = "std", feature = "no-std-rand"))]
11use alloc::boxed::Box;
12use core::pin::Pin;
13use core::task::{Context, Poll};
14
15use futures_util::stream::{Stream, StreamExt};
16
17use crate::DnsHandle;
18use crate::error::{ProtoError, ProtoErrorKind};
19use crate::xfer::{DnsRequest, DnsResponse};
20
21#[derive(Clone)]
34#[must_use = "queries can only be sent through a ClientHandle"]
35#[allow(dead_code)]
36pub struct RetryDnsHandle<H>
37where
38 H: DnsHandle + Unpin + Send,
39{
40 handle: H,
41 attempts: usize,
42}
43
44impl<H> RetryDnsHandle<H>
45where
46 H: DnsHandle + Unpin + Send,
47{
48 pub fn new(handle: H, attempts: usize) -> Self {
55 Self { handle, attempts }
56 }
57}
58
59#[cfg(any(feature = "std", feature = "no-std-rand"))]
60impl<H> DnsHandle for RetryDnsHandle<H>
61where
62 H: DnsHandle + Send + Unpin + 'static,
63{
64 type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, ProtoError>> + Send + Unpin>>;
65
66 fn send<R: Into<DnsRequest>>(&self, request: R) -> Self::Response {
67 let request = request.into();
68
69 let stream = self.handle.send(request.clone());
72
73 Box::pin(RetrySendStream {
74 request,
75 handle: self.handle.clone(),
76 stream,
77 remaining_attempts: self.attempts,
78 })
79 }
80}
81
82struct RetrySendStream<H>
84where
85 H: DnsHandle,
86{
87 request: DnsRequest,
88 handle: H,
89 stream: <H as DnsHandle>::Response,
90 remaining_attempts: usize,
91}
92
93impl<H: DnsHandle + Unpin> Stream for RetrySendStream<H> {
94 type Item = Result<DnsResponse, ProtoError>;
95
96 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
97 loop {
100 match self.stream.poll_next_unpin(cx) {
101 Poll::Ready(Some(Err(e))) => {
102 if self.remaining_attempts == 0 || !e.should_retry() {
103 return Poll::Ready(Some(Err(e)));
104 }
105
106 if e.attempted() {
107 self.remaining_attempts -= 1;
108 }
109
110 let request = self.request.clone();
113 self.stream = self.handle.send(request);
114 }
115 poll => return poll,
116 }
117 }
118 }
119}
120
121pub trait RetryableError {
123 fn should_retry(&self) -> bool;
125 fn attempted(&self) -> bool;
127}
128
129impl RetryableError for ProtoError {
130 fn should_retry(&self) -> bool {
131 !matches!(
132 self.kind(),
133 ProtoErrorKind::NoConnections | ProtoErrorKind::NoRecordsFound { .. }
134 )
135 }
136
137 fn attempted(&self) -> bool {
138 !matches!(self.kind(), ProtoErrorKind::Busy)
139 }
140}
141
142#[cfg(all(test, feature = "std"))]
143mod test {
144 use alloc::sync::Arc;
145 use core::sync::atomic::{AtomicU16, Ordering};
146
147 use super::*;
148 use crate::error::*;
149 use crate::op::*;
150 use crate::xfer::FirstAnswer;
151
152 use futures_executor::block_on;
153 use futures_util::future::{err, ok};
154 use futures_util::stream::*;
155 use test_support::subscribe;
156
157 #[derive(Clone)]
158 struct TestClient {
159 last_succeed: bool,
160 retries: u16,
161 attempts: Arc<AtomicU16>,
162 }
163
164 impl DnsHandle for TestClient {
165 type Response = Box<dyn Stream<Item = Result<DnsResponse, ProtoError>> + Send + Unpin>;
166
167 fn send<R: Into<DnsRequest>>(&self, _: R) -> Self::Response {
168 let i = self.attempts.load(Ordering::SeqCst);
169
170 if (i > self.retries || self.retries - i == 0) && self.last_succeed {
171 let mut message = Message::new();
172 message.set_id(i);
173 return Box::new(once(ok(DnsResponse::from_message(message).unwrap())));
174 }
175
176 self.attempts.fetch_add(1, Ordering::SeqCst);
177 Box::new(once(err(ProtoError::from("last retry set to fail"))))
178 }
179 }
180
181 #[test]
182 fn test_retry() {
183 subscribe();
184 let handle = RetryDnsHandle::new(
185 TestClient {
186 last_succeed: true,
187 retries: 1,
188 attempts: Arc::new(AtomicU16::new(0)),
189 },
190 2,
191 );
192 let test1 = Message::new();
193 let result = block_on(handle.send(test1).first_answer()).expect("should have succeeded");
194 assert_eq!(result.id(), 1); }
196
197 #[test]
198 fn test_error() {
199 subscribe();
200 let client = RetryDnsHandle::new(
201 TestClient {
202 last_succeed: false,
203 retries: 1,
204 attempts: Arc::new(AtomicU16::new(0)),
205 },
206 2,
207 );
208 let test1 = Message::new();
209 assert!(block_on(client.send(test1).first_answer()).is_err());
210 }
211}