broker_tokio/net/udp/socket.rs
1use crate::future::poll_fn;
2use crate::io::PollEvented;
3use crate::net::udp::split::{split, RecvHalf, SendHalf};
4use crate::net::ToSocketAddrs;
5
6use std::convert::TryFrom;
7use std::fmt;
8use std::io;
9use std::net::{self, Ipv4Addr, Ipv6Addr, SocketAddr};
10use std::task::{Context, Poll};
11
12cfg_udp! {
13 /// A UDP socket
14 pub struct UdpSocket {
15 io: PollEvented<mio::net::UdpSocket>,
16 }
17}
18
19impl UdpSocket {
20 /// This function will create a new UDP socket and attempt to bind it to
21 /// the `addr` provided.
22 pub async fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<UdpSocket> {
23 let addrs = addr.to_socket_addrs().await?;
24 let mut last_err = None;
25
26 for addr in addrs {
27 match UdpSocket::bind_addr(addr) {
28 Ok(socket) => return Ok(socket),
29 Err(e) => last_err = Some(e),
30 }
31 }
32
33 Err(last_err.unwrap_or_else(|| {
34 io::Error::new(
35 io::ErrorKind::InvalidInput,
36 "could not resolve to any addresses",
37 )
38 }))
39 }
40
41 fn bind_addr(addr: SocketAddr) -> io::Result<UdpSocket> {
42 let sys = mio::net::UdpSocket::bind(&addr)?;
43 UdpSocket::new(sys)
44 }
45
46 fn new(socket: mio::net::UdpSocket) -> io::Result<UdpSocket> {
47 let io = PollEvented::new(socket)?;
48 Ok(UdpSocket { io })
49 }
50
51 /// Creates a new `UdpSocket` from the previously bound socket provided.
52 ///
53 /// The socket given will be registered with the event loop that `handle`
54 /// is associated with. This function requires that `socket` has previously
55 /// been bound to an address to work correctly.
56 ///
57 /// This can be used in conjunction with net2's `UdpBuilder` interface to
58 /// configure a socket before it's handed off, such as setting options like
59 /// `reuse_address` or binding to multiple addresses.
60 ///
61 /// # Panics
62 ///
63 /// This function panics if thread-local runtime is not set.
64 ///
65 /// The runtime is usually set implicitly when this function is called
66 /// from a future driven by a tokio runtime, otherwise runtime can be set
67 /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function.
68 pub fn from_std(socket: net::UdpSocket) -> io::Result<UdpSocket> {
69 let io = mio::net::UdpSocket::from_socket(socket)?;
70 let io = PollEvented::new(io)?;
71 Ok(UdpSocket { io })
72 }
73
74 /// Split the `UdpSocket` into a receive half and a send half. The two parts
75 /// can be used to receive and send datagrams concurrently, even from two
76 /// different tasks.
77 ///
78 /// See the module level documenation of [`split`](super::split) for more
79 /// details.
80 pub fn split(self) -> (RecvHalf, SendHalf) {
81 split(self)
82 }
83
84 /// Returns the local address that this socket is bound to.
85 pub fn local_addr(&self) -> io::Result<SocketAddr> {
86 self.io.get_ref().local_addr()
87 }
88
89 /// Connects the UDP socket setting the default destination for send() and
90 /// limiting packets that are read via recv from the address specified in
91 /// `addr`.
92 pub async fn connect<A: ToSocketAddrs>(&self, addr: A) -> io::Result<()> {
93 let addrs = addr.to_socket_addrs().await?;
94 let mut last_err = None;
95
96 for addr in addrs {
97 match self.io.get_ref().connect(addr) {
98 Ok(_) => return Ok(()),
99 Err(e) => last_err = Some(e),
100 }
101 }
102
103 Err(last_err.unwrap_or_else(|| {
104 io::Error::new(
105 io::ErrorKind::InvalidInput,
106 "could not resolve to any addresses",
107 )
108 }))
109 }
110
111 /// Returns a future that sends data on the socket to the remote address to which it is connected.
112 /// On success, the future will resolve to the number of bytes written.
113 ///
114 /// The [`connect`] method will connect this socket to a remote address. The future
115 /// will resolve to an error if the socket is not connected.
116 ///
117 /// [`connect`]: #method.connect
118 pub async fn send(&mut self, buf: &[u8]) -> io::Result<usize> {
119 poll_fn(|cx| self.poll_send(cx, buf)).await
120 }
121
122 // Poll IO functions that takes `&self` are provided for the split API.
123 //
124 // They are not public because (taken from the doc of `PollEvented`):
125 //
126 // While `PollEvented` is `Sync` (if the underlying I/O type is `Sync`), the
127 // caller must ensure that there are at most two tasks that use a
128 // `PollEvented` instance concurrently. One for reading and one for writing.
129 // While violating this requirement is "safe" from a Rust memory model point
130 // of view, it will result in unexpected behavior in the form of lost
131 // notifications and tasks hanging.
132 #[doc(hidden)]
133 pub fn poll_send(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
134 ready!(self.io.poll_write_ready(cx))?;
135
136 match self.io.get_ref().send(buf) {
137 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
138 self.io.clear_write_ready(cx)?;
139 Poll::Pending
140 }
141 x => Poll::Ready(x),
142 }
143 }
144
145 /// Returns a future that receives a single datagram message on the socket from
146 /// the remote address to which it is connected. On success, the future will resolve
147 /// to the number of bytes read.
148 ///
149 /// The function must be called with valid byte array `buf` of sufficient size to
150 /// hold the message bytes. If a message is too long to fit in the supplied buffer,
151 /// excess bytes may be discarded.
152 ///
153 /// The [`connect`] method will connect this socket to a remote address. The future
154 /// will fail if the socket is not connected.
155 ///
156 /// [`connect`]: #method.connect
157 pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> {
158 poll_fn(|cx| self.poll_recv(cx, buf)).await
159 }
160
161 #[doc(hidden)]
162 pub fn poll_recv(&self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
163 ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?;
164
165 match self.io.get_ref().recv(buf) {
166 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
167 self.io.clear_read_ready(cx, mio::Ready::readable())?;
168 Poll::Pending
169 }
170 x => Poll::Ready(x),
171 }
172 }
173
174 /// Returns a future that sends data on the socket to the given address.
175 /// On success, the future will resolve to the number of bytes written.
176 ///
177 /// The future will resolve to an error if the IP version of the socket does
178 /// not match that of `target`.
179 pub async fn send_to<A: ToSocketAddrs>(&mut self, buf: &[u8], target: A) -> io::Result<usize> {
180 let mut addrs = target.to_socket_addrs().await?;
181
182 match addrs.next() {
183 Some(target) => poll_fn(|cx| self.poll_send_to(cx, buf, &target)).await,
184 None => Err(io::Error::new(
185 io::ErrorKind::InvalidInput,
186 "no addresses to send data to",
187 )),
188 }
189 }
190
191 // TODO: Public or not?
192 #[doc(hidden)]
193 pub fn poll_send_to(
194 &self,
195 cx: &mut Context<'_>,
196 buf: &[u8],
197 target: &SocketAddr,
198 ) -> Poll<io::Result<usize>> {
199 ready!(self.io.poll_write_ready(cx))?;
200
201 match self.io.get_ref().send_to(buf, target) {
202 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
203 self.io.clear_write_ready(cx)?;
204 Poll::Pending
205 }
206 x => Poll::Ready(x),
207 }
208 }
209
210 /// Returns a future that receives a single datagram on the socket. On success,
211 /// the future resolves to the number of bytes read and the origin.
212 ///
213 /// The function must be called with valid byte array `buf` of sufficient size
214 /// to hold the message bytes. If a message is too long to fit in the supplied
215 /// buffer, excess bytes may be discarded.
216 pub async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
217 poll_fn(|cx| self.poll_recv_from(cx, buf)).await
218 }
219
220 #[doc(hidden)]
221 pub fn poll_recv_from(
222 &self,
223 cx: &mut Context<'_>,
224 buf: &mut [u8],
225 ) -> Poll<Result<(usize, SocketAddr), io::Error>> {
226 ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?;
227
228 match self.io.get_ref().recv_from(buf) {
229 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
230 self.io.clear_read_ready(cx, mio::Ready::readable())?;
231 Poll::Pending
232 }
233 x => Poll::Ready(x),
234 }
235 }
236
237 /// Gets the value of the `SO_BROADCAST` option for this socket.
238 ///
239 /// For more information about this option, see [`set_broadcast`].
240 ///
241 /// [`set_broadcast`]: #method.set_broadcast
242 pub fn broadcast(&self) -> io::Result<bool> {
243 self.io.get_ref().broadcast()
244 }
245
246 /// Sets the value of the `SO_BROADCAST` option for this socket.
247 ///
248 /// When enabled, this socket is allowed to send packets to a broadcast
249 /// address.
250 pub fn set_broadcast(&self, on: bool) -> io::Result<()> {
251 self.io.get_ref().set_broadcast(on)
252 }
253
254 /// Gets the value of the `IP_MULTICAST_LOOP` option for this socket.
255 ///
256 /// For more information about this option, see [`set_multicast_loop_v4`].
257 ///
258 /// [`set_multicast_loop_v4`]: #method.set_multicast_loop_v4
259 pub fn multicast_loop_v4(&self) -> io::Result<bool> {
260 self.io.get_ref().multicast_loop_v4()
261 }
262
263 /// Sets the value of the `IP_MULTICAST_LOOP` option for this socket.
264 ///
265 /// If enabled, multicast packets will be looped back to the local socket.
266 ///
267 /// # Note
268 ///
269 /// This may not have any affect on IPv6 sockets.
270 pub fn set_multicast_loop_v4(&self, on: bool) -> io::Result<()> {
271 self.io.get_ref().set_multicast_loop_v4(on)
272 }
273
274 /// Gets the value of the `IP_MULTICAST_TTL` option for this socket.
275 ///
276 /// For more information about this option, see [`set_multicast_ttl_v4`].
277 ///
278 /// [`set_multicast_ttl_v4`]: #method.set_multicast_ttl_v4
279 pub fn multicast_ttl_v4(&self) -> io::Result<u32> {
280 self.io.get_ref().multicast_ttl_v4()
281 }
282
283 /// Sets the value of the `IP_MULTICAST_TTL` option for this socket.
284 ///
285 /// Indicates the time-to-live value of outgoing multicast packets for
286 /// this socket. The default value is 1 which means that multicast packets
287 /// don't leave the local network unless explicitly requested.
288 ///
289 /// # Note
290 ///
291 /// This may not have any affect on IPv6 sockets.
292 pub fn set_multicast_ttl_v4(&self, ttl: u32) -> io::Result<()> {
293 self.io.get_ref().set_multicast_ttl_v4(ttl)
294 }
295
296 /// Gets the value of the `IPV6_MULTICAST_LOOP` option for this socket.
297 ///
298 /// For more information about this option, see [`set_multicast_loop_v6`].
299 ///
300 /// [`set_multicast_loop_v6`]: #method.set_multicast_loop_v6
301 pub fn multicast_loop_v6(&self) -> io::Result<bool> {
302 self.io.get_ref().multicast_loop_v6()
303 }
304
305 /// Sets the value of the `IPV6_MULTICAST_LOOP` option for this socket.
306 ///
307 /// Controls whether this socket sees the multicast packets it sends itself.
308 ///
309 /// # Note
310 ///
311 /// This may not have any affect on IPv4 sockets.
312 pub fn set_multicast_loop_v6(&self, on: bool) -> io::Result<()> {
313 self.io.get_ref().set_multicast_loop_v6(on)
314 }
315
316 /// Gets the value of the `IP_TTL` option for this socket.
317 ///
318 /// For more information about this option, see [`set_ttl`].
319 ///
320 /// [`set_ttl`]: #method.set_ttl
321 pub fn ttl(&self) -> io::Result<u32> {
322 self.io.get_ref().ttl()
323 }
324
325 /// Sets the value for the `IP_TTL` option on this socket.
326 ///
327 /// This value sets the time-to-live field that is used in every packet sent
328 /// from this socket.
329 pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
330 self.io.get_ref().set_ttl(ttl)
331 }
332
333 /// Executes an operation of the `IP_ADD_MEMBERSHIP` type.
334 ///
335 /// This function specifies a new multicast group for this socket to join.
336 /// The address must be a valid multicast address, and `interface` is the
337 /// address of the local interface with which the system should join the
338 /// multicast group. If it's equal to `INADDR_ANY` then an appropriate
339 /// interface is chosen by the system.
340 pub fn join_multicast_v4(&self, multiaddr: Ipv4Addr, interface: Ipv4Addr) -> io::Result<()> {
341 self.io.get_ref().join_multicast_v4(&multiaddr, &interface)
342 }
343
344 /// Executes an operation of the `IPV6_ADD_MEMBERSHIP` type.
345 ///
346 /// This function specifies a new multicast group for this socket to join.
347 /// The address must be a valid multicast address, and `interface` is the
348 /// index of the interface to join/leave (or 0 to indicate any interface).
349 pub fn join_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
350 self.io.get_ref().join_multicast_v6(multiaddr, interface)
351 }
352
353 /// Executes an operation of the `IP_DROP_MEMBERSHIP` type.
354 ///
355 /// For more information about this option, see [`join_multicast_v4`].
356 ///
357 /// [`join_multicast_v4`]: #method.join_multicast_v4
358 pub fn leave_multicast_v4(&self, multiaddr: Ipv4Addr, interface: Ipv4Addr) -> io::Result<()> {
359 self.io.get_ref().leave_multicast_v4(&multiaddr, &interface)
360 }
361
362 /// Executes an operation of the `IPV6_DROP_MEMBERSHIP` type.
363 ///
364 /// For more information about this option, see [`join_multicast_v6`].
365 ///
366 /// [`join_multicast_v6`]: #method.join_multicast_v6
367 pub fn leave_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
368 self.io.get_ref().leave_multicast_v6(multiaddr, interface)
369 }
370}
371
372impl TryFrom<UdpSocket> for mio::net::UdpSocket {
373 type Error = io::Error;
374
375 /// Consumes value, returning the mio I/O object.
376 ///
377 /// See [`PollEvented::into_inner`] for more details about
378 /// resource deregistration that happens during the call.
379 ///
380 /// [`PollEvented::into_inner`]: crate::io::PollEvented::into_inner
381 fn try_from(value: UdpSocket) -> Result<Self, Self::Error> {
382 value.io.into_inner()
383 }
384}
385
386impl TryFrom<net::UdpSocket> for UdpSocket {
387 type Error = io::Error;
388
389 /// Consumes stream, returning the tokio I/O object.
390 ///
391 /// This is equivalent to
392 /// [`UdpSocket::from_std(stream)`](UdpSocket::from_std).
393 fn try_from(stream: net::UdpSocket) -> Result<Self, Self::Error> {
394 Self::from_std(stream)
395 }
396}
397
398impl fmt::Debug for UdpSocket {
399 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
400 self.io.get_ref().fmt(f)
401 }
402}
403
404#[cfg(all(unix))]
405mod sys {
406 use super::UdpSocket;
407 use std::os::unix::prelude::*;
408
409 impl AsRawFd for UdpSocket {
410 fn as_raw_fd(&self) -> RawFd {
411 self.io.get_ref().as_raw_fd()
412 }
413 }
414}
415
416#[cfg(windows)]
417mod sys {
418 // TODO: let's land these upstream with mio and then we can add them here.
419 //
420 // use std::os::windows::prelude::*;
421 // use super::UdpSocket;
422 //
423 // impl AsRawHandle for UdpSocket {
424 // fn as_raw_handle(&self) -> RawHandle {
425 // self.io.get_ref().as_raw_handle()
426 // }
427 // }
428}