solana_streamer/
sendmmsg.rs

1//! The `sendmmsg` module provides sendmmsg() API implementation
2
3#[cfg(target_os = "linux")]
4use {
5    itertools::izip,
6    libc::{iovec, mmsghdr, msghdr, sockaddr_in, sockaddr_in6, sockaddr_storage, socklen_t},
7    std::{
8        mem::{self, MaybeUninit},
9        os::unix::io::AsRawFd,
10        ptr,
11    },
12};
13use {
14    solana_sdk::transport::TransportError,
15    std::{
16        borrow::Borrow,
17        io,
18        iter::repeat,
19        net::{SocketAddr, UdpSocket},
20    },
21    thiserror::Error,
22};
23
24#[derive(Debug, Error)]
25pub enum SendPktsError {
26    /// IO Error during send: first error, num failed packets
27    #[error("IO Error, some packets could not be sent")]
28    IoError(io::Error, usize),
29}
30
31impl From<SendPktsError> for TransportError {
32    fn from(err: SendPktsError) -> Self {
33        Self::Custom(format!("{err:?}"))
34    }
35}
36
37#[cfg(not(target_os = "linux"))]
38pub fn batch_send<S, T>(sock: &UdpSocket, packets: &[(T, S)]) -> Result<(), SendPktsError>
39where
40    S: Borrow<SocketAddr>,
41    T: AsRef<[u8]>,
42{
43    let mut num_failed = 0;
44    let mut erropt = None;
45    for (p, a) in packets {
46        if let Err(e) = sock.send_to(p.as_ref(), a.borrow()) {
47            num_failed += 1;
48            if erropt.is_none() {
49                erropt = Some(e);
50            }
51        }
52    }
53
54    if let Some(err) = erropt {
55        Err(SendPktsError::IoError(err, num_failed))
56    } else {
57        Ok(())
58    }
59}
60
61#[cfg(target_os = "linux")]
62fn mmsghdr_for_packet(
63    packet: &[u8],
64    dest: &SocketAddr,
65    iov: &mut MaybeUninit<iovec>,
66    addr: &mut MaybeUninit<sockaddr_storage>,
67    hdr: &mut MaybeUninit<mmsghdr>,
68) {
69    const SIZE_OF_SOCKADDR_IN: usize = mem::size_of::<sockaddr_in>();
70    const SIZE_OF_SOCKADDR_IN6: usize = mem::size_of::<sockaddr_in6>();
71
72    iov.write(iovec {
73        iov_base: packet.as_ptr() as *mut libc::c_void,
74        iov_len: packet.len(),
75    });
76
77    let msg_namelen = match dest {
78        SocketAddr::V4(socket_addr_v4) => {
79            unsafe {
80                ptr::write(
81                    addr.as_mut_ptr() as *mut _,
82                    *nix::sys::socket::SockaddrIn::from(*socket_addr_v4).as_ref(),
83                );
84            }
85            SIZE_OF_SOCKADDR_IN as socklen_t
86        }
87        SocketAddr::V6(socket_addr_v6) => {
88            unsafe {
89                ptr::write(
90                    addr.as_mut_ptr() as *mut _,
91                    *nix::sys::socket::SockaddrIn6::from(*socket_addr_v6).as_ref(),
92                );
93            }
94            SIZE_OF_SOCKADDR_IN6 as socklen_t
95        }
96    };
97
98    hdr.write(mmsghdr {
99        msg_len: 0,
100        msg_hdr: msghdr {
101            msg_name: addr as *mut _ as *mut _,
102            msg_namelen,
103            msg_iov: iov.as_mut_ptr(),
104            msg_iovlen: 1,
105            msg_control: ptr::null::<libc::c_void>() as *mut _,
106            msg_controllen: 0,
107            msg_flags: 0,
108        },
109    });
110}
111
112#[cfg(target_os = "linux")]
113fn sendmmsg_retry(sock: &UdpSocket, hdrs: &mut [mmsghdr]) -> Result<(), SendPktsError> {
114    let sock_fd = sock.as_raw_fd();
115    let mut total_sent = 0;
116    let mut erropt = None;
117
118    let mut pkts = &mut *hdrs;
119    while !pkts.is_empty() {
120        let npkts = match unsafe { libc::sendmmsg(sock_fd, &mut pkts[0], pkts.len() as u32, 0) } {
121            -1 => {
122                if erropt.is_none() {
123                    erropt = Some(io::Error::last_os_error());
124                }
125                // skip over the failing packet
126                1_usize
127            }
128            n => {
129                // if we fail to send all packets we advance to the failing
130                // packet and retry in order to capture the error code
131                total_sent += n as usize;
132                n as usize
133            }
134        };
135        pkts = &mut pkts[npkts..];
136    }
137
138    if let Some(err) = erropt {
139        Err(SendPktsError::IoError(err, hdrs.len() - total_sent))
140    } else {
141        Ok(())
142    }
143}
144
145#[cfg(target_os = "linux")]
146pub fn batch_send<S, T>(sock: &UdpSocket, packets: &[(T, S)]) -> Result<(), SendPktsError>
147where
148    S: Borrow<SocketAddr>,
149    T: AsRef<[u8]>,
150{
151    let size = packets.len();
152    let mut iovs = vec![MaybeUninit::uninit(); size];
153    let mut addrs = vec![MaybeUninit::zeroed(); size];
154    let mut hdrs = vec![MaybeUninit::uninit(); size];
155    for ((pkt, dest), hdr, iov, addr) in izip!(packets, &mut hdrs, &mut iovs, &mut addrs) {
156        mmsghdr_for_packet(pkt.as_ref(), dest.borrow(), iov, addr, hdr);
157    }
158    // mmsghdr_for_packet() performs initialization so we can safely transmute
159    // the Vecs to their initialized counterparts
160    let _iovs = unsafe { mem::transmute::<Vec<MaybeUninit<iovec>>, Vec<iovec>>(iovs) };
161    let _addrs = unsafe {
162        mem::transmute::<Vec<MaybeUninit<sockaddr_storage>>, Vec<sockaddr_storage>>(addrs)
163    };
164    let mut hdrs = unsafe { mem::transmute::<Vec<MaybeUninit<mmsghdr>>, Vec<mmsghdr>>(hdrs) };
165
166    sendmmsg_retry(sock, &mut hdrs)
167}
168
169pub fn multi_target_send<S, T>(
170    sock: &UdpSocket,
171    packet: T,
172    dests: &[S],
173) -> Result<(), SendPktsError>
174where
175    S: Borrow<SocketAddr>,
176    T: AsRef<[u8]>,
177{
178    let dests = dests.iter().map(Borrow::borrow);
179    let pkts: Vec<_> = repeat(&packet).zip(dests).collect();
180    batch_send(sock, &pkts)
181}
182
183#[cfg(test)]
184mod tests {
185    use {
186        crate::{
187            packet::Packet,
188            recvmmsg::recv_mmsg,
189            sendmmsg::{batch_send, multi_target_send, SendPktsError},
190        },
191        assert_matches::assert_matches,
192        solana_sdk::packet::PACKET_DATA_SIZE,
193        std::{
194            io::ErrorKind,
195            net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket},
196        },
197    };
198
199    #[test]
200    pub fn test_send_mmsg_one_dest() {
201        let reader = UdpSocket::bind("127.0.0.1:0").expect("bind");
202        let addr = reader.local_addr().unwrap();
203        let sender = UdpSocket::bind("127.0.0.1:0").expect("bind");
204
205        let packets: Vec<_> = (0..32).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect();
206        let packet_refs: Vec<_> = packets.iter().map(|p| (&p[..], &addr)).collect();
207
208        let sent = batch_send(&sender, &packet_refs[..]).ok();
209        assert_eq!(sent, Some(()));
210
211        let mut packets = vec![Packet::default(); 32];
212        let recv = recv_mmsg(&reader, &mut packets[..]).unwrap();
213        assert_eq!(32, recv);
214    }
215
216    #[test]
217    pub fn test_send_mmsg_multi_dest() {
218        let reader = UdpSocket::bind("127.0.0.1:0").expect("bind");
219        let addr = reader.local_addr().unwrap();
220
221        let reader2 = UdpSocket::bind("127.0.0.1:0").expect("bind");
222        let addr2 = reader2.local_addr().unwrap();
223
224        let sender = UdpSocket::bind("127.0.0.1:0").expect("bind");
225
226        let packets: Vec<_> = (0..32).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect();
227        let packet_refs: Vec<_> = packets
228            .iter()
229            .enumerate()
230            .map(|(i, p)| {
231                if i < 16 {
232                    (&p[..], &addr)
233                } else {
234                    (&p[..], &addr2)
235                }
236            })
237            .collect();
238
239        let sent = batch_send(&sender, &packet_refs[..]).ok();
240        assert_eq!(sent, Some(()));
241
242        let mut packets = vec![Packet::default(); 32];
243        let recv = recv_mmsg(&reader, &mut packets[..]).unwrap();
244        assert_eq!(16, recv);
245
246        let mut packets = vec![Packet::default(); 32];
247        let recv = recv_mmsg(&reader2, &mut packets[..]).unwrap();
248        assert_eq!(16, recv);
249    }
250
251    #[test]
252    pub fn test_multicast_msg() {
253        let reader = UdpSocket::bind("127.0.0.1:0").expect("bind");
254        let addr = reader.local_addr().unwrap();
255
256        let reader2 = UdpSocket::bind("127.0.0.1:0").expect("bind");
257        let addr2 = reader2.local_addr().unwrap();
258
259        let reader3 = UdpSocket::bind("127.0.0.1:0").expect("bind");
260        let addr3 = reader3.local_addr().unwrap();
261
262        let reader4 = UdpSocket::bind("127.0.0.1:0").expect("bind");
263        let addr4 = reader4.local_addr().unwrap();
264
265        let sender = UdpSocket::bind("127.0.0.1:0").expect("bind");
266
267        let packet = Packet::default();
268
269        let sent = multi_target_send(
270            &sender,
271            packet.data(..).unwrap(),
272            &[&addr, &addr2, &addr3, &addr4],
273        )
274        .ok();
275        assert_eq!(sent, Some(()));
276
277        let mut packets = vec![Packet::default(); 32];
278        let recv = recv_mmsg(&reader, &mut packets[..]).unwrap();
279        assert_eq!(1, recv);
280
281        let mut packets = vec![Packet::default(); 32];
282        let recv = recv_mmsg(&reader2, &mut packets[..]).unwrap();
283        assert_eq!(1, recv);
284
285        let mut packets = vec![Packet::default(); 32];
286        let recv = recv_mmsg(&reader3, &mut packets[..]).unwrap();
287        assert_eq!(1, recv);
288
289        let mut packets = vec![Packet::default(); 32];
290        let recv = recv_mmsg(&reader4, &mut packets[..]).unwrap();
291        assert_eq!(1, recv);
292    }
293
294    #[test]
295    fn test_intermediate_failures_mismatched_bind() {
296        let packets: Vec<_> = (0..3).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect();
297        let ip4 = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080);
298        let ip6 = SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 8080);
299        let packet_refs: Vec<_> = vec![
300            (&packets[0][..], &ip4),
301            (&packets[1][..], &ip6),
302            (&packets[2][..], &ip4),
303        ];
304        let dest_refs: Vec<_> = vec![&ip4, &ip6, &ip4];
305
306        let sender = UdpSocket::bind("0.0.0.0:0").expect("bind");
307        let res = batch_send(&sender, &packet_refs[..]);
308        assert_matches!(res, Err(SendPktsError::IoError(_, /*num_failed*/ 1)));
309        let res = multi_target_send(&sender, &packets[0], &dest_refs);
310        assert_matches!(res, Err(SendPktsError::IoError(_, /*num_failed*/ 1)));
311    }
312
313    #[test]
314    fn test_intermediate_failures_unreachable_address() {
315        let packets: Vec<_> = (0..5).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect();
316        let ipv4local = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080);
317        let ipv4broadcast = SocketAddr::new(IpAddr::V4(Ipv4Addr::BROADCAST), 8080);
318        let sender = UdpSocket::bind("0.0.0.0:0").expect("bind");
319
320        // test intermediate failures for batch_send
321        let packet_refs: Vec<_> = vec![
322            (&packets[0][..], &ipv4local),
323            (&packets[1][..], &ipv4broadcast),
324            (&packets[2][..], &ipv4local),
325            (&packets[3][..], &ipv4broadcast),
326            (&packets[4][..], &ipv4local),
327        ];
328        match batch_send(&sender, &packet_refs[..]) {
329            Ok(()) => panic!(),
330            Err(SendPktsError::IoError(ioerror, num_failed)) => {
331                assert_matches!(ioerror.kind(), ErrorKind::PermissionDenied);
332                assert_eq!(num_failed, 2);
333            }
334        }
335
336        // test leading and trailing failures for batch_send
337        let packet_refs: Vec<_> = vec![
338            (&packets[0][..], &ipv4broadcast),
339            (&packets[1][..], &ipv4local),
340            (&packets[2][..], &ipv4broadcast),
341            (&packets[3][..], &ipv4local),
342            (&packets[4][..], &ipv4broadcast),
343        ];
344        match batch_send(&sender, &packet_refs[..]) {
345            Ok(()) => panic!(),
346            Err(SendPktsError::IoError(ioerror, num_failed)) => {
347                assert_matches!(ioerror.kind(), ErrorKind::PermissionDenied);
348                assert_eq!(num_failed, 3);
349            }
350        }
351
352        // test consecutive intermediate failures for batch_send
353        let packet_refs: Vec<_> = vec![
354            (&packets[0][..], &ipv4local),
355            (&packets[1][..], &ipv4local),
356            (&packets[2][..], &ipv4broadcast),
357            (&packets[3][..], &ipv4broadcast),
358            (&packets[4][..], &ipv4local),
359        ];
360        match batch_send(&sender, &packet_refs[..]) {
361            Ok(()) => panic!(),
362            Err(SendPktsError::IoError(ioerror, num_failed)) => {
363                assert_matches!(ioerror.kind(), ErrorKind::PermissionDenied);
364                assert_eq!(num_failed, 2);
365            }
366        }
367
368        // test intermediate failures for multi_target_send
369        let dest_refs: Vec<_> = vec![
370            &ipv4local,
371            &ipv4broadcast,
372            &ipv4local,
373            &ipv4broadcast,
374            &ipv4local,
375        ];
376        match multi_target_send(&sender, &packets[0], &dest_refs) {
377            Ok(()) => panic!(),
378            Err(SendPktsError::IoError(ioerror, num_failed)) => {
379                assert_matches!(ioerror.kind(), ErrorKind::PermissionDenied);
380                assert_eq!(num_failed, 2);
381            }
382        }
383
384        // test leading and trailing failures for multi_target_send
385        let dest_refs: Vec<_> = vec![
386            &ipv4broadcast,
387            &ipv4local,
388            &ipv4broadcast,
389            &ipv4local,
390            &ipv4broadcast,
391        ];
392        match multi_target_send(&sender, &packets[0], &dest_refs) {
393            Ok(()) => panic!(),
394            Err(SendPktsError::IoError(ioerror, num_failed)) => {
395                assert_matches!(ioerror.kind(), ErrorKind::PermissionDenied);
396                assert_eq!(num_failed, 3);
397            }
398        }
399    }
400}