quinn_udp/
windows.rs

1use std::{
2    io::{self, IoSliceMut},
3    mem,
4    net::{IpAddr, Ipv4Addr},
5    os::windows::io::AsRawSocket,
6    ptr,
7    sync::Mutex,
8    time::Instant,
9};
10
11use libc::{c_int, c_uint};
12use once_cell::sync::Lazy;
13use windows_sys::Win32::Networking::WinSock;
14
15use crate::{
16    cmsg::{self, CMsgHdr},
17    log::debug,
18    log_sendmsg_error, EcnCodepoint, RecvMeta, Transmit, UdpSockRef, IO_ERROR_LOG_INTERVAL,
19};
20
21/// QUIC-friendly UDP socket for Windows
22///
23/// Unlike a standard Windows UDP socket, this allows ECN bits to be read and written.
24#[derive(Debug)]
25pub struct UdpSocketState {
26    last_send_error: Mutex<Instant>,
27}
28
29impl UdpSocketState {
30    pub fn new(socket: UdpSockRef<'_>) -> io::Result<Self> {
31        assert!(
32            CMSG_LEN
33                >= WinSock::CMSGHDR::cmsg_space(mem::size_of::<WinSock::IN6_PKTINFO>())
34                    + WinSock::CMSGHDR::cmsg_space(mem::size_of::<c_int>())
35                    + WinSock::CMSGHDR::cmsg_space(mem::size_of::<u32>())
36        );
37        assert!(
38            mem::align_of::<WinSock::CMSGHDR>() <= mem::align_of::<cmsg::Aligned<[u8; 0]>>(),
39            "control message buffers will be misaligned"
40        );
41
42        socket.0.set_nonblocking(true)?;
43        let addr = socket.0.local_addr()?;
44        let is_ipv6 = addr.as_socket_ipv6().is_some();
45        let v6only = unsafe {
46            let mut result: u32 = 0;
47            let mut len = mem::size_of_val(&result) as i32;
48            let rc = WinSock::getsockopt(
49                socket.0.as_raw_socket() as _,
50                WinSock::IPPROTO_IPV6,
51                WinSock::IPV6_V6ONLY as _,
52                &mut result as *mut _ as _,
53                &mut len,
54            );
55            if rc == -1 {
56                return Err(io::Error::last_os_error());
57            }
58            result != 0
59        };
60        let is_ipv4 = addr.as_socket_ipv4().is_some() || !v6only;
61
62        // We don't support old versions of Windows that do not enable access to `WSARecvMsg()`
63        if WSARECVMSG_PTR.is_none() {
64            return Err(io::Error::new(
65                io::ErrorKind::Unsupported,
66                "network stack does not support WSARecvMsg function",
67            ));
68        }
69
70        if is_ipv4 {
71            set_socket_option(
72                &*socket.0,
73                WinSock::IPPROTO_IP,
74                WinSock::IP_DONTFRAGMENT,
75                OPTION_ON,
76            )?;
77
78            set_socket_option(
79                &*socket.0,
80                WinSock::IPPROTO_IP,
81                WinSock::IP_PKTINFO,
82                OPTION_ON,
83            )?;
84            set_socket_option(
85                &*socket.0,
86                WinSock::IPPROTO_IP,
87                WinSock::IP_RECVECN,
88                OPTION_ON,
89            )?;
90        }
91
92        if is_ipv6 {
93            set_socket_option(
94                &*socket.0,
95                WinSock::IPPROTO_IPV6,
96                WinSock::IPV6_DONTFRAG,
97                OPTION_ON,
98            )?;
99
100            set_socket_option(
101                &*socket.0,
102                WinSock::IPPROTO_IPV6,
103                WinSock::IPV6_PKTINFO,
104                OPTION_ON,
105            )?;
106
107            set_socket_option(
108                &*socket.0,
109                WinSock::IPPROTO_IPV6,
110                WinSock::IPV6_RECVECN,
111                OPTION_ON,
112            )?;
113        }
114
115        let now = Instant::now();
116        Ok(Self {
117            last_send_error: Mutex::new(now.checked_sub(2 * IO_ERROR_LOG_INTERVAL).unwrap_or(now)),
118        })
119    }
120
121    /// Enable or disable receive offloading.
122    ///
123    /// Also referred to as UDP Receive Segment Coalescing Offload (URO) on Windows.
124    ///
125    /// <https://learn.microsoft.com/en-us/windows-hardware/drivers/network/udp-rsc-offload>
126    ///
127    /// Disabled by default on Windows due to <https://github.com/quinn-rs/quinn/issues/2041>.
128    pub fn set_gro(&self, socket: UdpSockRef<'_>, enable: bool) -> io::Result<()> {
129        set_socket_option(
130            &*socket.0,
131            WinSock::IPPROTO_UDP,
132            WinSock::UDP_RECV_MAX_COALESCED_SIZE,
133            match enable {
134                // u32 per
135                // https://learn.microsoft.com/en-us/windows/win32/winsock/ipproto-udp-socket-options.
136                // Choice of 2^16 - 1 inspired by msquic.
137                true => u16::MAX as u32,
138                false => 0,
139            },
140        )
141    }
142
143    /// Sends a [`Transmit`] on the given socket.
144    ///
145    /// This function will only ever return errors of kind [`io::ErrorKind::WouldBlock`].
146    /// All other errors will be logged and converted to `Ok`.
147    ///
148    /// UDP transmission errors are considered non-fatal because higher-level protocols must
149    /// employ retransmits and timeouts anyway in order to deal with UDP's unreliable nature.
150    /// Thus, logging is most likely the only thing you can do with these errors.
151    ///
152    /// If you would like to handle these errors yourself, use [`UdpSocketState::try_send`]
153    /// instead.
154    pub fn send(&self, socket: UdpSockRef<'_>, transmit: &Transmit<'_>) -> io::Result<()> {
155        match send(socket, transmit) {
156            Ok(()) => Ok(()),
157            Err(e) if e.kind() == io::ErrorKind::WouldBlock => Err(e),
158            Err(e) => {
159                log_sendmsg_error(&self.last_send_error, e, transmit);
160
161                Ok(())
162            }
163        }
164    }
165
166    /// Sends a [`Transmit`] on the given socket without any additional error handling.
167    pub fn try_send(&self, socket: UdpSockRef<'_>, transmit: &Transmit<'_>) -> io::Result<()> {
168        send(socket, transmit)
169    }
170
171    pub fn recv(
172        &self,
173        socket: UdpSockRef<'_>,
174        bufs: &mut [IoSliceMut<'_>],
175        meta: &mut [RecvMeta],
176    ) -> io::Result<usize> {
177        let wsa_recvmsg_ptr = WSARECVMSG_PTR.expect("valid function pointer for WSARecvMsg");
178
179        // we cannot use [`socket2::MsgHdrMut`] as we do not have access to inner field which holds the WSAMSG
180        let mut ctrl_buf = cmsg::Aligned([0; CMSG_LEN]);
181        let mut source: WinSock::SOCKADDR_INET = unsafe { mem::zeroed() };
182        let mut data = WinSock::WSABUF {
183            buf: bufs[0].as_mut_ptr(),
184            len: bufs[0].len() as _,
185        };
186
187        let ctrl = WinSock::WSABUF {
188            buf: ctrl_buf.0.as_mut_ptr(),
189            len: ctrl_buf.0.len() as _,
190        };
191
192        let mut wsa_msg = WinSock::WSAMSG {
193            name: &mut source as *mut _ as *mut _,
194            namelen: mem::size_of_val(&source) as _,
195            lpBuffers: &mut data,
196            Control: ctrl,
197            dwBufferCount: 1,
198            dwFlags: 0,
199        };
200
201        let mut len = 0;
202        unsafe {
203            let rc = (wsa_recvmsg_ptr)(
204                socket.0.as_raw_socket() as usize,
205                &mut wsa_msg,
206                &mut len,
207                ptr::null_mut(),
208                None,
209            );
210            if rc == -1 {
211                return Err(io::Error::last_os_error());
212            }
213        }
214
215        let addr = unsafe {
216            let (_, addr) = socket2::SockAddr::try_init(|addr_storage, len| {
217                *len = mem::size_of_val(&source) as _;
218                ptr::copy_nonoverlapping(&source, addr_storage as _, 1);
219                Ok(())
220            })?;
221            addr.as_socket()
222        };
223
224        // Decode control messages (PKTINFO and ECN)
225        let mut ecn_bits = 0;
226        let mut dst_ip = None;
227        let mut stride = len;
228
229        let cmsg_iter = unsafe { cmsg::Iter::new(&wsa_msg) };
230        for cmsg in cmsg_iter {
231            const UDP_COALESCED_INFO: i32 = WinSock::UDP_COALESCED_INFO as i32;
232            // [header (len)][data][padding(len + sizeof(data))] -> [header][data][padding]
233            match (cmsg.cmsg_level, cmsg.cmsg_type) {
234                (WinSock::IPPROTO_IP, WinSock::IP_PKTINFO) => {
235                    let pktinfo =
236                        unsafe { cmsg::decode::<WinSock::IN_PKTINFO, WinSock::CMSGHDR>(cmsg) };
237                    // Addr is stored in big endian format
238                    let ip4 = Ipv4Addr::from(u32::from_be(unsafe { pktinfo.ipi_addr.S_un.S_addr }));
239                    dst_ip = Some(ip4.into());
240                }
241                (WinSock::IPPROTO_IPV6, WinSock::IPV6_PKTINFO) => {
242                    let pktinfo =
243                        unsafe { cmsg::decode::<WinSock::IN6_PKTINFO, WinSock::CMSGHDR>(cmsg) };
244                    // Addr is stored in big endian format
245                    dst_ip = Some(IpAddr::from(unsafe { pktinfo.ipi6_addr.u.Byte }));
246                }
247                (WinSock::IPPROTO_IP, WinSock::IP_ECN) => {
248                    // ECN is a C integer https://learn.microsoft.com/en-us/windows/win32/winsock/winsock-ecn
249                    ecn_bits = unsafe { cmsg::decode::<c_int, WinSock::CMSGHDR>(cmsg) };
250                }
251                (WinSock::IPPROTO_IPV6, WinSock::IPV6_ECN) => {
252                    // ECN is a C integer https://learn.microsoft.com/en-us/windows/win32/winsock/winsock-ecn
253                    ecn_bits = unsafe { cmsg::decode::<c_int, WinSock::CMSGHDR>(cmsg) };
254                }
255                (WinSock::IPPROTO_UDP, UDP_COALESCED_INFO) => {
256                    // Has type u32 (aka DWORD) per
257                    // https://learn.microsoft.com/en-us/windows/win32/winsock/ipproto-udp-socket-options
258                    stride = unsafe { cmsg::decode::<u32, WinSock::CMSGHDR>(cmsg) };
259                }
260                _ => {}
261            }
262        }
263
264        meta[0] = RecvMeta {
265            len: len as usize,
266            stride: stride as usize,
267            addr: addr.unwrap(),
268            ecn: EcnCodepoint::from_bits(ecn_bits as u8),
269            dst_ip,
270        };
271        Ok(1)
272    }
273
274    /// The maximum amount of segments which can be transmitted if a platform
275    /// supports Generic Send Offload (GSO).
276    ///
277    /// This is 1 if the platform doesn't support GSO. Subject to change if errors are detected
278    /// while using GSO.
279    #[inline]
280    pub fn max_gso_segments(&self) -> usize {
281        *MAX_GSO_SEGMENTS
282    }
283
284    /// The number of segments to read when GRO is enabled. Used as a factor to
285    /// compute the receive buffer size.
286    ///
287    /// Returns 1 if the platform doesn't support GRO.
288    #[inline]
289    pub fn gro_segments(&self) -> usize {
290        // Arbitrary reasonable value inspired by Linux and msquic
291        64
292    }
293
294    #[inline]
295    pub fn may_fragment(&self) -> bool {
296        false
297    }
298}
299
300fn send(socket: UdpSockRef<'_>, transmit: &Transmit<'_>) -> io::Result<()> {
301    // we cannot use [`socket2::sendmsg()`] and [`socket2::MsgHdr`] as we do not have access
302    // to the inner field which holds the WSAMSG
303    let mut ctrl_buf = cmsg::Aligned([0; CMSG_LEN]);
304    let daddr = socket2::SockAddr::from(transmit.destination);
305
306    let mut data = WinSock::WSABUF {
307        buf: transmit.contents.as_ptr() as *mut _,
308        len: transmit.contents.len() as _,
309    };
310
311    let ctrl = WinSock::WSABUF {
312        buf: ctrl_buf.0.as_mut_ptr(),
313        len: ctrl_buf.0.len() as _,
314    };
315
316    let mut wsa_msg = WinSock::WSAMSG {
317        name: daddr.as_ptr() as *mut _,
318        namelen: daddr.len(),
319        lpBuffers: &mut data,
320        Control: ctrl,
321        dwBufferCount: 1,
322        dwFlags: 0,
323    };
324
325    // Add control messages (ECN and PKTINFO)
326    let mut encoder = unsafe { cmsg::Encoder::new(&mut wsa_msg) };
327
328    if let Some(ip) = transmit.src_ip {
329        let ip = std::net::SocketAddr::new(ip, 0);
330        let ip = socket2::SockAddr::from(ip);
331        match ip.family() {
332            WinSock::AF_INET => {
333                let src_ip = unsafe { ptr::read(ip.as_ptr() as *const WinSock::SOCKADDR_IN) };
334                let pktinfo = WinSock::IN_PKTINFO {
335                    ipi_addr: src_ip.sin_addr,
336                    ipi_ifindex: 0,
337                };
338                encoder.push(WinSock::IPPROTO_IP, WinSock::IP_PKTINFO, pktinfo);
339            }
340            WinSock::AF_INET6 => {
341                let src_ip = unsafe { ptr::read(ip.as_ptr() as *const WinSock::SOCKADDR_IN6) };
342                let pktinfo = WinSock::IN6_PKTINFO {
343                    ipi6_addr: src_ip.sin6_addr,
344                    ipi6_ifindex: unsafe { src_ip.Anonymous.sin6_scope_id },
345                };
346                encoder.push(WinSock::IPPROTO_IPV6, WinSock::IPV6_PKTINFO, pktinfo);
347            }
348            _ => {
349                return Err(io::Error::from(io::ErrorKind::InvalidInput));
350            }
351        }
352    }
353
354    // ECN is a C integer https://learn.microsoft.com/en-us/windows/win32/winsock/winsock-ecn
355    let ecn = transmit.ecn.map_or(0, |x| x as c_int);
356    // True for IPv4 or IPv4-Mapped IPv6
357    let is_ipv4 = transmit.destination.is_ipv4()
358        || matches!(transmit.destination.ip(), IpAddr::V6(addr) if addr.to_ipv4_mapped().is_some());
359    if is_ipv4 {
360        encoder.push(WinSock::IPPROTO_IP, WinSock::IP_ECN, ecn);
361    } else {
362        encoder.push(WinSock::IPPROTO_IPV6, WinSock::IPV6_ECN, ecn);
363    }
364
365    // Segment size is a u32 https://learn.microsoft.com/en-us/windows/win32/api/ws2tcpip/nf-ws2tcpip-wsasetudpsendmessagesize
366    if let Some(segment_size) = transmit.segment_size {
367        encoder.push(
368            WinSock::IPPROTO_UDP,
369            WinSock::UDP_SEND_MSG_SIZE,
370            segment_size as u32,
371        );
372    }
373
374    encoder.finish();
375
376    let mut len = 0;
377    let rc = unsafe {
378        WinSock::WSASendMsg(
379            socket.0.as_raw_socket() as usize,
380            &wsa_msg,
381            0,
382            &mut len,
383            ptr::null_mut(),
384            None,
385        )
386    };
387
388    match rc {
389        0 => Ok(()),
390        _ => Err(io::Error::last_os_error()),
391    }
392}
393
394fn set_socket_option(
395    socket: &impl AsRawSocket,
396    level: i32,
397    name: i32,
398    value: u32,
399) -> io::Result<()> {
400    let rc = unsafe {
401        WinSock::setsockopt(
402            socket.as_raw_socket() as usize,
403            level,
404            name,
405            &value as *const _ as _,
406            mem::size_of_val(&value) as _,
407        )
408    };
409
410    match rc == 0 {
411        true => Ok(()),
412        false => Err(io::Error::last_os_error()),
413    }
414}
415
416pub(crate) const BATCH_SIZE: usize = 1;
417// Enough to store max(IP_PKTINFO + IP_ECN, IPV6_PKTINFO + IPV6_ECN) + max(UDP_SEND_MSG_SIZE, UDP_COALESCED_INFO) bytes (header + data) and some extra margin
418const CMSG_LEN: usize = 128;
419const OPTION_ON: u32 = 1;
420
421// FIXME this could use [`std::sync::OnceLock`] once the MSRV is bumped to 1.70 and upper
422static WSARECVMSG_PTR: Lazy<WinSock::LPFN_WSARECVMSG> = Lazy::new(|| {
423    let s = unsafe { WinSock::socket(WinSock::AF_INET as _, WinSock::SOCK_DGRAM as _, 0) };
424    if s == WinSock::INVALID_SOCKET {
425        debug!(
426            "ignoring WSARecvMsg function pointer due to socket creation error: {}",
427            io::Error::last_os_error()
428        );
429        return None;
430    }
431
432    // Detect if OS expose WSARecvMsg API based on
433    // https://github.com/Azure/mio-uds-windows/blob/a3c97df82018086add96d8821edb4aa85ec1b42b/src/stdnet/ext.rs#L601
434    let guid = WinSock::WSAID_WSARECVMSG;
435    let mut wsa_recvmsg_ptr = None;
436    let mut len = 0;
437
438    // Safety: Option handles the NULL pointer with a None value
439    let rc = unsafe {
440        WinSock::WSAIoctl(
441            s as _,
442            WinSock::SIO_GET_EXTENSION_FUNCTION_POINTER,
443            &guid as *const _ as *const _,
444            mem::size_of_val(&guid) as u32,
445            &mut wsa_recvmsg_ptr as *mut _ as *mut _,
446            mem::size_of_val(&wsa_recvmsg_ptr) as u32,
447            &mut len,
448            ptr::null_mut(),
449            None,
450        )
451    };
452
453    if rc == -1 {
454        debug!(
455            "ignoring WSARecvMsg function pointer due to ioctl error: {}",
456            io::Error::last_os_error()
457        );
458    } else if len as usize != mem::size_of::<WinSock::LPFN_WSARECVMSG>() {
459        debug!("ignoring WSARecvMsg function pointer due to pointer size mismatch");
460        wsa_recvmsg_ptr = None;
461    }
462
463    unsafe {
464        WinSock::closesocket(s);
465    }
466
467    wsa_recvmsg_ptr
468});
469
470static MAX_GSO_SEGMENTS: Lazy<usize> = Lazy::new(|| {
471    let socket = match std::net::UdpSocket::bind("[::]:0")
472        .or_else(|_| std::net::UdpSocket::bind((Ipv4Addr::LOCALHOST, 0)))
473    {
474        Ok(socket) => socket,
475        Err(_) => return 1,
476    };
477    const GSO_SIZE: c_uint = 1500;
478    match set_socket_option(
479        &socket,
480        WinSock::IPPROTO_UDP,
481        WinSock::UDP_SEND_MSG_SIZE,
482        GSO_SIZE,
483    ) {
484        // Empirically found on Windows 11 x64
485        Ok(()) => 512,
486        Err(_) => 1,
487    }
488});