solana_streamer/
recvmmsg.rs

1//! The `recvmmsg` module provides recvmmsg() API implementation
2
3pub use solana_perf::packet::NUM_RCVMMSGS;
4#[cfg(target_os = "linux")]
5use {
6    crate::msghdr::create_msghdr,
7    itertools::izip,
8    libc::{iovec, mmsghdr, sockaddr_storage, socklen_t, AF_INET, AF_INET6, MSG_WAITFORONE},
9    std::{
10        mem::{self, MaybeUninit},
11        net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
12        os::unix::io::AsRawFd,
13    },
14};
15use {
16    crate::packet::{Meta, Packet},
17    std::{cmp, io, net::UdpSocket},
18};
19
20#[cfg(not(target_os = "linux"))]
21pub fn recv_mmsg(socket: &UdpSocket, packets: &mut [Packet]) -> io::Result</*num packets:*/ usize> {
22    debug_assert!(packets.iter().all(|pkt| pkt.meta() == &Meta::default()));
23    let mut i = 0;
24    let count = cmp::min(NUM_RCVMMSGS, packets.len());
25    for p in packets.iter_mut().take(count) {
26        p.meta_mut().size = 0;
27        match socket.recv_from(p.buffer_mut()) {
28            Err(_) if i > 0 => {
29                break;
30            }
31            Err(e) => {
32                return Err(e);
33            }
34            Ok((nrecv, from)) => {
35                p.meta_mut().size = nrecv;
36                p.meta_mut().set_socket_addr(&from);
37                if i == 0 {
38                    socket.set_nonblocking(true)?;
39                }
40            }
41        }
42        i += 1;
43    }
44    Ok(i)
45}
46
47#[cfg(target_os = "linux")]
48fn cast_socket_addr(addr: &sockaddr_storage, hdr: &mmsghdr) -> Option<SocketAddr> {
49    use libc::{sa_family_t, sockaddr_in, sockaddr_in6};
50    const SOCKADDR_IN_SIZE: usize = std::mem::size_of::<sockaddr_in>();
51    const SOCKADDR_IN6_SIZE: usize = std::mem::size_of::<sockaddr_in6>();
52    if addr.ss_family == AF_INET as sa_family_t
53        && hdr.msg_hdr.msg_namelen == SOCKADDR_IN_SIZE as socklen_t
54    {
55        // ref: https://github.com/rust-lang/socket2/blob/65085d9dff270e588c0fbdd7217ec0b392b05ef2/src/sockaddr.rs#L167-L172
56        let addr = unsafe { &*(addr as *const _ as *const sockaddr_in) };
57        return Some(SocketAddr::V4(SocketAddrV4::new(
58            Ipv4Addr::from(addr.sin_addr.s_addr.to_ne_bytes()),
59            u16::from_be(addr.sin_port),
60        )));
61    }
62    if addr.ss_family == AF_INET6 as sa_family_t
63        && hdr.msg_hdr.msg_namelen == SOCKADDR_IN6_SIZE as socklen_t
64    {
65        // ref: https://github.com/rust-lang/socket2/blob/65085d9dff270e588c0fbdd7217ec0b392b05ef2/src/sockaddr.rs#L174-L189
66        let addr = unsafe { &*(addr as *const _ as *const sockaddr_in6) };
67        return Some(SocketAddr::V6(SocketAddrV6::new(
68            Ipv6Addr::from(addr.sin6_addr.s6_addr),
69            u16::from_be(addr.sin6_port),
70            addr.sin6_flowinfo,
71            addr.sin6_scope_id,
72        )));
73    }
74    error!(
75        "recvmmsg unexpected ss_family:{} msg_namelen:{}",
76        addr.ss_family, hdr.msg_hdr.msg_namelen
77    );
78    None
79}
80
81#[cfg(target_os = "linux")]
82pub fn recv_mmsg(sock: &UdpSocket, packets: &mut [Packet]) -> io::Result</*num packets:*/ usize> {
83    // Should never hit this, but bail if the caller didn't provide any Packets
84    // to receive into
85    if packets.is_empty() {
86        return Ok(0);
87    }
88    // Assert that there are no leftovers in packets.
89    debug_assert!(packets.iter().all(|pkt| pkt.meta() == &Meta::default()));
90    const SOCKADDR_STORAGE_SIZE: socklen_t = mem::size_of::<sockaddr_storage>() as socklen_t;
91
92    let mut iovs = [MaybeUninit::uninit(); NUM_RCVMMSGS];
93    let mut addrs = [MaybeUninit::zeroed(); NUM_RCVMMSGS];
94    let mut hdrs = [MaybeUninit::uninit(); NUM_RCVMMSGS];
95
96    let sock_fd = sock.as_raw_fd();
97    let count = cmp::min(iovs.len(), packets.len());
98
99    for (packet, hdr, iov, addr) in
100        izip!(packets.iter_mut(), &mut hdrs, &mut iovs, &mut addrs).take(count)
101    {
102        let buffer = packet.buffer_mut();
103        iov.write(iovec {
104            iov_base: buffer.as_mut_ptr() as *mut libc::c_void,
105            iov_len: buffer.len(),
106        });
107
108        let msg_hdr = create_msghdr(addr, SOCKADDR_STORAGE_SIZE, iov);
109
110        hdr.write(mmsghdr {
111            msg_len: 0,
112            msg_hdr,
113        });
114    }
115
116    let mut ts = libc::timespec {
117        tv_sec: 1,
118        tv_nsec: 0,
119    };
120    // TODO: remove .try_into().unwrap() once rust libc fixes recvmmsg types for musl
121    #[allow(clippy::useless_conversion)]
122    let nrecv = unsafe {
123        libc::recvmmsg(
124            sock_fd,
125            hdrs[0].assume_init_mut(),
126            count as u32,
127            MSG_WAITFORONE.try_into().unwrap(),
128            &mut ts,
129        )
130    };
131    let nrecv = if nrecv < 0 {
132        return Err(io::Error::last_os_error());
133    } else {
134        usize::try_from(nrecv).unwrap()
135    };
136    for (addr, hdr, pkt) in izip!(addrs, hdrs, packets.iter_mut()).take(nrecv) {
137        // SAFETY: We initialized `count` elements of `hdrs` above. `count` is
138        // passed to recvmmsg() as the limit of messages that can be read. So,
139        // `nrevc <= count` which means we initialized this `hdr` and
140        // recvmmsg() will have updated it appropriately
141        let hdr_ref = unsafe { hdr.assume_init_ref() };
142        // SAFETY: Similar to above, we initialized this `addr` and recvmmsg()
143        // will have populated it
144        let addr_ref = unsafe { addr.assume_init_ref() };
145        pkt.meta_mut().size = hdr_ref.msg_len as usize;
146        if let Some(addr) = cast_socket_addr(addr_ref, hdr_ref) {
147            pkt.meta_mut().set_socket_addr(&addr);
148        }
149    }
150
151    for (iov, addr, hdr) in izip!(&mut iovs, &mut addrs, &mut hdrs).take(count) {
152        // SAFETY: We initialized `count` elements of each array above
153        //
154        // It may be that `packets.len() != NUM_RCVMMSGS`; thus, some elements
155        // in `iovs` / `addrs` / `hdrs` may not get initialized. So, we must
156        // manually drop `count` elements from each array instead of being able
157        // to convert [MaybeUninit<T>] to [T] and letting `Drop` do the work
158        // for us when these items go out of scope at the end of the function
159        unsafe {
160            iov.assume_init_drop();
161            addr.assume_init_drop();
162            hdr.assume_init_drop();
163        }
164    }
165
166    Ok(nrecv)
167}
168
169#[cfg(test)]
170mod tests {
171    use {
172        crate::{packet::PACKET_DATA_SIZE, recvmmsg::*},
173        solana_net_utils::{bind_to, bind_to_localhost},
174        std::{
175            net::{SocketAddr, UdpSocket},
176            time::{Duration, Instant},
177        },
178    };
179
180    type TestConfig = (UdpSocket, SocketAddr, UdpSocket, SocketAddr);
181
182    fn test_setup_reader_sender(ip_str: &str) -> io::Result<TestConfig> {
183        let sock_addr: SocketAddr = ip_str
184            .parse()
185            .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
186        let reader = bind_to(sock_addr.ip(), sock_addr.port(), /*reuseport:*/ false)?;
187        let addr = reader.local_addr()?;
188        let sender = bind_to(sock_addr.ip(), sock_addr.port(), /*reuseport:*/ false)?;
189        let saddr = sender.local_addr()?;
190        Ok((reader, addr, sender, saddr))
191    }
192
193    const TEST_NUM_MSGS: usize = 32;
194    #[test]
195    pub fn test_recv_mmsg_one_iter() {
196        let test_one_iter = |(reader, addr, sender, saddr): TestConfig| {
197            let sent = TEST_NUM_MSGS - 1;
198            for _ in 0..sent {
199                let data = [0; PACKET_DATA_SIZE];
200                sender.send_to(&data[..], addr).unwrap();
201            }
202
203            let mut packets = vec![Packet::default(); TEST_NUM_MSGS];
204            let recv = recv_mmsg(&reader, &mut packets[..]).unwrap();
205            assert_eq!(sent, recv);
206            for packet in packets.iter().take(recv) {
207                assert_eq!(packet.meta().size, PACKET_DATA_SIZE);
208                assert_eq!(packet.meta().socket_addr(), saddr);
209            }
210        };
211
212        test_one_iter(test_setup_reader_sender("127.0.0.1:0").unwrap());
213
214        match test_setup_reader_sender("::1:0") {
215            Ok(config) => test_one_iter(config),
216            Err(e) => warn!("Failed to configure IPv6: {:?}", e),
217        }
218    }
219
220    #[test]
221    pub fn test_recv_mmsg_multi_iter() {
222        let test_multi_iter = |(reader, addr, sender, saddr): TestConfig| {
223            let sent = TEST_NUM_MSGS + 10;
224            for _ in 0..sent {
225                let data = [0; PACKET_DATA_SIZE];
226                sender.send_to(&data[..], addr).unwrap();
227            }
228
229            let mut packets = vec![Packet::default(); TEST_NUM_MSGS];
230            let recv = recv_mmsg(&reader, &mut packets[..]).unwrap();
231            assert_eq!(TEST_NUM_MSGS, recv);
232            for packet in packets.iter().take(recv) {
233                assert_eq!(packet.meta().size, PACKET_DATA_SIZE);
234                assert_eq!(packet.meta().socket_addr(), saddr);
235            }
236
237            packets
238                .iter_mut()
239                .for_each(|pkt| *pkt.meta_mut() = Meta::default());
240            let recv = recv_mmsg(&reader, &mut packets[..]).unwrap();
241            assert_eq!(sent - TEST_NUM_MSGS, recv);
242            for packet in packets.iter().take(recv) {
243                assert_eq!(packet.meta().size, PACKET_DATA_SIZE);
244                assert_eq!(packet.meta().socket_addr(), saddr);
245            }
246        };
247
248        test_multi_iter(test_setup_reader_sender("127.0.0.1:0").unwrap());
249
250        match test_setup_reader_sender("::1:0") {
251            Ok(config) => test_multi_iter(config),
252            Err(e) => warn!("Failed to configure IPv6: {:?}", e),
253        }
254    }
255
256    #[test]
257    pub fn test_recv_mmsg_multi_iter_timeout() {
258        let reader = bind_to_localhost().expect("bind");
259        let addr = reader.local_addr().unwrap();
260        reader.set_read_timeout(Some(Duration::new(5, 0))).unwrap();
261        reader.set_nonblocking(false).unwrap();
262        let sender = bind_to_localhost().expect("bind");
263        let saddr = sender.local_addr().unwrap();
264        let sent = TEST_NUM_MSGS;
265        for _ in 0..sent {
266            let data = [0; PACKET_DATA_SIZE];
267            sender.send_to(&data[..], addr).unwrap();
268        }
269
270        let start = Instant::now();
271        let mut packets = vec![Packet::default(); TEST_NUM_MSGS];
272        let recv = recv_mmsg(&reader, &mut packets[..]).unwrap();
273        assert_eq!(TEST_NUM_MSGS, recv);
274        for packet in packets.iter().take(recv) {
275            assert_eq!(packet.meta().size, PACKET_DATA_SIZE);
276            assert_eq!(packet.meta().socket_addr(), saddr);
277        }
278        reader.set_nonblocking(true).unwrap();
279
280        packets
281            .iter_mut()
282            .for_each(|pkt| *pkt.meta_mut() = Meta::default());
283        let _recv = recv_mmsg(&reader, &mut packets[..]);
284        assert!(start.elapsed().as_secs() < 5);
285    }
286
287    #[test]
288    pub fn test_recv_mmsg_multi_addrs() {
289        let reader = bind_to_localhost().expect("bind");
290        let addr = reader.local_addr().unwrap();
291
292        let sender1 = bind_to_localhost().expect("bind");
293        let saddr1 = sender1.local_addr().unwrap();
294        let sent1 = TEST_NUM_MSGS - 1;
295
296        let sender2 = bind_to_localhost().expect("bind");
297        let saddr2 = sender2.local_addr().unwrap();
298        let sent2 = TEST_NUM_MSGS + 1;
299
300        for _ in 0..sent1 {
301            let data = [0; PACKET_DATA_SIZE];
302            sender1.send_to(&data[..], addr).unwrap();
303        }
304
305        for _ in 0..sent2 {
306            let data = [0; PACKET_DATA_SIZE];
307            sender2.send_to(&data[..], addr).unwrap();
308        }
309
310        let mut packets = vec![Packet::default(); TEST_NUM_MSGS];
311
312        let recv = recv_mmsg(&reader, &mut packets[..]).unwrap();
313        assert_eq!(TEST_NUM_MSGS, recv);
314        for packet in packets.iter().take(sent1) {
315            assert_eq!(packet.meta().size, PACKET_DATA_SIZE);
316            assert_eq!(packet.meta().socket_addr(), saddr1);
317        }
318        for packet in packets.iter().skip(sent1).take(recv - sent1) {
319            assert_eq!(packet.meta().size, PACKET_DATA_SIZE);
320            assert_eq!(packet.meta().socket_addr(), saddr2);
321        }
322
323        packets
324            .iter_mut()
325            .for_each(|pkt| *pkt.meta_mut() = Meta::default());
326        let recv = recv_mmsg(&reader, &mut packets[..]).unwrap();
327        assert_eq!(sent1 + sent2 - TEST_NUM_MSGS, recv);
328        for packet in packets.iter().take(recv) {
329            assert_eq!(packet.meta().size, PACKET_DATA_SIZE);
330            assert_eq!(packet.meta().socket_addr(), saddr2);
331        }
332    }
333}