solana_streamer/
sendmmsg.rs

1//! The `sendmmsg` module provides sendmmsg() API implementation
2
3#[cfg(target_os = "linux")]
4use {
5    crate::msghdr::create_msghdr,
6    itertools::izip,
7    libc::{iovec, mmsghdr, sockaddr_in, sockaddr_in6, sockaddr_storage, socklen_t},
8    std::{
9        mem::{self, MaybeUninit},
10        os::unix::io::AsRawFd,
11        ptr,
12    },
13};
14use {
15    solana_transaction_error::TransportError,
16    std::{
17        borrow::Borrow,
18        io,
19        iter::repeat,
20        net::{SocketAddr, UdpSocket},
21    },
22    thiserror::Error,
23};
24
25#[derive(Debug, Error)]
26pub enum SendPktsError {
27    /// IO Error during send: first error, num failed packets
28    #[error("IO Error, some packets could not be sent")]
29    IoError(io::Error, usize),
30}
31
32impl From<SendPktsError> for TransportError {
33    fn from(err: SendPktsError) -> Self {
34        Self::Custom(format!("{err:?}"))
35    }
36}
37
38#[cfg(not(target_os = "linux"))]
39pub fn batch_send<S, T>(sock: &UdpSocket, packets: &[(T, S)]) -> Result<(), SendPktsError>
40where
41    S: Borrow<SocketAddr>,
42    T: AsRef<[u8]>,
43{
44    let mut num_failed = 0;
45    let mut erropt = None;
46    for (p, a) in packets {
47        if let Err(e) = sock.send_to(p.as_ref(), a.borrow()) {
48            num_failed += 1;
49            if erropt.is_none() {
50                erropt = Some(e);
51            }
52        }
53    }
54
55    if let Some(err) = erropt {
56        Err(SendPktsError::IoError(err, num_failed))
57    } else {
58        Ok(())
59    }
60}
61
62#[cfg(target_os = "linux")]
63fn mmsghdr_for_packet(
64    packet: &[u8],
65    dest: &SocketAddr,
66    iov: &mut MaybeUninit<iovec>,
67    addr: &mut MaybeUninit<sockaddr_storage>,
68    hdr: &mut MaybeUninit<mmsghdr>,
69) {
70    const SIZE_OF_SOCKADDR_IN: usize = mem::size_of::<sockaddr_in>();
71    const SIZE_OF_SOCKADDR_IN6: usize = mem::size_of::<sockaddr_in6>();
72    const SIZE_OF_SOCKADDR_STORAGE: usize = mem::size_of::<sockaddr_storage>();
73    const SOCKADDR_IN_PADDING: usize = SIZE_OF_SOCKADDR_STORAGE - SIZE_OF_SOCKADDR_IN;
74    const SOCKADDR_IN6_PADDING: usize = SIZE_OF_SOCKADDR_STORAGE - SIZE_OF_SOCKADDR_IN6;
75
76    iov.write(iovec {
77        iov_base: packet.as_ptr() as *mut libc::c_void,
78        iov_len: packet.len(),
79    });
80
81    let msg_namelen = match dest {
82        SocketAddr::V4(socket_addr_v4) => {
83            let ptr: *mut sockaddr_in = addr.as_mut_ptr() as *mut _;
84            unsafe {
85                ptr::write(
86                    ptr,
87                    *nix::sys::socket::SockaddrIn::from(*socket_addr_v4).as_ref(),
88                );
89                // Zero the remaining bytes after sockaddr_in
90                ptr::write_bytes(
91                    (ptr as *mut u8).add(SIZE_OF_SOCKADDR_IN),
92                    0,
93                    SOCKADDR_IN_PADDING,
94                );
95            }
96            SIZE_OF_SOCKADDR_IN as socklen_t
97        }
98        SocketAddr::V6(socket_addr_v6) => {
99            let ptr: *mut sockaddr_in6 = addr.as_mut_ptr() as *mut _;
100            unsafe {
101                ptr::write(
102                    ptr,
103                    *nix::sys::socket::SockaddrIn6::from(*socket_addr_v6).as_ref(),
104                );
105                // Zero the remaining bytes after sockaddr_in6
106                ptr::write_bytes(
107                    (ptr as *mut u8).add(SIZE_OF_SOCKADDR_IN6),
108                    0,
109                    SOCKADDR_IN6_PADDING,
110                );
111            }
112            SIZE_OF_SOCKADDR_IN6 as socklen_t
113        }
114    };
115
116    let msg_hdr = create_msghdr(addr, msg_namelen, iov);
117
118    hdr.write(mmsghdr {
119        msg_len: 0,
120        msg_hdr,
121    });
122}
123
124#[cfg(target_os = "linux")]
125fn sendmmsg_retry(sock: &UdpSocket, hdrs: &mut [mmsghdr]) -> Result<(), SendPktsError> {
126    let sock_fd = sock.as_raw_fd();
127    let mut total_sent = 0;
128    let mut erropt = None;
129
130    let mut pkts = &mut *hdrs;
131    while !pkts.is_empty() {
132        let npkts = match unsafe { libc::sendmmsg(sock_fd, &mut pkts[0], pkts.len() as u32, 0) } {
133            -1 => {
134                if erropt.is_none() {
135                    erropt = Some(io::Error::last_os_error());
136                }
137                // skip over the failing packet
138                1_usize
139            }
140            n => {
141                // if we fail to send all packets we advance to the failing
142                // packet and retry in order to capture the error code
143                total_sent += n as usize;
144                n as usize
145            }
146        };
147        pkts = &mut pkts[npkts..];
148    }
149
150    if let Some(err) = erropt {
151        Err(SendPktsError::IoError(err, hdrs.len() - total_sent))
152    } else {
153        Ok(())
154    }
155}
156
157#[cfg(target_os = "linux")]
158const MAX_IOV: usize = libc::UIO_MAXIOV as usize;
159
160#[cfg(target_os = "linux")]
161pub fn batch_send_max_iov<S, T>(sock: &UdpSocket, packets: &[(T, S)]) -> Result<(), SendPktsError>
162where
163    S: Borrow<SocketAddr>,
164    T: AsRef<[u8]>,
165{
166    assert!(packets.len() <= MAX_IOV);
167
168    let mut iovs = [MaybeUninit::uninit(); MAX_IOV];
169    let mut addrs = [MaybeUninit::uninit(); MAX_IOV];
170    let mut hdrs = [MaybeUninit::uninit(); MAX_IOV];
171
172    // izip! will iterate packets.len() times, leaving hdrs, iovs, and addrs initialized only up to packets.len()
173    for ((pkt, dest), hdr, iov, addr) in izip!(packets, &mut hdrs, &mut iovs, &mut addrs) {
174        mmsghdr_for_packet(pkt.as_ref(), dest.borrow(), iov, addr, hdr);
175    }
176
177    // SAFETY: The first `packets.len()` elements of `hdrs`, `iovs`, and `addrs` are
178    // guaranteed to be initialized by `mmsghdr_for_packet` before this loop.
179    let hdrs_slice =
180        unsafe { std::slice::from_raw_parts_mut(hdrs.as_mut_ptr() as *mut mmsghdr, packets.len()) };
181
182    let result = sendmmsg_retry(sock, hdrs_slice);
183
184    // SAFETY: The first `packets.len()` elements of `hdrs`, `iovs`, and `addrs` are
185    // guaranteed to be initialized by `mmsghdr_for_packet` before this loop.
186    for (hdr, iov, addr) in izip!(&mut hdrs, &mut iovs, &mut addrs).take(packets.len()) {
187        unsafe {
188            hdr.assume_init_drop();
189            iov.assume_init_drop();
190            addr.assume_init_drop();
191        }
192    }
193
194    result
195}
196
197#[cfg(target_os = "linux")]
198pub fn batch_send<S, T>(sock: &UdpSocket, packets: &[(T, S)]) -> Result<(), SendPktsError>
199where
200    S: Borrow<SocketAddr>,
201    T: AsRef<[u8]>,
202{
203    for chunk in packets.chunks(MAX_IOV) {
204        batch_send_max_iov(sock, chunk)?;
205    }
206    Ok(())
207}
208
209pub fn multi_target_send<S, T>(
210    sock: &UdpSocket,
211    packet: T,
212    dests: &[S],
213) -> Result<(), SendPktsError>
214where
215    S: Borrow<SocketAddr>,
216    T: AsRef<[u8]>,
217{
218    let dests = dests.iter().map(Borrow::borrow);
219    let pkts: Vec<_> = repeat(&packet).zip(dests).collect();
220    batch_send(sock, &pkts)
221}
222
223#[cfg(test)]
224mod tests {
225    use {
226        crate::{
227            packet::Packet,
228            recvmmsg::recv_mmsg,
229            sendmmsg::{batch_send, multi_target_send, SendPktsError},
230        },
231        assert_matches::assert_matches,
232        solana_net_utils::{bind_to_localhost, bind_to_unspecified},
233        solana_packet::PACKET_DATA_SIZE,
234        std::{
235            io::ErrorKind,
236            net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
237        },
238    };
239
240    #[test]
241    pub fn test_send_mmsg_one_dest() {
242        let reader = bind_to_localhost().expect("bind");
243        let addr = reader.local_addr().unwrap();
244        let sender = bind_to_localhost().expect("bind");
245
246        let packets: Vec<_> = (0..32).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect();
247        let packet_refs: Vec<_> = packets.iter().map(|p| (&p[..], &addr)).collect();
248
249        let sent = batch_send(&sender, &packet_refs[..]).ok();
250        assert_eq!(sent, Some(()));
251
252        let mut packets = vec![Packet::default(); 32];
253        let recv = recv_mmsg(&reader, &mut packets[..]).unwrap();
254        assert_eq!(32, recv);
255    }
256
257    #[test]
258    pub fn test_send_mmsg_multi_dest() {
259        let reader = bind_to_localhost().expect("bind");
260        let addr = reader.local_addr().unwrap();
261
262        let reader2 = bind_to_localhost().expect("bind");
263        let addr2 = reader2.local_addr().unwrap();
264
265        let sender = bind_to_localhost().expect("bind");
266
267        let packets: Vec<_> = (0..32).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect();
268        let packet_refs: Vec<_> = packets
269            .iter()
270            .enumerate()
271            .map(|(i, p)| {
272                if i < 16 {
273                    (&p[..], &addr)
274                } else {
275                    (&p[..], &addr2)
276                }
277            })
278            .collect();
279
280        let sent = batch_send(&sender, &packet_refs[..]).ok();
281        assert_eq!(sent, Some(()));
282
283        let mut packets = vec![Packet::default(); 32];
284        let recv = recv_mmsg(&reader, &mut packets[..]).unwrap();
285        assert_eq!(16, recv);
286
287        let mut packets = vec![Packet::default(); 32];
288        let recv = recv_mmsg(&reader2, &mut packets[..]).unwrap();
289        assert_eq!(16, recv);
290    }
291
292    #[test]
293    pub fn test_multicast_msg() {
294        let reader = bind_to_localhost().expect("bind");
295        let addr = reader.local_addr().unwrap();
296
297        let reader2 = bind_to_localhost().expect("bind");
298        let addr2 = reader2.local_addr().unwrap();
299
300        let reader3 = bind_to_localhost().expect("bind");
301        let addr3 = reader3.local_addr().unwrap();
302
303        let reader4 = bind_to_localhost().expect("bind");
304        let addr4 = reader4.local_addr().unwrap();
305
306        let sender = bind_to_localhost().expect("bind");
307
308        let packet = Packet::default();
309
310        let sent = multi_target_send(
311            &sender,
312            packet.data(..).unwrap(),
313            &[&addr, &addr2, &addr3, &addr4],
314        )
315        .ok();
316        assert_eq!(sent, Some(()));
317
318        let mut packets = vec![Packet::default(); 32];
319        let recv = recv_mmsg(&reader, &mut packets[..]).unwrap();
320        assert_eq!(1, recv);
321
322        let mut packets = vec![Packet::default(); 32];
323        let recv = recv_mmsg(&reader2, &mut packets[..]).unwrap();
324        assert_eq!(1, recv);
325
326        let mut packets = vec![Packet::default(); 32];
327        let recv = recv_mmsg(&reader3, &mut packets[..]).unwrap();
328        assert_eq!(1, recv);
329
330        let mut packets = vec![Packet::default(); 32];
331        let recv = recv_mmsg(&reader4, &mut packets[..]).unwrap();
332        assert_eq!(1, recv);
333    }
334
335    #[test]
336    fn test_intermediate_failures_mismatched_bind() {
337        let packets: Vec<_> = (0..3).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect();
338        let ip4 = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080);
339        let ip6 = SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 8080);
340        let packet_refs: Vec<_> = vec![
341            (&packets[0][..], &ip4),
342            (&packets[1][..], &ip6),
343            (&packets[2][..], &ip4),
344        ];
345        let dest_refs: Vec<_> = vec![&ip4, &ip6, &ip4];
346
347        let sender = bind_to_unspecified().expect("bind");
348        let res = batch_send(&sender, &packet_refs[..]);
349        assert_matches!(res, Err(SendPktsError::IoError(_, /*num_failed*/ 1)));
350        let res = multi_target_send(&sender, &packets[0], &dest_refs);
351        assert_matches!(res, Err(SendPktsError::IoError(_, /*num_failed*/ 1)));
352    }
353
354    #[test]
355    fn test_intermediate_failures_unreachable_address() {
356        let packets: Vec<_> = (0..5).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect();
357        let ipv4local = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080);
358        let ipv4broadcast = SocketAddr::new(IpAddr::V4(Ipv4Addr::BROADCAST), 8080);
359        let sender = bind_to_unspecified().expect("bind");
360
361        // test intermediate failures for batch_send
362        let packet_refs: Vec<_> = vec![
363            (&packets[0][..], &ipv4local),
364            (&packets[1][..], &ipv4broadcast),
365            (&packets[2][..], &ipv4local),
366            (&packets[3][..], &ipv4broadcast),
367            (&packets[4][..], &ipv4local),
368        ];
369        match batch_send(&sender, &packet_refs[..]) {
370            Ok(()) => panic!(),
371            Err(SendPktsError::IoError(ioerror, num_failed)) => {
372                assert_matches!(ioerror.kind(), ErrorKind::PermissionDenied);
373                assert_eq!(num_failed, 2);
374            }
375        }
376
377        // test leading and trailing failures for batch_send
378        let packet_refs: Vec<_> = vec![
379            (&packets[0][..], &ipv4broadcast),
380            (&packets[1][..], &ipv4local),
381            (&packets[2][..], &ipv4broadcast),
382            (&packets[3][..], &ipv4local),
383            (&packets[4][..], &ipv4broadcast),
384        ];
385        match batch_send(&sender, &packet_refs[..]) {
386            Ok(()) => panic!(),
387            Err(SendPktsError::IoError(ioerror, num_failed)) => {
388                assert_matches!(ioerror.kind(), ErrorKind::PermissionDenied);
389                assert_eq!(num_failed, 3);
390            }
391        }
392
393        // test consecutive intermediate failures for batch_send
394        let packet_refs: Vec<_> = vec![
395            (&packets[0][..], &ipv4local),
396            (&packets[1][..], &ipv4local),
397            (&packets[2][..], &ipv4broadcast),
398            (&packets[3][..], &ipv4broadcast),
399            (&packets[4][..], &ipv4local),
400        ];
401        match batch_send(&sender, &packet_refs[..]) {
402            Ok(()) => panic!(),
403            Err(SendPktsError::IoError(ioerror, num_failed)) => {
404                assert_matches!(ioerror.kind(), ErrorKind::PermissionDenied);
405                assert_eq!(num_failed, 2);
406            }
407        }
408
409        // test intermediate failures for multi_target_send
410        let dest_refs: Vec<_> = vec![
411            &ipv4local,
412            &ipv4broadcast,
413            &ipv4local,
414            &ipv4broadcast,
415            &ipv4local,
416        ];
417        match multi_target_send(&sender, &packets[0], &dest_refs) {
418            Ok(()) => panic!(),
419            Err(SendPktsError::IoError(ioerror, num_failed)) => {
420                assert_matches!(ioerror.kind(), ErrorKind::PermissionDenied);
421                assert_eq!(num_failed, 2);
422            }
423        }
424
425        // test leading and trailing failures for multi_target_send
426        let dest_refs: Vec<_> = vec![
427            &ipv4broadcast,
428            &ipv4local,
429            &ipv4broadcast,
430            &ipv4local,
431            &ipv4broadcast,
432        ];
433        match multi_target_send(&sender, &packets[0], &dest_refs) {
434            Ok(()) => panic!(),
435            Err(SendPktsError::IoError(ioerror, num_failed)) => {
436                assert_matches!(ioerror.kind(), ErrorKind::PermissionDenied);
437                assert_eq!(num_failed, 3);
438            }
439        }
440    }
441}