1use core::future::Future;
11use core::marker::PhantomData;
12use core::pin::Pin;
13use core::task::{Context, Poll};
14
15use futures_channel::mpsc;
16use futures_util::future::FutureExt;
17use futures_util::stream::{Peekable, Stream, StreamExt};
18use tracing::debug;
19
20use crate::error::*;
21#[cfg(feature = "std")]
22use crate::runtime::Time;
23use crate::xfer::DnsResponseReceiver;
24#[cfg(any(feature = "std", feature = "no-std-rand"))]
25use crate::xfer::dns_handle::DnsHandle;
26use crate::xfer::{
27 BufDnsRequestStreamHandle, CHANNEL_BUFFER_SIZE, DnsRequest, DnsRequestSender, DnsResponse,
28 OneshotDnsRequest,
29};
30
31#[must_use = "futures do nothing unless polled"]
35pub struct DnsExchange {
36 sender: BufDnsRequestStreamHandle,
37}
38
39impl DnsExchange {
40 pub fn from_stream<S, TE>(stream: S) -> (Self, DnsExchangeBackground<S, TE>)
48 where
49 S: DnsRequestSender + 'static + Send + Unpin,
50 {
51 let (sender, outbound_messages) = mpsc::channel(CHANNEL_BUFFER_SIZE);
52 let message_sender = BufDnsRequestStreamHandle { sender };
53
54 Self::from_stream_with_receiver(stream, outbound_messages, message_sender)
55 }
56
57 pub fn from_stream_with_receiver<S, TE>(
59 stream: S,
60 receiver: mpsc::Receiver<OneshotDnsRequest>,
61 sender: BufDnsRequestStreamHandle,
62 ) -> (Self, DnsExchangeBackground<S, TE>)
63 where
64 S: DnsRequestSender + 'static + Send + Unpin,
65 {
66 let background = DnsExchangeBackground {
67 io_stream: stream,
68 outbound_messages: receiver.peekable(),
69 marker: PhantomData,
70 };
71
72 (Self { sender }, background)
73 }
74
75 #[cfg(feature = "std")]
79 pub fn connect<F, S, TE>(connect_future: F) -> DnsExchangeConnect<F, S, TE>
80 where
81 F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
82 S: DnsRequestSender + 'static + Send + Unpin,
83 TE: Time + Unpin,
84 {
85 let (sender, outbound_messages) = mpsc::channel(CHANNEL_BUFFER_SIZE);
86 let message_sender = BufDnsRequestStreamHandle { sender };
87
88 DnsExchangeConnect::connect(connect_future, outbound_messages, message_sender)
89 }
90
91 pub fn error<F, S, TE>(error: ProtoError) -> DnsExchangeConnect<F, S, TE>
93 where
94 F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
95 S: DnsRequestSender + 'static + Send + Unpin,
96 TE: Time + Unpin,
97 {
98 DnsExchangeConnect(DnsExchangeConnectInner::Error(error))
99 }
100}
101
102impl Clone for DnsExchange {
103 fn clone(&self) -> Self {
104 Self {
105 sender: self.sender.clone(),
106 }
107 }
108}
109
110#[cfg(any(feature = "std", feature = "no-std-rand"))]
111impl DnsHandle for DnsExchange {
112 type Response = DnsExchangeSend;
113
114 fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(&self, request: R) -> Self::Response {
115 DnsExchangeSend {
116 result: self.sender.send(request),
117 _sender: self.sender.clone(), }
119 }
120}
121
122#[must_use = "futures do nothing unless polled"]
124pub struct DnsExchangeSend {
125 result: DnsResponseReceiver,
126 _sender: BufDnsRequestStreamHandle,
127}
128
129impl Stream for DnsExchangeSend {
130 type Item = Result<DnsResponse, ProtoError>;
131
132 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
133 self.result.poll_next_unpin(cx)
135 }
136}
137
138#[must_use = "futures do nothing unless polled"]
142pub struct DnsExchangeBackground<S, TE>
143where
144 S: DnsRequestSender + 'static + Send + Unpin,
145{
146 io_stream: S,
147 outbound_messages: Peekable<mpsc::Receiver<OneshotDnsRequest>>,
148 marker: PhantomData<TE>,
149}
150
151impl<S, TE> DnsExchangeBackground<S, TE>
152where
153 S: DnsRequestSender + 'static + Send + Unpin,
154{
155 fn pollable_split(&mut self) -> (&mut S, &mut Peekable<mpsc::Receiver<OneshotDnsRequest>>) {
156 (&mut self.io_stream, &mut self.outbound_messages)
157 }
158}
159
160impl<S, TE> Future for DnsExchangeBackground<S, TE>
161where
162 S: DnsRequestSender + 'static + Send + Unpin,
163 TE: Time + Unpin,
164{
165 type Output = Result<(), ProtoError>;
166
167 #[allow(clippy::unused_unit)]
168 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
169 let (io_stream, outbound_messages) = self.pollable_split();
170 let mut io_stream = Pin::new(io_stream);
171 let mut outbound_messages = Pin::new(outbound_messages);
172
173 loop {
176 match io_stream.as_mut().poll_next(cx) {
178 Poll::Ready(Some(Ok(()))) => (),
180 Poll::Pending => {
181 if io_stream.is_shutdown() {
182 return Poll::Pending;
184 }
185
186 ()
188 } Poll::Ready(None) => {
190 debug!("io_stream is done, shutting down");
191 return Poll::Ready(Ok(()));
194 }
195 Poll::Ready(Some(Err(err))) => {
196 debug!(
197 error = err.as_dyn(),
198 "io_stream hit an error, shutting down"
199 );
200
201 return Poll::Ready(Err(err));
202 }
203 }
204
205 match outbound_messages.as_mut().poll_next(cx) {
207 Poll::Ready(Some(dns_request)) => {
209 let (dns_request, serial_response): (DnsRequest, _) = dns_request.into_parts();
211
212 match serial_response.send_response(io_stream.send_message(dns_request)) {
217 Ok(()) => (),
218 Err(_) => {
219 debug!("failed to associate send_message response to the sender");
220 }
221 }
222 }
223 Poll::Pending => return Poll::Pending,
225 Poll::Ready(None) => {
226 io_stream.shutdown();
228
229 }
231 }
232
233 }
235 }
236}
237
238pub struct DnsExchangeConnect<F, S, TE>(DnsExchangeConnectInner<F, S, TE>)
247where
248 F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
249 S: DnsRequestSender + 'static,
250 TE: Time + Unpin;
251
252impl<F, S, TE> DnsExchangeConnect<F, S, TE>
253where
254 F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
255 S: DnsRequestSender + 'static,
256 TE: Time + Unpin,
257{
258 fn connect(
259 connect_future: F,
260 outbound_messages: mpsc::Receiver<OneshotDnsRequest>,
261 sender: BufDnsRequestStreamHandle,
262 ) -> Self {
263 Self(DnsExchangeConnectInner::Connecting {
264 connect_future,
265 outbound_messages: Some(outbound_messages),
266 sender: Some(sender),
267 })
268 }
269}
270
271#[allow(clippy::type_complexity)]
272impl<F, S, TE> Future for DnsExchangeConnect<F, S, TE>
273where
274 F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
275 S: DnsRequestSender + 'static + Send + Unpin,
276 TE: Time + Unpin,
277{
278 type Output = Result<(DnsExchange, DnsExchangeBackground<S, TE>), ProtoError>;
279
280 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
281 self.0.poll_unpin(cx)
282 }
283}
284
285enum DnsExchangeConnectInner<F, S, TE>
286where
287 F: Future<Output = Result<S, ProtoError>> + 'static + Send,
288 S: DnsRequestSender + 'static + Send,
289 TE: Time + Unpin,
290{
291 Connecting {
292 connect_future: F,
293 outbound_messages: Option<mpsc::Receiver<OneshotDnsRequest>>,
294 sender: Option<BufDnsRequestStreamHandle>,
295 },
296 Connected {
297 exchange: DnsExchange,
298 background: Option<DnsExchangeBackground<S, TE>>,
299 },
300 FailAll {
301 error: ProtoError,
302 outbound_messages: mpsc::Receiver<OneshotDnsRequest>,
303 },
304 Error(ProtoError),
305}
306
307#[allow(clippy::type_complexity)]
308impl<F, S, TE> Future for DnsExchangeConnectInner<F, S, TE>
309where
310 F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
311 S: DnsRequestSender + 'static + Send + Unpin,
312 TE: Time + Unpin,
313{
314 type Output = Result<(DnsExchange, DnsExchangeBackground<S, TE>), ProtoError>;
315
316 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
317 loop {
318 let next;
319 match &mut *self {
320 Self::Connecting {
321 connect_future,
322 outbound_messages,
323 sender,
324 } => {
325 let connect_future = Pin::new(connect_future);
326 match connect_future.poll(cx) {
327 Poll::Ready(Ok(stream)) => {
328 let (exchange, background) = DnsExchange::from_stream_with_receiver(
331 stream,
332 outbound_messages
333 .take()
334 .expect("cannot poll after complete"),
335 sender.take().expect("cannot poll after complete"),
336 );
337
338 next = Self::Connected {
339 exchange,
340 background: Some(background),
341 };
342 }
343 Poll::Pending => return Poll::Pending,
344 Poll::Ready(Err(error)) => {
345 debug!(error = error.as_dyn(), "stream errored while connecting");
346 next = Self::FailAll {
347 error,
348 outbound_messages: outbound_messages
349 .take()
350 .expect("cannot poll after complete"),
351 }
352 }
353 };
354 }
355 Self::Connected {
356 exchange,
357 background,
358 } => {
359 let exchange = exchange.clone();
360 let background = background.take().expect("cannot poll after complete");
361
362 return Poll::Ready(Ok((exchange, background)));
363 }
364 Self::FailAll {
365 error,
366 outbound_messages,
367 } => {
368 while let Some(outbound_message) = match outbound_messages.poll_next_unpin(cx) {
369 Poll::Ready(opt) => opt,
370 Poll::Pending => return Poll::Pending,
371 } {
372 outbound_message
374 .into_parts()
375 .1
376 .send_response(error.clone().into())
377 .ok();
378 }
379
380 return Poll::Ready(Err(error.clone()));
381 }
382 Self::Error(error) => return Poll::Ready(Err(error.clone())),
383 }
384
385 *self = next;
386 }
387 }
388}