solana_streamer/
sendmmsg.rs

1//! The `sendmmsg` module provides sendmmsg() API implementation
2
3#[cfg(target_os = "linux")]
4#[allow(deprecated)]
5use nix::sys::socket::InetAddr;
6#[cfg(target_os = "linux")]
7use {
8    itertools::izip,
9    libc::{iovec, mmsghdr, sockaddr_in, sockaddr_in6, sockaddr_storage},
10    std::os::unix::io::AsRawFd,
11};
12use {
13    solana_sdk::transport::TransportError,
14    std::{
15        borrow::Borrow,
16        io,
17        iter::repeat,
18        net::{SocketAddr, UdpSocket},
19    },
20    thiserror::Error,
21};
22
23#[derive(Debug, Error)]
24pub enum SendPktsError {
25    /// IO Error during send: first error, num failed packets
26    #[error("IO Error, some packets could not be sent")]
27    IoError(io::Error, usize),
28}
29
30impl From<SendPktsError> for TransportError {
31    fn from(err: SendPktsError) -> Self {
32        Self::Custom(format!("{:?}", err))
33    }
34}
35
36#[cfg(not(target_os = "linux"))]
37pub fn batch_send<S, T>(sock: &UdpSocket, packets: &[(T, S)]) -> Result<(), SendPktsError>
38where
39    S: Borrow<SocketAddr>,
40    T: AsRef<[u8]>,
41{
42    let mut num_failed = 0;
43    let mut erropt = None;
44    for (p, a) in packets {
45        if let Err(e) = sock.send_to(p.as_ref(), a.borrow()) {
46            num_failed += 1;
47            if erropt.is_none() {
48                erropt = Some(e);
49            }
50        }
51    }
52
53    if let Some(err) = erropt {
54        Err(SendPktsError::IoError(err, num_failed))
55    } else {
56        Ok(())
57    }
58}
59
60#[cfg(target_os = "linux")]
61fn mmsghdr_for_packet(
62    packet: &[u8],
63    dest: &SocketAddr,
64    iov: &mut iovec,
65    addr: &mut sockaddr_storage,
66    hdr: &mut mmsghdr,
67) {
68    const SIZE_OF_SOCKADDR_IN: usize = std::mem::size_of::<sockaddr_in>();
69    const SIZE_OF_SOCKADDR_IN6: usize = std::mem::size_of::<sockaddr_in6>();
70
71    *iov = iovec {
72        iov_base: packet.as_ptr() as *mut libc::c_void,
73        iov_len: packet.len(),
74    };
75    hdr.msg_hdr.msg_iov = iov;
76    hdr.msg_hdr.msg_iovlen = 1;
77    hdr.msg_hdr.msg_name = addr as *mut _ as *mut _;
78
79    #[allow(deprecated)]
80    match InetAddr::from_std(dest) {
81        InetAddr::V4(dest) => {
82            unsafe {
83                std::ptr::write(addr as *mut _ as *mut _, dest);
84            }
85            hdr.msg_hdr.msg_namelen = SIZE_OF_SOCKADDR_IN as u32;
86        }
87        InetAddr::V6(dest) => {
88            unsafe {
89                std::ptr::write(addr as *mut _ as *mut _, dest);
90            }
91            hdr.msg_hdr.msg_namelen = SIZE_OF_SOCKADDR_IN6 as u32;
92        }
93    };
94}
95
96#[cfg(target_os = "linux")]
97fn sendmmsg_retry(sock: &UdpSocket, hdrs: &mut [mmsghdr]) -> Result<(), SendPktsError> {
98    let sock_fd = sock.as_raw_fd();
99    let mut total_sent = 0;
100    let mut erropt = None;
101
102    let mut pkts = &mut *hdrs;
103    while !pkts.is_empty() {
104        let npkts = match unsafe { libc::sendmmsg(sock_fd, &mut pkts[0], pkts.len() as u32, 0) } {
105            -1 => {
106                if erropt.is_none() {
107                    erropt = Some(io::Error::last_os_error());
108                }
109                // skip over the failing packet
110                1_usize
111            }
112            n => {
113                // if we fail to send all packets we advance to the failing
114                // packet and retry in order to capture the error code
115                total_sent += n as usize;
116                n as usize
117            }
118        };
119        pkts = &mut pkts[npkts..];
120    }
121
122    if let Some(err) = erropt {
123        Err(SendPktsError::IoError(err, hdrs.len() - total_sent))
124    } else {
125        Ok(())
126    }
127}
128
129#[cfg(target_os = "linux")]
130pub fn batch_send<S, T>(sock: &UdpSocket, packets: &[(T, S)]) -> Result<(), SendPktsError>
131where
132    S: Borrow<SocketAddr>,
133    T: AsRef<[u8]>,
134{
135    let size = packets.len();
136    #[allow(clippy::uninit_assumed_init)]
137    let iovec = std::mem::MaybeUninit::<iovec>::uninit();
138    let mut iovs = vec![unsafe { iovec.assume_init() }; size];
139    let mut addrs = vec![unsafe { std::mem::zeroed() }; size];
140    let mut hdrs = vec![unsafe { std::mem::zeroed() }; size];
141    for ((pkt, dest), hdr, iov, addr) in izip!(packets, &mut hdrs, &mut iovs, &mut addrs) {
142        mmsghdr_for_packet(pkt.as_ref(), dest.borrow(), iov, addr, hdr);
143    }
144    sendmmsg_retry(sock, &mut hdrs)
145}
146
147pub fn multi_target_send<S, T>(
148    sock: &UdpSocket,
149    packet: T,
150    dests: &[S],
151) -> Result<(), SendPktsError>
152where
153    S: Borrow<SocketAddr>,
154    T: AsRef<[u8]>,
155{
156    let dests = dests.iter().map(Borrow::borrow);
157    let pkts: Vec<_> = repeat(&packet).zip(dests).collect();
158    batch_send(sock, &pkts)
159}
160
161#[cfg(test)]
162mod tests {
163    use {
164        crate::{
165            packet::Packet,
166            recvmmsg::recv_mmsg,
167            sendmmsg::{batch_send, multi_target_send, SendPktsError},
168        },
169        solana_sdk::packet::PACKET_DATA_SIZE,
170        std::{
171            io::ErrorKind,
172            net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket},
173        },
174    };
175
176    #[test]
177    pub fn test_send_mmsg_one_dest() {
178        let reader = UdpSocket::bind("127.0.0.1:0").expect("bind");
179        let addr = reader.local_addr().unwrap();
180        let sender = UdpSocket::bind("127.0.0.1:0").expect("bind");
181
182        let packets: Vec<_> = (0..32).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect();
183        let packet_refs: Vec<_> = packets.iter().map(|p| (&p[..], &addr)).collect();
184
185        let sent = batch_send(&sender, &packet_refs[..]).ok();
186        assert_eq!(sent, Some(()));
187
188        let mut packets = vec![Packet::default(); 32];
189        let recv = recv_mmsg(&reader, &mut packets[..]).unwrap();
190        assert_eq!(32, recv);
191    }
192
193    #[test]
194    pub fn test_send_mmsg_multi_dest() {
195        let reader = UdpSocket::bind("127.0.0.1:0").expect("bind");
196        let addr = reader.local_addr().unwrap();
197
198        let reader2 = UdpSocket::bind("127.0.0.1:0").expect("bind");
199        let addr2 = reader2.local_addr().unwrap();
200
201        let sender = UdpSocket::bind("127.0.0.1:0").expect("bind");
202
203        let packets: Vec<_> = (0..32).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect();
204        let packet_refs: Vec<_> = packets
205            .iter()
206            .enumerate()
207            .map(|(i, p)| {
208                if i < 16 {
209                    (&p[..], &addr)
210                } else {
211                    (&p[..], &addr2)
212                }
213            })
214            .collect();
215
216        let sent = batch_send(&sender, &packet_refs[..]).ok();
217        assert_eq!(sent, Some(()));
218
219        let mut packets = vec![Packet::default(); 32];
220        let recv = recv_mmsg(&reader, &mut packets[..]).unwrap();
221        assert_eq!(16, recv);
222
223        let mut packets = vec![Packet::default(); 32];
224        let recv = recv_mmsg(&reader2, &mut packets[..]).unwrap();
225        assert_eq!(16, recv);
226    }
227
228    #[test]
229    pub fn test_multicast_msg() {
230        let reader = UdpSocket::bind("127.0.0.1:0").expect("bind");
231        let addr = reader.local_addr().unwrap();
232
233        let reader2 = UdpSocket::bind("127.0.0.1:0").expect("bind");
234        let addr2 = reader2.local_addr().unwrap();
235
236        let reader3 = UdpSocket::bind("127.0.0.1:0").expect("bind");
237        let addr3 = reader3.local_addr().unwrap();
238
239        let reader4 = UdpSocket::bind("127.0.0.1:0").expect("bind");
240        let addr4 = reader4.local_addr().unwrap();
241
242        let sender = UdpSocket::bind("127.0.0.1:0").expect("bind");
243
244        let packet = Packet::default();
245
246        let sent = multi_target_send(
247            &sender,
248            packet.data(..).unwrap(),
249            &[&addr, &addr2, &addr3, &addr4],
250        )
251        .ok();
252        assert_eq!(sent, Some(()));
253
254        let mut packets = vec![Packet::default(); 32];
255        let recv = recv_mmsg(&reader, &mut packets[..]).unwrap();
256        assert_eq!(1, recv);
257
258        let mut packets = vec![Packet::default(); 32];
259        let recv = recv_mmsg(&reader2, &mut packets[..]).unwrap();
260        assert_eq!(1, recv);
261
262        let mut packets = vec![Packet::default(); 32];
263        let recv = recv_mmsg(&reader3, &mut packets[..]).unwrap();
264        assert_eq!(1, recv);
265
266        let mut packets = vec![Packet::default(); 32];
267        let recv = recv_mmsg(&reader4, &mut packets[..]).unwrap();
268        assert_eq!(1, recv);
269    }
270
271    #[test]
272    fn test_intermediate_failures_mismatched_bind() {
273        let packets: Vec<_> = (0..3).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect();
274        let ip4 = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080);
275        let ip6 = SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 8080);
276        let packet_refs: Vec<_> = vec![
277            (&packets[0][..], &ip4),
278            (&packets[1][..], &ip6),
279            (&packets[2][..], &ip4),
280        ];
281        let dest_refs: Vec<_> = vec![&ip4, &ip6, &ip4];
282
283        let sender = UdpSocket::bind("0.0.0.0:0").expect("bind");
284        if let Err(SendPktsError::IoError(_, num_failed)) = batch_send(&sender, &packet_refs[..]) {
285            assert_eq!(num_failed, 1);
286        }
287        if let Err(SendPktsError::IoError(_, num_failed)) =
288            multi_target_send(&sender, &packets[0], &dest_refs)
289        {
290            assert_eq!(num_failed, 1);
291        }
292    }
293
294    #[test]
295    fn test_intermediate_failures_unreachable_address() {
296        let packets: Vec<_> = (0..5).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect();
297        let ipv4local = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080);
298        let ipv4broadcast = SocketAddr::new(IpAddr::V4(Ipv4Addr::BROADCAST), 8080);
299        let sender = UdpSocket::bind("0.0.0.0:0").expect("bind");
300
301        // test intermediate failures for batch_send
302        let packet_refs: Vec<_> = vec![
303            (&packets[0][..], &ipv4local),
304            (&packets[1][..], &ipv4broadcast),
305            (&packets[2][..], &ipv4local),
306            (&packets[3][..], &ipv4broadcast),
307            (&packets[4][..], &ipv4local),
308        ];
309        if let Err(SendPktsError::IoError(ioerror, num_failed)) =
310            batch_send(&sender, &packet_refs[..])
311        {
312            assert!(matches!(ioerror.kind(), ErrorKind::PermissionDenied));
313            assert_eq!(num_failed, 2);
314        }
315
316        // test leading and trailing failures for batch_send
317        let packet_refs: Vec<_> = vec![
318            (&packets[0][..], &ipv4broadcast),
319            (&packets[1][..], &ipv4local),
320            (&packets[2][..], &ipv4broadcast),
321            (&packets[3][..], &ipv4local),
322            (&packets[4][..], &ipv4broadcast),
323        ];
324        if let Err(SendPktsError::IoError(ioerror, num_failed)) =
325            batch_send(&sender, &packet_refs[..])
326        {
327            assert!(matches!(ioerror.kind(), ErrorKind::PermissionDenied));
328            assert_eq!(num_failed, 3);
329        }
330
331        // test consecutive intermediate failures for batch_send
332        let packet_refs: Vec<_> = vec![
333            (&packets[0][..], &ipv4local),
334            (&packets[1][..], &ipv4local),
335            (&packets[2][..], &ipv4broadcast),
336            (&packets[3][..], &ipv4broadcast),
337            (&packets[4][..], &ipv4local),
338        ];
339        if let Err(SendPktsError::IoError(ioerror, num_failed)) =
340            batch_send(&sender, &packet_refs[..])
341        {
342            assert!(matches!(ioerror.kind(), ErrorKind::PermissionDenied));
343            assert_eq!(num_failed, 2);
344        }
345
346        // test intermediate failures for multi_target_send
347        let dest_refs: Vec<_> = vec![
348            &ipv4local,
349            &ipv4broadcast,
350            &ipv4local,
351            &ipv4broadcast,
352            &ipv4local,
353        ];
354        if let Err(SendPktsError::IoError(ioerror, num_failed)) =
355            multi_target_send(&sender, &packets[0], &dest_refs)
356        {
357            assert!(matches!(ioerror.kind(), ErrorKind::PermissionDenied));
358            assert_eq!(num_failed, 2);
359        }
360
361        // test leading and trailing failures for multi_target_send
362        let dest_refs: Vec<_> = vec![
363            &ipv4broadcast,
364            &ipv4local,
365            &ipv4broadcast,
366            &ipv4local,
367            &ipv4broadcast,
368        ];
369        if let Err(SendPktsError::IoError(ioerror, num_failed)) =
370            multi_target_send(&sender, &packets[0], &dest_refs)
371        {
372            assert!(matches!(ioerror.kind(), ErrorKind::PermissionDenied));
373            assert_eq!(num_failed, 3);
374        }
375    }
376}