1use std::{
2 io::{self, IoSliceMut},
3 mem,
4 net::{IpAddr, Ipv4Addr},
5 os::windows::io::AsRawSocket,
6 ptr,
7 sync::Mutex,
8 time::Instant,
9};
10
11use libc::{c_int, c_uint};
12use once_cell::sync::Lazy;
13use windows_sys::Win32::Networking::WinSock;
14
15use crate::{
16 cmsg::{self, CMsgHdr},
17 log::debug,
18 log_sendmsg_error, EcnCodepoint, RecvMeta, Transmit, UdpSockRef, IO_ERROR_LOG_INTERVAL,
19};
20
21#[derive(Debug)]
25pub struct UdpSocketState {
26 last_send_error: Mutex<Instant>,
27}
28
29impl UdpSocketState {
30 pub fn new(socket: UdpSockRef<'_>) -> io::Result<Self> {
31 assert!(
32 CMSG_LEN
33 >= WinSock::CMSGHDR::cmsg_space(mem::size_of::<WinSock::IN6_PKTINFO>())
34 + WinSock::CMSGHDR::cmsg_space(mem::size_of::<c_int>())
35 + WinSock::CMSGHDR::cmsg_space(mem::size_of::<u32>())
36 );
37 assert!(
38 mem::align_of::<WinSock::CMSGHDR>() <= mem::align_of::<cmsg::Aligned<[u8; 0]>>(),
39 "control message buffers will be misaligned"
40 );
41
42 socket.0.set_nonblocking(true)?;
43 let addr = socket.0.local_addr()?;
44 let is_ipv6 = addr.as_socket_ipv6().is_some();
45 let v6only = unsafe {
46 let mut result: u32 = 0;
47 let mut len = mem::size_of_val(&result) as i32;
48 let rc = WinSock::getsockopt(
49 socket.0.as_raw_socket() as _,
50 WinSock::IPPROTO_IPV6,
51 WinSock::IPV6_V6ONLY as _,
52 &mut result as *mut _ as _,
53 &mut len,
54 );
55 if rc == -1 {
56 return Err(io::Error::last_os_error());
57 }
58 result != 0
59 };
60 let is_ipv4 = addr.as_socket_ipv4().is_some() || !v6only;
61
62 if WSARECVMSG_PTR.is_none() {
64 return Err(io::Error::new(
65 io::ErrorKind::Unsupported,
66 "network stack does not support WSARecvMsg function",
67 ));
68 }
69
70 if is_ipv4 {
71 set_socket_option(
72 &*socket.0,
73 WinSock::IPPROTO_IP,
74 WinSock::IP_DONTFRAGMENT,
75 OPTION_ON,
76 )?;
77
78 set_socket_option(
79 &*socket.0,
80 WinSock::IPPROTO_IP,
81 WinSock::IP_PKTINFO,
82 OPTION_ON,
83 )?;
84 set_socket_option(
85 &*socket.0,
86 WinSock::IPPROTO_IP,
87 WinSock::IP_RECVECN,
88 OPTION_ON,
89 )?;
90 }
91
92 if is_ipv6 {
93 set_socket_option(
94 &*socket.0,
95 WinSock::IPPROTO_IPV6,
96 WinSock::IPV6_DONTFRAG,
97 OPTION_ON,
98 )?;
99
100 set_socket_option(
101 &*socket.0,
102 WinSock::IPPROTO_IPV6,
103 WinSock::IPV6_PKTINFO,
104 OPTION_ON,
105 )?;
106
107 set_socket_option(
108 &*socket.0,
109 WinSock::IPPROTO_IPV6,
110 WinSock::IPV6_RECVECN,
111 OPTION_ON,
112 )?;
113 }
114
115 let now = Instant::now();
116 Ok(Self {
117 last_send_error: Mutex::new(now.checked_sub(2 * IO_ERROR_LOG_INTERVAL).unwrap_or(now)),
118 })
119 }
120
121 pub fn set_gro(&self, socket: UdpSockRef<'_>, enable: bool) -> io::Result<()> {
129 set_socket_option(
130 &*socket.0,
131 WinSock::IPPROTO_UDP,
132 WinSock::UDP_RECV_MAX_COALESCED_SIZE,
133 match enable {
134 true => u16::MAX as u32,
138 false => 0,
139 },
140 )
141 }
142
143 pub fn send(&self, socket: UdpSockRef<'_>, transmit: &Transmit<'_>) -> io::Result<()> {
155 match send(socket, transmit) {
156 Ok(()) => Ok(()),
157 Err(e) if e.kind() == io::ErrorKind::WouldBlock => Err(e),
158 Err(e) => {
159 log_sendmsg_error(&self.last_send_error, e, transmit);
160
161 Ok(())
162 }
163 }
164 }
165
166 pub fn try_send(&self, socket: UdpSockRef<'_>, transmit: &Transmit<'_>) -> io::Result<()> {
168 send(socket, transmit)
169 }
170
171 pub fn recv(
172 &self,
173 socket: UdpSockRef<'_>,
174 bufs: &mut [IoSliceMut<'_>],
175 meta: &mut [RecvMeta],
176 ) -> io::Result<usize> {
177 let wsa_recvmsg_ptr = WSARECVMSG_PTR.expect("valid function pointer for WSARecvMsg");
178
179 let mut ctrl_buf = cmsg::Aligned([0; CMSG_LEN]);
181 let mut source: WinSock::SOCKADDR_INET = unsafe { mem::zeroed() };
182 let mut data = WinSock::WSABUF {
183 buf: bufs[0].as_mut_ptr(),
184 len: bufs[0].len() as _,
185 };
186
187 let ctrl = WinSock::WSABUF {
188 buf: ctrl_buf.0.as_mut_ptr(),
189 len: ctrl_buf.0.len() as _,
190 };
191
192 let mut wsa_msg = WinSock::WSAMSG {
193 name: &mut source as *mut _ as *mut _,
194 namelen: mem::size_of_val(&source) as _,
195 lpBuffers: &mut data,
196 Control: ctrl,
197 dwBufferCount: 1,
198 dwFlags: 0,
199 };
200
201 let mut len = 0;
202 unsafe {
203 let rc = (wsa_recvmsg_ptr)(
204 socket.0.as_raw_socket() as usize,
205 &mut wsa_msg,
206 &mut len,
207 ptr::null_mut(),
208 None,
209 );
210 if rc == -1 {
211 return Err(io::Error::last_os_error());
212 }
213 }
214
215 let addr = unsafe {
216 let (_, addr) = socket2::SockAddr::try_init(|addr_storage, len| {
217 *len = mem::size_of_val(&source) as _;
218 ptr::copy_nonoverlapping(&source, addr_storage as _, 1);
219 Ok(())
220 })?;
221 addr.as_socket()
222 };
223
224 let mut ecn_bits = 0;
226 let mut dst_ip = None;
227 let mut stride = len;
228
229 let cmsg_iter = unsafe { cmsg::Iter::new(&wsa_msg) };
230 for cmsg in cmsg_iter {
231 const UDP_COALESCED_INFO: i32 = WinSock::UDP_COALESCED_INFO as i32;
232 match (cmsg.cmsg_level, cmsg.cmsg_type) {
234 (WinSock::IPPROTO_IP, WinSock::IP_PKTINFO) => {
235 let pktinfo =
236 unsafe { cmsg::decode::<WinSock::IN_PKTINFO, WinSock::CMSGHDR>(cmsg) };
237 let ip4 = Ipv4Addr::from(u32::from_be(unsafe { pktinfo.ipi_addr.S_un.S_addr }));
239 dst_ip = Some(ip4.into());
240 }
241 (WinSock::IPPROTO_IPV6, WinSock::IPV6_PKTINFO) => {
242 let pktinfo =
243 unsafe { cmsg::decode::<WinSock::IN6_PKTINFO, WinSock::CMSGHDR>(cmsg) };
244 dst_ip = Some(IpAddr::from(unsafe { pktinfo.ipi6_addr.u.Byte }));
246 }
247 (WinSock::IPPROTO_IP, WinSock::IP_ECN) => {
248 ecn_bits = unsafe { cmsg::decode::<c_int, WinSock::CMSGHDR>(cmsg) };
250 }
251 (WinSock::IPPROTO_IPV6, WinSock::IPV6_ECN) => {
252 ecn_bits = unsafe { cmsg::decode::<c_int, WinSock::CMSGHDR>(cmsg) };
254 }
255 (WinSock::IPPROTO_UDP, UDP_COALESCED_INFO) => {
256 stride = unsafe { cmsg::decode::<u32, WinSock::CMSGHDR>(cmsg) };
259 }
260 _ => {}
261 }
262 }
263
264 meta[0] = RecvMeta {
265 len: len as usize,
266 stride: stride as usize,
267 addr: addr.unwrap(),
268 ecn: EcnCodepoint::from_bits(ecn_bits as u8),
269 dst_ip,
270 };
271 Ok(1)
272 }
273
274 #[inline]
280 pub fn max_gso_segments(&self) -> usize {
281 *MAX_GSO_SEGMENTS
282 }
283
284 #[inline]
289 pub fn gro_segments(&self) -> usize {
290 64
292 }
293
294 #[inline]
295 pub fn may_fragment(&self) -> bool {
296 false
297 }
298}
299
300fn send(socket: UdpSockRef<'_>, transmit: &Transmit<'_>) -> io::Result<()> {
301 let mut ctrl_buf = cmsg::Aligned([0; CMSG_LEN]);
304 let daddr = socket2::SockAddr::from(transmit.destination);
305
306 let mut data = WinSock::WSABUF {
307 buf: transmit.contents.as_ptr() as *mut _,
308 len: transmit.contents.len() as _,
309 };
310
311 let ctrl = WinSock::WSABUF {
312 buf: ctrl_buf.0.as_mut_ptr(),
313 len: ctrl_buf.0.len() as _,
314 };
315
316 let mut wsa_msg = WinSock::WSAMSG {
317 name: daddr.as_ptr() as *mut _,
318 namelen: daddr.len(),
319 lpBuffers: &mut data,
320 Control: ctrl,
321 dwBufferCount: 1,
322 dwFlags: 0,
323 };
324
325 let mut encoder = unsafe { cmsg::Encoder::new(&mut wsa_msg) };
327
328 if let Some(ip) = transmit.src_ip {
329 let ip = std::net::SocketAddr::new(ip, 0);
330 let ip = socket2::SockAddr::from(ip);
331 match ip.family() {
332 WinSock::AF_INET => {
333 let src_ip = unsafe { ptr::read(ip.as_ptr() as *const WinSock::SOCKADDR_IN) };
334 let pktinfo = WinSock::IN_PKTINFO {
335 ipi_addr: src_ip.sin_addr,
336 ipi_ifindex: 0,
337 };
338 encoder.push(WinSock::IPPROTO_IP, WinSock::IP_PKTINFO, pktinfo);
339 }
340 WinSock::AF_INET6 => {
341 let src_ip = unsafe { ptr::read(ip.as_ptr() as *const WinSock::SOCKADDR_IN6) };
342 let pktinfo = WinSock::IN6_PKTINFO {
343 ipi6_addr: src_ip.sin6_addr,
344 ipi6_ifindex: unsafe { src_ip.Anonymous.sin6_scope_id },
345 };
346 encoder.push(WinSock::IPPROTO_IPV6, WinSock::IPV6_PKTINFO, pktinfo);
347 }
348 _ => {
349 return Err(io::Error::from(io::ErrorKind::InvalidInput));
350 }
351 }
352 }
353
354 let ecn = transmit.ecn.map_or(0, |x| x as c_int);
356 let is_ipv4 = transmit.destination.is_ipv4()
358 || matches!(transmit.destination.ip(), IpAddr::V6(addr) if addr.to_ipv4_mapped().is_some());
359 if is_ipv4 {
360 encoder.push(WinSock::IPPROTO_IP, WinSock::IP_ECN, ecn);
361 } else {
362 encoder.push(WinSock::IPPROTO_IPV6, WinSock::IPV6_ECN, ecn);
363 }
364
365 if let Some(segment_size) = transmit.segment_size {
367 encoder.push(
368 WinSock::IPPROTO_UDP,
369 WinSock::UDP_SEND_MSG_SIZE,
370 segment_size as u32,
371 );
372 }
373
374 encoder.finish();
375
376 let mut len = 0;
377 let rc = unsafe {
378 WinSock::WSASendMsg(
379 socket.0.as_raw_socket() as usize,
380 &wsa_msg,
381 0,
382 &mut len,
383 ptr::null_mut(),
384 None,
385 )
386 };
387
388 match rc {
389 0 => Ok(()),
390 _ => Err(io::Error::last_os_error()),
391 }
392}
393
394fn set_socket_option(
395 socket: &impl AsRawSocket,
396 level: i32,
397 name: i32,
398 value: u32,
399) -> io::Result<()> {
400 let rc = unsafe {
401 WinSock::setsockopt(
402 socket.as_raw_socket() as usize,
403 level,
404 name,
405 &value as *const _ as _,
406 mem::size_of_val(&value) as _,
407 )
408 };
409
410 match rc == 0 {
411 true => Ok(()),
412 false => Err(io::Error::last_os_error()),
413 }
414}
415
416pub(crate) const BATCH_SIZE: usize = 1;
417const CMSG_LEN: usize = 128;
419const OPTION_ON: u32 = 1;
420
421static WSARECVMSG_PTR: Lazy<WinSock::LPFN_WSARECVMSG> = Lazy::new(|| {
423 let s = unsafe { WinSock::socket(WinSock::AF_INET as _, WinSock::SOCK_DGRAM as _, 0) };
424 if s == WinSock::INVALID_SOCKET {
425 debug!(
426 "ignoring WSARecvMsg function pointer due to socket creation error: {}",
427 io::Error::last_os_error()
428 );
429 return None;
430 }
431
432 let guid = WinSock::WSAID_WSARECVMSG;
435 let mut wsa_recvmsg_ptr = None;
436 let mut len = 0;
437
438 let rc = unsafe {
440 WinSock::WSAIoctl(
441 s as _,
442 WinSock::SIO_GET_EXTENSION_FUNCTION_POINTER,
443 &guid as *const _ as *const _,
444 mem::size_of_val(&guid) as u32,
445 &mut wsa_recvmsg_ptr as *mut _ as *mut _,
446 mem::size_of_val(&wsa_recvmsg_ptr) as u32,
447 &mut len,
448 ptr::null_mut(),
449 None,
450 )
451 };
452
453 if rc == -1 {
454 debug!(
455 "ignoring WSARecvMsg function pointer due to ioctl error: {}",
456 io::Error::last_os_error()
457 );
458 } else if len as usize != mem::size_of::<WinSock::LPFN_WSARECVMSG>() {
459 debug!("ignoring WSARecvMsg function pointer due to pointer size mismatch");
460 wsa_recvmsg_ptr = None;
461 }
462
463 unsafe {
464 WinSock::closesocket(s);
465 }
466
467 wsa_recvmsg_ptr
468});
469
470static MAX_GSO_SEGMENTS: Lazy<usize> = Lazy::new(|| {
471 let socket = match std::net::UdpSocket::bind("[::]:0")
472 .or_else(|_| std::net::UdpSocket::bind((Ipv4Addr::LOCALHOST, 0)))
473 {
474 Ok(socket) => socket,
475 Err(_) => return 1,
476 };
477 const GSO_SIZE: c_uint = 1500;
478 match set_socket_option(
479 &socket,
480 WinSock::IPPROTO_UDP,
481 WinSock::UDP_SEND_MSG_SIZE,
482 GSO_SIZE,
483 ) {
484 Ok(()) => 512,
486 Err(_) => 1,
487 }
488});