broker_tokio/net/udp/
socket.rs

1use crate::future::poll_fn;
2use crate::io::PollEvented;
3use crate::net::udp::split::{split, RecvHalf, SendHalf};
4use crate::net::ToSocketAddrs;
5
6use std::convert::TryFrom;
7use std::fmt;
8use std::io;
9use std::net::{self, Ipv4Addr, Ipv6Addr, SocketAddr};
10use std::task::{Context, Poll};
11
12cfg_udp! {
13    /// A UDP socket
14    pub struct UdpSocket {
15        io: PollEvented<mio::net::UdpSocket>,
16    }
17}
18
19impl UdpSocket {
20    /// This function will create a new UDP socket and attempt to bind it to
21    /// the `addr` provided.
22    pub async fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<UdpSocket> {
23        let addrs = addr.to_socket_addrs().await?;
24        let mut last_err = None;
25
26        for addr in addrs {
27            match UdpSocket::bind_addr(addr) {
28                Ok(socket) => return Ok(socket),
29                Err(e) => last_err = Some(e),
30            }
31        }
32
33        Err(last_err.unwrap_or_else(|| {
34            io::Error::new(
35                io::ErrorKind::InvalidInput,
36                "could not resolve to any addresses",
37            )
38        }))
39    }
40
41    fn bind_addr(addr: SocketAddr) -> io::Result<UdpSocket> {
42        let sys = mio::net::UdpSocket::bind(&addr)?;
43        UdpSocket::new(sys)
44    }
45
46    fn new(socket: mio::net::UdpSocket) -> io::Result<UdpSocket> {
47        let io = PollEvented::new(socket)?;
48        Ok(UdpSocket { io })
49    }
50
51    /// Creates a new `UdpSocket` from the previously bound socket provided.
52    ///
53    /// The socket given will be registered with the event loop that `handle`
54    /// is associated with. This function requires that `socket` has previously
55    /// been bound to an address to work correctly.
56    ///
57    /// This can be used in conjunction with net2's `UdpBuilder` interface to
58    /// configure a socket before it's handed off, such as setting options like
59    /// `reuse_address` or binding to multiple addresses.
60    ///
61    /// # Panics
62    ///
63    /// This function panics if thread-local runtime is not set.
64    ///
65    /// The runtime is usually set implicitly when this function is called
66    /// from a future driven by a tokio runtime, otherwise runtime can be set
67    /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function.
68    pub fn from_std(socket: net::UdpSocket) -> io::Result<UdpSocket> {
69        let io = mio::net::UdpSocket::from_socket(socket)?;
70        let io = PollEvented::new(io)?;
71        Ok(UdpSocket { io })
72    }
73
74    /// Split the `UdpSocket` into a receive half and a send half. The two parts
75    /// can be used to receive and send datagrams concurrently, even from two
76    /// different tasks.
77    ///
78    /// See the module level documenation of [`split`](super::split) for more
79    /// details.
80    pub fn split(self) -> (RecvHalf, SendHalf) {
81        split(self)
82    }
83
84    /// Returns the local address that this socket is bound to.
85    pub fn local_addr(&self) -> io::Result<SocketAddr> {
86        self.io.get_ref().local_addr()
87    }
88
89    /// Connects the UDP socket setting the default destination for send() and
90    /// limiting packets that are read via recv from the address specified in
91    /// `addr`.
92    pub async fn connect<A: ToSocketAddrs>(&self, addr: A) -> io::Result<()> {
93        let addrs = addr.to_socket_addrs().await?;
94        let mut last_err = None;
95
96        for addr in addrs {
97            match self.io.get_ref().connect(addr) {
98                Ok(_) => return Ok(()),
99                Err(e) => last_err = Some(e),
100            }
101        }
102
103        Err(last_err.unwrap_or_else(|| {
104            io::Error::new(
105                io::ErrorKind::InvalidInput,
106                "could not resolve to any addresses",
107            )
108        }))
109    }
110
111    /// Returns a future that sends data on the socket to the remote address to which it is connected.
112    /// On success, the future will resolve to the number of bytes written.
113    ///
114    /// The [`connect`] method will connect this socket to a remote address. The future
115    /// will resolve to an error if the socket is not connected.
116    ///
117    /// [`connect`]: #method.connect
118    pub async fn send(&mut self, buf: &[u8]) -> io::Result<usize> {
119        poll_fn(|cx| self.poll_send(cx, buf)).await
120    }
121
122    // Poll IO functions that takes `&self` are provided for the split API.
123    //
124    // They are not public because (taken from the doc of `PollEvented`):
125    //
126    // While `PollEvented` is `Sync` (if the underlying I/O type is `Sync`), the
127    // caller must ensure that there are at most two tasks that use a
128    // `PollEvented` instance concurrently. One for reading and one for writing.
129    // While violating this requirement is "safe" from a Rust memory model point
130    // of view, it will result in unexpected behavior in the form of lost
131    // notifications and tasks hanging.
132    #[doc(hidden)]
133    pub fn poll_send(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
134        ready!(self.io.poll_write_ready(cx))?;
135
136        match self.io.get_ref().send(buf) {
137            Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
138                self.io.clear_write_ready(cx)?;
139                Poll::Pending
140            }
141            x => Poll::Ready(x),
142        }
143    }
144
145    /// Returns a future that receives a single datagram message on the socket from
146    /// the remote address to which it is connected. On success, the future will resolve
147    /// to the number of bytes read.
148    ///
149    /// The function must be called with valid byte array `buf` of sufficient size to
150    /// hold the message bytes. If a message is too long to fit in the supplied buffer,
151    /// excess bytes may be discarded.
152    ///
153    /// The [`connect`] method will connect this socket to a remote address. The future
154    /// will fail if the socket is not connected.
155    ///
156    /// [`connect`]: #method.connect
157    pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> {
158        poll_fn(|cx| self.poll_recv(cx, buf)).await
159    }
160
161    #[doc(hidden)]
162    pub fn poll_recv(&self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
163        ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?;
164
165        match self.io.get_ref().recv(buf) {
166            Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
167                self.io.clear_read_ready(cx, mio::Ready::readable())?;
168                Poll::Pending
169            }
170            x => Poll::Ready(x),
171        }
172    }
173
174    /// Returns a future that sends data on the socket to the given address.
175    /// On success, the future will resolve to the number of bytes written.
176    ///
177    /// The future will resolve to an error if the IP version of the socket does
178    /// not match that of `target`.
179    pub async fn send_to<A: ToSocketAddrs>(&mut self, buf: &[u8], target: A) -> io::Result<usize> {
180        let mut addrs = target.to_socket_addrs().await?;
181
182        match addrs.next() {
183            Some(target) => poll_fn(|cx| self.poll_send_to(cx, buf, &target)).await,
184            None => Err(io::Error::new(
185                io::ErrorKind::InvalidInput,
186                "no addresses to send data to",
187            )),
188        }
189    }
190
191    // TODO: Public or not?
192    #[doc(hidden)]
193    pub fn poll_send_to(
194        &self,
195        cx: &mut Context<'_>,
196        buf: &[u8],
197        target: &SocketAddr,
198    ) -> Poll<io::Result<usize>> {
199        ready!(self.io.poll_write_ready(cx))?;
200
201        match self.io.get_ref().send_to(buf, target) {
202            Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
203                self.io.clear_write_ready(cx)?;
204                Poll::Pending
205            }
206            x => Poll::Ready(x),
207        }
208    }
209
210    /// Returns a future that receives a single datagram on the socket. On success,
211    /// the future resolves to the number of bytes read and the origin.
212    ///
213    /// The function must be called with valid byte array `buf` of sufficient size
214    /// to hold the message bytes. If a message is too long to fit in the supplied
215    /// buffer, excess bytes may be discarded.
216    pub async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
217        poll_fn(|cx| self.poll_recv_from(cx, buf)).await
218    }
219
220    #[doc(hidden)]
221    pub fn poll_recv_from(
222        &self,
223        cx: &mut Context<'_>,
224        buf: &mut [u8],
225    ) -> Poll<Result<(usize, SocketAddr), io::Error>> {
226        ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?;
227
228        match self.io.get_ref().recv_from(buf) {
229            Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
230                self.io.clear_read_ready(cx, mio::Ready::readable())?;
231                Poll::Pending
232            }
233            x => Poll::Ready(x),
234        }
235    }
236
237    /// Gets the value of the `SO_BROADCAST` option for this socket.
238    ///
239    /// For more information about this option, see [`set_broadcast`].
240    ///
241    /// [`set_broadcast`]: #method.set_broadcast
242    pub fn broadcast(&self) -> io::Result<bool> {
243        self.io.get_ref().broadcast()
244    }
245
246    /// Sets the value of the `SO_BROADCAST` option for this socket.
247    ///
248    /// When enabled, this socket is allowed to send packets to a broadcast
249    /// address.
250    pub fn set_broadcast(&self, on: bool) -> io::Result<()> {
251        self.io.get_ref().set_broadcast(on)
252    }
253
254    /// Gets the value of the `IP_MULTICAST_LOOP` option for this socket.
255    ///
256    /// For more information about this option, see [`set_multicast_loop_v4`].
257    ///
258    /// [`set_multicast_loop_v4`]: #method.set_multicast_loop_v4
259    pub fn multicast_loop_v4(&self) -> io::Result<bool> {
260        self.io.get_ref().multicast_loop_v4()
261    }
262
263    /// Sets the value of the `IP_MULTICAST_LOOP` option for this socket.
264    ///
265    /// If enabled, multicast packets will be looped back to the local socket.
266    ///
267    /// # Note
268    ///
269    /// This may not have any affect on IPv6 sockets.
270    pub fn set_multicast_loop_v4(&self, on: bool) -> io::Result<()> {
271        self.io.get_ref().set_multicast_loop_v4(on)
272    }
273
274    /// Gets the value of the `IP_MULTICAST_TTL` option for this socket.
275    ///
276    /// For more information about this option, see [`set_multicast_ttl_v4`].
277    ///
278    /// [`set_multicast_ttl_v4`]: #method.set_multicast_ttl_v4
279    pub fn multicast_ttl_v4(&self) -> io::Result<u32> {
280        self.io.get_ref().multicast_ttl_v4()
281    }
282
283    /// Sets the value of the `IP_MULTICAST_TTL` option for this socket.
284    ///
285    /// Indicates the time-to-live value of outgoing multicast packets for
286    /// this socket. The default value is 1 which means that multicast packets
287    /// don't leave the local network unless explicitly requested.
288    ///
289    /// # Note
290    ///
291    /// This may not have any affect on IPv6 sockets.
292    pub fn set_multicast_ttl_v4(&self, ttl: u32) -> io::Result<()> {
293        self.io.get_ref().set_multicast_ttl_v4(ttl)
294    }
295
296    /// Gets the value of the `IPV6_MULTICAST_LOOP` option for this socket.
297    ///
298    /// For more information about this option, see [`set_multicast_loop_v6`].
299    ///
300    /// [`set_multicast_loop_v6`]: #method.set_multicast_loop_v6
301    pub fn multicast_loop_v6(&self) -> io::Result<bool> {
302        self.io.get_ref().multicast_loop_v6()
303    }
304
305    /// Sets the value of the `IPV6_MULTICAST_LOOP` option for this socket.
306    ///
307    /// Controls whether this socket sees the multicast packets it sends itself.
308    ///
309    /// # Note
310    ///
311    /// This may not have any affect on IPv4 sockets.
312    pub fn set_multicast_loop_v6(&self, on: bool) -> io::Result<()> {
313        self.io.get_ref().set_multicast_loop_v6(on)
314    }
315
316    /// Gets the value of the `IP_TTL` option for this socket.
317    ///
318    /// For more information about this option, see [`set_ttl`].
319    ///
320    /// [`set_ttl`]: #method.set_ttl
321    pub fn ttl(&self) -> io::Result<u32> {
322        self.io.get_ref().ttl()
323    }
324
325    /// Sets the value for the `IP_TTL` option on this socket.
326    ///
327    /// This value sets the time-to-live field that is used in every packet sent
328    /// from this socket.
329    pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
330        self.io.get_ref().set_ttl(ttl)
331    }
332
333    /// Executes an operation of the `IP_ADD_MEMBERSHIP` type.
334    ///
335    /// This function specifies a new multicast group for this socket to join.
336    /// The address must be a valid multicast address, and `interface` is the
337    /// address of the local interface with which the system should join the
338    /// multicast group. If it's equal to `INADDR_ANY` then an appropriate
339    /// interface is chosen by the system.
340    pub fn join_multicast_v4(&self, multiaddr: Ipv4Addr, interface: Ipv4Addr) -> io::Result<()> {
341        self.io.get_ref().join_multicast_v4(&multiaddr, &interface)
342    }
343
344    /// Executes an operation of the `IPV6_ADD_MEMBERSHIP` type.
345    ///
346    /// This function specifies a new multicast group for this socket to join.
347    /// The address must be a valid multicast address, and `interface` is the
348    /// index of the interface to join/leave (or 0 to indicate any interface).
349    pub fn join_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
350        self.io.get_ref().join_multicast_v6(multiaddr, interface)
351    }
352
353    /// Executes an operation of the `IP_DROP_MEMBERSHIP` type.
354    ///
355    /// For more information about this option, see [`join_multicast_v4`].
356    ///
357    /// [`join_multicast_v4`]: #method.join_multicast_v4
358    pub fn leave_multicast_v4(&self, multiaddr: Ipv4Addr, interface: Ipv4Addr) -> io::Result<()> {
359        self.io.get_ref().leave_multicast_v4(&multiaddr, &interface)
360    }
361
362    /// Executes an operation of the `IPV6_DROP_MEMBERSHIP` type.
363    ///
364    /// For more information about this option, see [`join_multicast_v6`].
365    ///
366    /// [`join_multicast_v6`]: #method.join_multicast_v6
367    pub fn leave_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
368        self.io.get_ref().leave_multicast_v6(multiaddr, interface)
369    }
370}
371
372impl TryFrom<UdpSocket> for mio::net::UdpSocket {
373    type Error = io::Error;
374
375    /// Consumes value, returning the mio I/O object.
376    ///
377    /// See [`PollEvented::into_inner`] for more details about
378    /// resource deregistration that happens during the call.
379    ///
380    /// [`PollEvented::into_inner`]: crate::io::PollEvented::into_inner
381    fn try_from(value: UdpSocket) -> Result<Self, Self::Error> {
382        value.io.into_inner()
383    }
384}
385
386impl TryFrom<net::UdpSocket> for UdpSocket {
387    type Error = io::Error;
388
389    /// Consumes stream, returning the tokio I/O object.
390    ///
391    /// This is equivalent to
392    /// [`UdpSocket::from_std(stream)`](UdpSocket::from_std).
393    fn try_from(stream: net::UdpSocket) -> Result<Self, Self::Error> {
394        Self::from_std(stream)
395    }
396}
397
398impl fmt::Debug for UdpSocket {
399    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
400        self.io.get_ref().fmt(f)
401    }
402}
403
404#[cfg(all(unix))]
405mod sys {
406    use super::UdpSocket;
407    use std::os::unix::prelude::*;
408
409    impl AsRawFd for UdpSocket {
410        fn as_raw_fd(&self) -> RawFd {
411            self.io.get_ref().as_raw_fd()
412        }
413    }
414}
415
416#[cfg(windows)]
417mod sys {
418    // TODO: let's land these upstream with mio and then we can add them here.
419    //
420    // use std::os::windows::prelude::*;
421    // use super::UdpSocket;
422    //
423    // impl AsRawHandle for UdpSocket {
424    //     fn as_raw_handle(&self) -> RawHandle {
425    //         self.io.get_ref().as_raw_handle()
426    //     }
427    // }
428}