1use std::io;
9use std::marker::PhantomData;
10use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
11use std::pin::Pin;
12use std::sync::Arc;
13use std::task::{Context, Poll};
14
15use async_trait::async_trait;
16use futures_util::stream::Stream;
17use futures_util::{future::Future, ready, TryFutureExt};
18use rand;
19use rand::distributions::{uniform::Uniform, Distribution};
20use tracing::{debug, warn};
21
22use crate::udp::MAX_RECEIVE_BUFFER_SIZE;
23use crate::xfer::{BufDnsStreamHandle, SerialMessage, StreamReceiver};
24use crate::Time;
25
26pub(crate) type UdpCreator<S> = Arc<
27 dyn Send
28 + Sync
29 + (Fn(
30 SocketAddr, SocketAddr, ) -> Pin<Box<dyn Send + (Future<Output = Result<S, std::io::Error>>)>>),
33>;
34
35#[async_trait]
37pub trait DnsUdpSocket
38where
39 Self: Send + Sync + Sized + Unpin,
40{
41 type Time: Time;
43
44 fn poll_recv_from(
47 &self,
48 cx: &mut Context<'_>,
49 buf: &mut [u8],
50 ) -> Poll<io::Result<(usize, SocketAddr)>>;
51
52 async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
55 futures_util::future::poll_fn(|cx| self.poll_recv_from(cx, buf)).await
56 }
57
58 fn poll_send_to(
60 &self,
61 cx: &mut Context<'_>,
62 buf: &[u8],
63 target: SocketAddr,
64 ) -> Poll<io::Result<usize>>;
65
66 async fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
68 futures_util::future::poll_fn(|cx| self.poll_send_to(cx, buf, target)).await
69 }
70}
71
72#[async_trait]
74pub trait UdpSocket: DnsUdpSocket {
75 async fn connect(addr: SocketAddr) -> io::Result<Self>;
77
78 async fn connect_with_bind(addr: SocketAddr, bind_addr: SocketAddr) -> io::Result<Self>;
80
81 async fn bind(addr: SocketAddr) -> io::Result<Self>;
83}
84
85#[must_use = "futures do nothing unless polled"]
87pub struct UdpStream<S: Send> {
88 socket: S,
89 outbound_messages: StreamReceiver,
90}
91
92pub trait QuicLocalAddr {
94 fn local_addr(&self) -> std::io::Result<std::net::SocketAddr>;
96}
97
98#[cfg(feature = "tokio-runtime")]
99use tokio::net::UdpSocket as TokioUdpSocket;
100
101#[cfg(feature = "tokio-runtime")]
102#[cfg_attr(docsrs, doc(cfg(feature = "tokio-runtime")))]
103#[allow(unreachable_pub)]
104impl QuicLocalAddr for TokioUdpSocket {
105 fn local_addr(&self) -> std::io::Result<SocketAddr> {
106 self.local_addr()
107 }
108}
109
110impl<S: UdpSocket + Send + 'static> UdpStream<S> {
111 #[allow(clippy::type_complexity)]
125 pub fn new(
126 remote_addr: SocketAddr,
127 bind_addr: Option<SocketAddr>,
128 ) -> (
129 Box<dyn Future<Output = Result<Self, io::Error>> + Send + Unpin>,
130 BufDnsStreamHandle,
131 ) {
132 let (message_sender, outbound_messages) = BufDnsStreamHandle::new(remote_addr);
133
134 let next_socket = NextRandomUdpSocket::new(&remote_addr, &bind_addr);
137
138 let stream = Box::new(next_socket.map_ok(move |socket| Self {
141 socket,
142 outbound_messages,
143 }));
144
145 (stream, message_sender)
146 }
147}
148
149impl<S: DnsUdpSocket + Send + 'static> UdpStream<S> {
150 pub fn with_bound(socket: S, remote_addr: SocketAddr) -> (Self, BufDnsStreamHandle) {
165 let (message_sender, outbound_messages) = BufDnsStreamHandle::new(remote_addr);
166 let stream = Self {
167 socket,
168 outbound_messages,
169 };
170
171 (stream, message_sender)
172 }
173
174 #[allow(unused)]
175 pub(crate) fn from_parts(socket: S, outbound_messages: StreamReceiver) -> Self {
176 Self {
177 socket,
178 outbound_messages,
179 }
180 }
181}
182
183impl<S: Send> UdpStream<S> {
184 #[allow(clippy::type_complexity)]
185 fn pollable_split(&mut self) -> (&mut S, &mut StreamReceiver) {
186 (&mut self.socket, &mut self.outbound_messages)
187 }
188}
189
190impl<S: DnsUdpSocket + Send + 'static> Stream for UdpStream<S> {
191 type Item = Result<SerialMessage, io::Error>;
192
193 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
194 let (socket, outbound_messages) = self.pollable_split();
195 let socket = Pin::new(socket);
196 let mut outbound_messages = Pin::new(outbound_messages);
197
198 while let Poll::Ready(Some(message)) = outbound_messages.as_mut().poll_peek(cx) {
201 let addr = message.addr();
203
204 if let Err(e) = ready!(socket.poll_send_to(cx, message.bytes(), addr)) {
209 warn!(
211 "error sending message to {} on udp_socket, dropping response: {}",
212 addr, e
213 );
214 }
215
216 assert!(outbound_messages.as_mut().poll_next(cx).is_ready());
218 }
219
220 let mut buf = [0u8; MAX_RECEIVE_BUFFER_SIZE];
225 let (len, src) = ready!(socket.poll_recv_from(cx, &mut buf))?;
226
227 let serial_message = SerialMessage::new(buf.iter().take(len).cloned().collect(), src);
228 Poll::Ready(Some(Ok(serial_message)))
229 }
230}
231
232#[must_use = "futures do nothing unless polled"]
233pub(crate) struct NextRandomUdpSocket<S> {
234 name_server: SocketAddr,
235 bind_address: SocketAddr,
236 closure: UdpCreator<S>,
237 marker: PhantomData<S>,
238}
239
240impl<S: UdpSocket + 'static> NextRandomUdpSocket<S> {
241 pub(crate) fn new(name_server: &SocketAddr, bind_addr: &Option<SocketAddr>) -> Self {
246 let bind_address = match bind_addr {
247 Some(ba) => *ba,
248 None => match *name_server {
249 SocketAddr::V4(..) => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0),
250 SocketAddr::V6(..) => {
251 SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0)
252 }
253 },
254 };
255
256 Self {
257 name_server: *name_server,
258 bind_address,
259 closure: Arc::new(|local_addr: _, _server_addr: _| S::bind(local_addr)),
260 marker: PhantomData,
261 }
262 }
263}
264
265impl<S: DnsUdpSocket> NextRandomUdpSocket<S> {
266 pub(crate) fn new_with_closure(name_server: &SocketAddr, func: UdpCreator<S>) -> Self {
268 let bind_address = match *name_server {
269 SocketAddr::V4(..) => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0),
270 SocketAddr::V6(..) => {
271 SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0)
272 }
273 };
274 Self {
275 name_server: *name_server,
276 bind_address,
277 closure: func,
278 marker: PhantomData,
279 }
280 }
281}
282
283impl<S: DnsUdpSocket + Send> Future for NextRandomUdpSocket<S> {
284 type Output = Result<S, io::Error>;
285
286 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
291 if self.bind_address.port() == 0 {
292 let rand_port_range = Uniform::new_inclusive(49152_u16, u16::MAX);
297 let mut rand = rand::thread_rng();
298
299 for attempt in 0..10 {
300 let port = rand_port_range.sample(&mut rand);
301 let bind_addr = SocketAddr::new(self.bind_address.ip(), port);
302
303 match (*self.closure)(bind_addr, self.name_server)
306 .as_mut()
307 .poll(cx)
308 {
309 Poll::Ready(Ok(socket)) => {
310 debug!("created socket successfully");
311 return Poll::Ready(Ok(socket));
312 }
313 Poll::Ready(Err(err)) => match err.kind() {
314 io::ErrorKind::AddrInUse => {
315 debug!("unable to bind port, attempt: {}: {}", attempt, err);
316 }
317 _ => {
318 debug!("failed to bind port: {}", err);
319 return Poll::Ready(Err(err));
320 }
321 },
322 Poll::Pending => debug!("unable to bind port, attempt: {}", attempt),
323 }
324 }
325
326 debug!("could not get next random port, delaying");
327
328 cx.waker().wake_by_ref();
330
331 Poll::Pending
333 } else {
334 (*self.closure)(self.bind_address, self.name_server)
336 .as_mut()
337 .poll(cx)
338 }
339 }
340}
341
342#[cfg(feature = "tokio-runtime")]
343#[async_trait]
344impl UdpSocket for tokio::net::UdpSocket {
345 async fn connect(addr: SocketAddr) -> io::Result<Self> {
349 let bind_addr: SocketAddr = match addr {
350 SocketAddr::V4(_addr) => (Ipv4Addr::UNSPECIFIED, 0).into(),
351 SocketAddr::V6(_addr) => (Ipv6Addr::UNSPECIFIED, 0).into(),
352 };
353
354 Self::connect_with_bind(addr, bind_addr).await
355 }
356
357 async fn connect_with_bind(_addr: SocketAddr, bind_addr: SocketAddr) -> io::Result<Self> {
359 let socket = Self::bind(bind_addr).await?;
360
361 Ok(socket)
365 }
366
367 async fn bind(addr: SocketAddr) -> io::Result<Self> {
368 Self::bind(addr).await
369 }
370}
371
372#[cfg(feature = "tokio-runtime")]
373#[async_trait]
374impl DnsUdpSocket for tokio::net::UdpSocket {
375 type Time = crate::TokioTime;
376
377 fn poll_recv_from(
378 &self,
379 cx: &mut Context<'_>,
380 buf: &mut [u8],
381 ) -> Poll<io::Result<(usize, SocketAddr)>> {
382 let mut buf = tokio::io::ReadBuf::new(buf);
383 let addr = ready!(Self::poll_recv_from(self, cx, &mut buf))?;
384 let len = buf.filled().len();
385
386 Poll::Ready(Ok((len, addr)))
387 }
388
389 fn poll_send_to(
390 &self,
391 cx: &mut Context<'_>,
392 buf: &[u8],
393 target: SocketAddr,
394 ) -> Poll<io::Result<usize>> {
395 Self::poll_send_to(self, cx, buf, target)
396 }
397}
398
399#[cfg(test)]
400#[cfg(feature = "tokio-runtime")]
401mod tests {
402 #[cfg(not(target_os = "linux"))] use std::net::Ipv6Addr;
404 use std::net::{IpAddr, Ipv4Addr};
405 use tokio::{net::UdpSocket as TokioUdpSocket, runtime::Runtime};
406
407 #[test]
408 fn test_next_random_socket() {
409 use crate::tests::next_random_socket_test;
410 let io_loop = Runtime::new().expect("failed to create tokio runtime");
411 next_random_socket_test::<TokioUdpSocket, Runtime>(io_loop)
412 }
413
414 #[test]
415 fn test_udp_stream_ipv4() {
416 use crate::tests::udp_stream_test;
417 let io_loop = Runtime::new().expect("failed to create tokio runtime");
418 io_loop.block_on(udp_stream_test::<TokioUdpSocket>(IpAddr::V4(
419 Ipv4Addr::new(127, 0, 0, 1),
420 )));
421 }
422
423 #[test]
424 #[cfg(not(target_os = "linux"))] fn test_udp_stream_ipv6() {
426 use crate::tests::udp_stream_test;
427 let io_loop = Runtime::new().expect("failed to create tokio runtime");
428 io_loop.block_on(udp_stream_test::<TokioUdpSocket>(IpAddr::V6(
429 Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1),
430 )));
431 }
432}