1use std::marker::PhantomData;
11use std::pin::Pin;
12use std::task::{Context, Poll};
13
14use futures_channel::mpsc;
15use futures_util::future::{Future, FutureExt};
16use futures_util::stream::{Peekable, Stream, StreamExt};
17use tracing::{debug, warn};
18
19use crate::error::*;
20use crate::xfer::dns_handle::DnsHandle;
21use crate::xfer::DnsResponseReceiver;
22use crate::xfer::{
23 BufDnsRequestStreamHandle, DnsRequest, DnsRequestSender, DnsResponse, OneshotDnsRequest,
24 CHANNEL_BUFFER_SIZE,
25};
26use crate::Time;
27
28#[must_use = "futures do nothing unless polled"]
32pub struct DnsExchange {
33 sender: BufDnsRequestStreamHandle,
34}
35
36impl DnsExchange {
37 pub fn from_stream<S, TE>(stream: S) -> (Self, DnsExchangeBackground<S, TE>)
45 where
46 S: DnsRequestSender + 'static + Send + Unpin,
47 {
48 let (sender, outbound_messages) = mpsc::channel(CHANNEL_BUFFER_SIZE);
49 let message_sender = BufDnsRequestStreamHandle { sender };
50
51 Self::from_stream_with_receiver(stream, outbound_messages, message_sender)
52 }
53
54 pub fn from_stream_with_receiver<S, TE>(
56 stream: S,
57 receiver: mpsc::Receiver<OneshotDnsRequest>,
58 sender: BufDnsRequestStreamHandle,
59 ) -> (Self, DnsExchangeBackground<S, TE>)
60 where
61 S: DnsRequestSender + 'static + Send + Unpin,
62 {
63 let background = DnsExchangeBackground {
64 io_stream: stream,
65 outbound_messages: receiver.peekable(),
66 marker: PhantomData,
67 };
68
69 (Self { sender }, background)
70 }
71
72 pub fn connect<F, S, TE>(connect_future: F) -> DnsExchangeConnect<F, S, TE>
76 where
77 F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
78 S: DnsRequestSender + 'static + Send + Unpin,
79 TE: Time + Unpin,
80 {
81 let (sender, outbound_messages) = mpsc::channel(CHANNEL_BUFFER_SIZE);
82 let message_sender = BufDnsRequestStreamHandle { sender };
83
84 DnsExchangeConnect::connect(connect_future, outbound_messages, message_sender)
85 }
86
87 pub fn error<F, S, TE>(error: ProtoError) -> DnsExchangeConnect<F, S, TE>
89 where
90 F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
91 S: DnsRequestSender + 'static + Send + Unpin,
92 TE: Time + Unpin,
93 {
94 DnsExchangeConnect(DnsExchangeConnectInner::Error(error))
95 }
96}
97
98impl Clone for DnsExchange {
99 fn clone(&self) -> Self {
100 Self {
101 sender: self.sender.clone(),
102 }
103 }
104}
105
106impl DnsHandle for DnsExchange {
107 type Response = DnsExchangeSend;
108 type Error = ProtoError;
109
110 fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(&self, request: R) -> Self::Response {
111 DnsExchangeSend {
112 result: self.sender.send(request),
113 _sender: self.sender.clone(), }
115 }
116}
117
118#[must_use = "futures do nothing unless polled"]
120pub struct DnsExchangeSend {
121 result: DnsResponseReceiver,
122 _sender: BufDnsRequestStreamHandle,
123}
124
125impl Stream for DnsExchangeSend {
126 type Item = Result<DnsResponse, ProtoError>;
127
128 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
129 self.result.poll_next_unpin(cx)
131 }
132}
133
134#[must_use = "futures do nothing unless polled"]
138pub struct DnsExchangeBackground<S, TE>
139where
140 S: DnsRequestSender + 'static + Send + Unpin,
141{
142 io_stream: S,
143 outbound_messages: Peekable<mpsc::Receiver<OneshotDnsRequest>>,
144 marker: PhantomData<TE>,
145}
146
147impl<S, TE> DnsExchangeBackground<S, TE>
148where
149 S: DnsRequestSender + 'static + Send + Unpin,
150{
151 fn pollable_split(&mut self) -> (&mut S, &mut Peekable<mpsc::Receiver<OneshotDnsRequest>>) {
152 (&mut self.io_stream, &mut self.outbound_messages)
153 }
154}
155
156impl<S, TE> Future for DnsExchangeBackground<S, TE>
157where
158 S: DnsRequestSender + 'static + Send + Unpin,
159 TE: Time + Unpin,
160{
161 type Output = Result<(), ProtoError>;
162
163 #[allow(clippy::unused_unit)]
164 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
165 let (io_stream, outbound_messages) = self.pollable_split();
166 let mut io_stream = Pin::new(io_stream);
167 let mut outbound_messages = Pin::new(outbound_messages);
168
169 loop {
172 match io_stream.as_mut().poll_next(cx) {
174 Poll::Ready(Some(Ok(()))) => (),
176 Poll::Pending => {
177 if io_stream.is_shutdown() {
178 return Poll::Pending;
180 }
181
182 ()
184 } Poll::Ready(None) => {
186 debug!("io_stream is done, shutting down");
187 return Poll::Ready(Ok(()));
190 }
191 Poll::Ready(Some(Err(err))) => {
192 debug!(
193 error = err.as_dyn(),
194 "io_stream hit an error, shutting down"
195 );
196
197 return Poll::Ready(Err(err));
198 }
199 }
200
201 match outbound_messages.as_mut().poll_next(cx) {
203 Poll::Ready(Some(dns_request)) => {
205 let (dns_request, serial_response): (DnsRequest, _) = dns_request.into_parts();
207
208 match serial_response.send_response(io_stream.send_message(dns_request)) {
213 Ok(()) => (),
214 Err(_) => {
215 warn!("failed to associate send_message response to the sender");
216 }
217 }
218 }
219 Poll::Pending => return Poll::Pending,
221 Poll::Ready(None) => {
222 io_stream.shutdown();
224
225 }
227 }
228
229 }
231 }
232}
233
234pub struct DnsExchangeConnect<F, S, TE>(DnsExchangeConnectInner<F, S, TE>)
243where
244 F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
245 S: DnsRequestSender + 'static,
246 TE: Time + Unpin;
247
248impl<F, S, TE> DnsExchangeConnect<F, S, TE>
249where
250 F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
251 S: DnsRequestSender + 'static,
252 TE: Time + Unpin,
253{
254 fn connect(
255 connect_future: F,
256 outbound_messages: mpsc::Receiver<OneshotDnsRequest>,
257 sender: BufDnsRequestStreamHandle,
258 ) -> Self {
259 Self(DnsExchangeConnectInner::Connecting {
260 connect_future,
261 outbound_messages: Some(outbound_messages),
262 sender: Some(sender),
263 })
264 }
265}
266
267#[allow(clippy::type_complexity)]
268impl<F, S, TE> Future for DnsExchangeConnect<F, S, TE>
269where
270 F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
271 S: DnsRequestSender + 'static + Send + Unpin,
272 TE: Time + Unpin,
273{
274 type Output = Result<(DnsExchange, DnsExchangeBackground<S, TE>), ProtoError>;
275
276 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
277 self.0.poll_unpin(cx)
278 }
279}
280
281enum DnsExchangeConnectInner<F, S, TE>
282where
283 F: Future<Output = Result<S, ProtoError>> + 'static + Send,
284 S: DnsRequestSender + 'static + Send,
285 TE: Time + Unpin,
286{
287 Connecting {
288 connect_future: F,
289 outbound_messages: Option<mpsc::Receiver<OneshotDnsRequest>>,
290 sender: Option<BufDnsRequestStreamHandle>,
291 },
292 Connected {
293 exchange: DnsExchange,
294 background: Option<DnsExchangeBackground<S, TE>>,
295 },
296 FailAll {
297 error: ProtoError,
298 outbound_messages: mpsc::Receiver<OneshotDnsRequest>,
299 },
300 Error(ProtoError),
301}
302
303#[allow(clippy::type_complexity)]
304impl<F, S, TE> Future for DnsExchangeConnectInner<F, S, TE>
305where
306 F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
307 S: DnsRequestSender + 'static + Send + Unpin,
308 TE: Time + Unpin,
309{
310 type Output = Result<(DnsExchange, DnsExchangeBackground<S, TE>), ProtoError>;
311
312 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
313 loop {
314 let next;
315 match *self {
316 Self::Connecting {
317 ref mut connect_future,
318 ref mut outbound_messages,
319 ref mut sender,
320 } => {
321 let connect_future = Pin::new(connect_future);
322 match connect_future.poll(cx) {
323 Poll::Ready(Ok(stream)) => {
324 let (exchange, background) = DnsExchange::from_stream_with_receiver(
327 stream,
328 outbound_messages
329 .take()
330 .expect("cannot poll after complete"),
331 sender.take().expect("cannot poll after complete"),
332 );
333
334 next = Self::Connected {
335 exchange,
336 background: Some(background),
337 };
338 }
339 Poll::Pending => return Poll::Pending,
340 Poll::Ready(Err(error)) => {
341 debug!(error = error.as_dyn(), "stream errored while connecting");
342 next = Self::FailAll {
343 error,
344 outbound_messages: outbound_messages
345 .take()
346 .expect("cannot poll after complete"),
347 }
348 }
349 };
350 }
351 Self::Connected {
352 ref exchange,
353 ref mut background,
354 } => {
355 let exchange = exchange.clone();
356 let background = background.take().expect("cannot poll after complete");
357
358 return Poll::Ready(Ok((exchange, background)));
359 }
360 Self::FailAll {
361 ref error,
362 ref mut outbound_messages,
363 } => {
364 while let Some(outbound_message) = match outbound_messages.poll_next_unpin(cx) {
365 Poll::Ready(opt) => opt,
366 Poll::Pending => return Poll::Pending,
367 } {
368 outbound_message
370 .into_parts()
371 .1
372 .send_response(error.clone().into())
373 .ok();
374 }
375
376 return Poll::Ready(Err(error.clone()));
377 }
378 Self::Error(ref error) => return Poll::Ready(Err(error.clone())),
379 }
380
381 *self = next;
382 }
383 }
384}