1#[cfg(feature = "std")]
8use core::fmt::Display;
9use core::fmt::{self, Debug};
10use core::future::Future;
11use core::pin::Pin;
12use core::task::{Context, Poll};
13use core::time::Duration;
14#[cfg(feature = "std")]
15use std::net::SocketAddr;
16
17#[cfg(feature = "std")]
18use futures_channel::mpsc;
19#[cfg(feature = "std")]
20use futures_channel::oneshot;
21use futures_util::ready;
22#[cfg(feature = "std")]
23use futures_util::stream::{Fuse, Peekable};
24use futures_util::stream::{Stream, StreamExt};
25#[cfg(feature = "serde")]
26use serde::{Deserialize, Serialize};
27#[cfg(feature = "std")]
28use tracing::{debug, warn};
29
30use crate::error::{ProtoError, ProtoErrorKind};
31#[cfg(feature = "std")]
32use crate::runtime::Time;
33
34#[cfg(feature = "std")]
35mod dns_exchange;
36pub mod dns_handle;
37#[cfg(feature = "std")]
38pub mod dns_multiplexer;
39pub mod dns_request;
40pub mod dns_response;
41pub mod retry_dns_handle;
42mod serial_message;
43
44#[cfg(feature = "std")]
45pub use self::dns_exchange::{
46 DnsExchange, DnsExchangeBackground, DnsExchangeConnect, DnsExchangeSend,
47};
48pub use self::dns_handle::{DnsHandle, DnsStreamHandle};
49#[cfg(feature = "std")]
50pub use self::dns_multiplexer::{DnsMultiplexer, DnsMultiplexerConnect};
51pub use self::dns_request::{DnsRequest, DnsRequestOptions};
52pub use self::dns_response::DnsResponse;
53#[cfg(feature = "std")]
54pub use self::dns_response::DnsResponseStream;
55pub use self::retry_dns_handle::RetryDnsHandle;
56pub use self::serial_message::SerialMessage;
57
58#[cfg(feature = "std")]
60fn ignore_send<M, T>(result: Result<M, mpsc::TrySendError<T>>) {
61 if let Err(error) = result {
62 if error.is_disconnected() {
63 debug!("ignoring send error on disconnected stream");
64 return;
65 }
66
67 warn!("error notifying wait, possible future leak: {:?}", error);
68 }
69}
70
71#[cfg(feature = "std")]
73pub trait DnsClientStream:
74 Stream<Item = Result<SerialMessage, ProtoError>> + Display + Send
75{
76 type Time: Time;
78
79 fn name_server_addr(&self) -> SocketAddr;
81}
82
83#[cfg(feature = "std")]
85pub type StreamReceiver = Peekable<Fuse<mpsc::Receiver<SerialMessage>>>;
86
87#[cfg(feature = "std")]
88const CHANNEL_BUFFER_SIZE: usize = 32;
89
90#[derive(Clone)]
94#[cfg(feature = "std")]
95pub struct BufDnsStreamHandle {
96 remote_addr: SocketAddr,
97 sender: mpsc::Sender<SerialMessage>,
98}
99
100#[cfg(feature = "std")]
101impl BufDnsStreamHandle {
102 pub fn new(remote_addr: SocketAddr) -> (Self, StreamReceiver) {
109 let (sender, receiver) = mpsc::channel(CHANNEL_BUFFER_SIZE);
110 let receiver = receiver.fuse().peekable();
111
112 let this = Self {
113 remote_addr,
114 sender,
115 };
116
117 (this, receiver)
118 }
119
120 pub fn with_remote_addr(&self, remote_addr: SocketAddr) -> Self {
124 Self {
125 remote_addr,
126 sender: self.sender.clone(),
127 }
128 }
129}
130
131#[cfg(feature = "std")]
132impl DnsStreamHandle for BufDnsStreamHandle {
133 fn send(&mut self, buffer: SerialMessage) -> Result<(), ProtoError> {
134 let sender: &mut _ = &mut self.sender;
135 sender
136 .try_send(SerialMessage::new(buffer.into_parts().0, self.remote_addr))
137 .map_err(|e| ProtoError::from(format!("mpsc::SendError {e}")))
138 }
139}
140
141#[cfg(feature = "std")]
147pub trait DnsRequestSender: Stream<Item = Result<(), ProtoError>> + Send + Unpin + 'static {
148 fn send_message(&mut self, request: DnsRequest) -> DnsResponseStream;
154
155 fn shutdown(&mut self);
159
160 fn is_shutdown(&self) -> bool;
162}
163
164#[derive(Clone)]
166#[cfg(feature = "std")]
167pub struct BufDnsRequestStreamHandle {
168 sender: mpsc::Sender<OneshotDnsRequest>,
169}
170
171#[cfg(feature = "std")]
172macro_rules! try_oneshot {
173 ($expr:expr) => {{
174 use core::result::Result;
175
176 match $expr {
177 Result::Ok(val) => val,
178 Result::Err(err) => return DnsResponseReceiver::Err(Some(ProtoError::from(err))),
179 }
180 }};
181 ($expr:expr,) => {
182 $expr?
183 };
184}
185
186#[cfg(feature = "std")]
187impl DnsHandle for BufDnsRequestStreamHandle {
188 type Response = DnsResponseReceiver;
189
190 fn send<R: Into<DnsRequest>>(&self, request: R) -> Self::Response {
191 let request: DnsRequest = request.into();
192 debug!(
193 "enqueueing message:{}:{:?}",
194 request.op_code(),
195 request.queries()
196 );
197
198 let (request, oneshot) = OneshotDnsRequest::oneshot(request);
199 let mut sender = self.sender.clone();
200 let try_send = sender.try_send(request).map_err(|_| {
201 debug!("unable to enqueue message");
202 ProtoError::from(ProtoErrorKind::Busy)
203 });
204 try_oneshot!(try_send);
205
206 DnsResponseReceiver::Receiver(oneshot)
207 }
208}
209
210#[cfg(feature = "std")]
213pub struct OneshotDnsRequest {
214 dns_request: DnsRequest,
215 sender_for_response: oneshot::Sender<DnsResponseStream>,
216}
217
218#[cfg(feature = "std")]
219impl OneshotDnsRequest {
220 #[cfg(any(feature = "std", feature = "no-std-rand"))]
221 fn oneshot(dns_request: DnsRequest) -> (Self, oneshot::Receiver<DnsResponseStream>) {
222 let (sender_for_response, receiver) = oneshot::channel();
223
224 (
225 Self {
226 dns_request,
227 sender_for_response,
228 },
229 receiver,
230 )
231 }
232
233 fn into_parts(self) -> (DnsRequest, OneshotDnsResponse) {
234 (
235 self.dns_request,
236 OneshotDnsResponse(self.sender_for_response),
237 )
238 }
239}
240
241#[cfg(feature = "std")]
242struct OneshotDnsResponse(oneshot::Sender<DnsResponseStream>);
243
244#[cfg(feature = "std")]
245impl OneshotDnsResponse {
246 fn send_response(self, serial_response: DnsResponseStream) -> Result<(), DnsResponseStream> {
247 self.0.send(serial_response)
248 }
249}
250
251#[cfg(feature = "std")]
253pub enum DnsResponseReceiver {
254 Receiver(oneshot::Receiver<DnsResponseStream>),
256 Received(DnsResponseStream),
258 Err(Option<ProtoError>),
260}
261
262#[cfg(feature = "std")]
263impl Stream for DnsResponseReceiver {
264 type Item = Result<DnsResponse, ProtoError>;
265
266 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
267 loop {
268 *self = match &mut *self {
269 Self::Receiver(receiver) => {
270 let receiver = Pin::new(receiver);
271 let future = ready!(
272 receiver
273 .poll(cx)
274 .map_err(|_| ProtoError::from("receiver was canceled"))
275 )?;
276 Self::Received(future)
277 }
278 Self::Received(stream) => {
279 return stream.poll_next_unpin(cx);
280 }
281 Self::Err(err) => return Poll::Ready(err.take().map(Err)),
282 };
283 }
284 }
285}
286
287pub trait FirstAnswer<T, E: From<ProtoError>>: Stream<Item = Result<T, E>> + Unpin + Sized {
289 fn first_answer(self) -> FirstAnswerFuture<Self> {
292 FirstAnswerFuture { stream: Some(self) }
293 }
294}
295
296impl<E, S, T> FirstAnswer<T, E> for S
297where
298 S: Stream<Item = Result<T, E>> + Unpin + Sized,
299 E: From<ProtoError>,
300{
301}
302
303#[derive(Debug)]
305#[must_use = "futures do nothing unless you `.await` or poll them"]
306pub struct FirstAnswerFuture<S> {
307 stream: Option<S>,
308}
309
310impl<E, S: Stream<Item = Result<T, E>> + Unpin, T> Future for FirstAnswerFuture<S>
311where
312 S: Stream<Item = Result<T, E>> + Unpin + Sized,
313 E: From<ProtoError>,
314{
315 type Output = S::Item;
316
317 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
318 let s = self
319 .stream
320 .as_mut()
321 .expect("polling FirstAnswerFuture twice");
322 let item = match ready!(s.poll_next_unpin(cx)) {
323 Some(r) => r,
324 None => Err(ProtoError::from(ProtoErrorKind::Timeout).into()),
325 };
326 self.stream.take();
327 Poll::Ready(item)
328 }
329}
330
331#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
333#[cfg_attr(
334 feature = "serde",
335 derive(Serialize, Deserialize),
336 serde(rename_all = "lowercase")
337)]
338#[non_exhaustive]
339pub enum Protocol {
340 Udp,
342 Tcp,
344 #[cfg(feature = "__tls")]
346 Tls,
347 #[cfg(feature = "__https")]
349 Https,
350 #[cfg(feature = "__quic")]
352 Quic,
353 #[cfg(feature = "__h3")]
355 H3,
356}
357
358impl fmt::Display for Protocol {
359 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
360 let protocol = match self {
361 Self::Udp => "udp",
362 Self::Tcp => "tcp",
363 #[cfg(feature = "__tls")]
364 Self::Tls => "tls",
365 #[cfg(feature = "__https")]
366 Self::Https => "https",
367 #[cfg(feature = "__quic")]
368 Self::Quic => "quic",
369 #[cfg(feature = "__h3")]
370 Self::H3 => "h3",
371 };
372
373 f.write_str(protocol)
374 }
375}
376
377impl Protocol {
378 pub fn is_datagram(self) -> bool {
380 match self {
381 Self::Udp => true,
382 Self::Tcp => false,
383 #[cfg(feature = "__tls")]
384 Self::Tls => false,
385 #[cfg(feature = "__https")]
386 Self::Https => false,
387 #[cfg(feature = "__quic")]
389 Self::Quic => true,
390 #[cfg(feature = "__h3")]
391 Self::H3 => true,
392 }
393 }
394
395 pub fn is_stream(self) -> bool {
397 !self.is_datagram()
398 }
399
400 pub fn is_encrypted(self) -> bool {
402 match self {
403 Self::Udp => false,
404 Self::Tcp => false,
405 #[cfg(feature = "__tls")]
406 Self::Tls => true,
407 #[cfg(feature = "__https")]
408 Self::Https => true,
409 #[cfg(feature = "__quic")]
410 Self::Quic => true,
411 #[cfg(feature = "__h3")]
412 Self::H3 => true,
413 }
414 }
415}
416
417impl Default for Protocol {
418 fn default() -> Self {
420 Self::Udp
421 }
422}
423
424#[allow(unused)] pub(crate) const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);