solana_streamer/nonblocking/
recvmmsg.rs

1//! The `recvmmsg` module provides a nonblocking recvmmsg() API implementation
2
3use {
4    crate::{
5        packet::{Meta, Packet},
6        recvmmsg::NUM_RCVMMSGS,
7    },
8    std::{cmp, io},
9    tokio::net::UdpSocket,
10};
11
12/// Pulls some packets from the socket into the specified container
13/// returning how many packets were read
14pub async fn recv_mmsg(
15    socket: &UdpSocket,
16    packets: &mut [Packet],
17) -> io::Result</*num packets:*/ usize> {
18    debug_assert!(packets.iter().all(|pkt| pkt.meta() == &Meta::default()));
19    let count = cmp::min(NUM_RCVMMSGS, packets.len());
20    socket.readable().await?;
21    let mut i = 0;
22    for p in packets.iter_mut().take(count) {
23        p.meta_mut().size = 0;
24        match socket.try_recv_from(p.buffer_mut()) {
25            Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
26                break;
27            }
28            Err(e) => {
29                return Err(e);
30            }
31            Ok((nrecv, from)) => {
32                p.meta_mut().size = nrecv;
33                p.meta_mut().set_socket_addr(&from);
34            }
35        }
36        i += 1;
37    }
38    Ok(i)
39}
40
41/// Reads the exact number of packets required to fill `packets`
42pub async fn recv_mmsg_exact(
43    socket: &UdpSocket,
44    packets: &mut [Packet],
45) -> io::Result</*num packets:*/ usize> {
46    let total = packets.len();
47    let mut remaining = total;
48    while remaining != 0 {
49        let first = total - remaining;
50        let res = recv_mmsg(socket, &mut packets[first..]).await?;
51        remaining -= res;
52    }
53    Ok(packets.len())
54}
55
56#[cfg(test)]
57mod tests {
58    use {
59        crate::{nonblocking::recvmmsg::*, packet::PACKET_DATA_SIZE},
60        solana_net_utils::{bind_to_async, bind_to_localhost_async},
61        std::{net::SocketAddr, time::Instant},
62        tokio::net::UdpSocket,
63    };
64
65    type TestConfig = (UdpSocket, SocketAddr, UdpSocket, SocketAddr);
66
67    async fn test_setup_reader_sender(ip_str: &str) -> io::Result<TestConfig> {
68        let sock_addr: SocketAddr = ip_str
69            .parse()
70            .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
71        let reader = bind_to_async(sock_addr.ip(), sock_addr.port(), /*reuseport:*/ false).await?;
72        let addr = reader.local_addr()?;
73        let sender = bind_to_async(sock_addr.ip(), sock_addr.port(), /*reuseport:*/ false).await?;
74        let saddr = sender.local_addr()?;
75        Ok((reader, addr, sender, saddr))
76    }
77
78    const TEST_NUM_MSGS: usize = 32;
79
80    async fn test_one_iter((reader, addr, sender, saddr): TestConfig) {
81        let sent = TEST_NUM_MSGS - 1;
82        for _ in 0..sent {
83            let data = [0; PACKET_DATA_SIZE];
84            sender.send_to(&data[..], &addr).await.unwrap();
85        }
86
87        let mut packets = vec![Packet::default(); sent];
88        let recv = recv_mmsg_exact(&reader, &mut packets[..]).await.unwrap();
89        assert_eq!(sent, recv);
90        for packet in packets.iter().take(recv) {
91            assert_eq!(packet.meta().size, PACKET_DATA_SIZE);
92            assert_eq!(packet.meta().socket_addr(), saddr);
93        }
94    }
95
96    #[tokio::test]
97    async fn test_recv_mmsg_one_iter() {
98        test_one_iter(test_setup_reader_sender("127.0.0.1:0").await.unwrap()).await;
99
100        match test_setup_reader_sender("::1:0").await {
101            Ok(config) => test_one_iter(config).await,
102            Err(e) => warn!("Failed to configure IPv6: {:?}", e),
103        }
104    }
105
106    async fn test_multi_iter((reader, addr, sender, saddr): TestConfig) {
107        let sent = TEST_NUM_MSGS + 10;
108        for _ in 0..sent {
109            let data = [0; PACKET_DATA_SIZE];
110            sender.send_to(&data[..], &addr).await.unwrap();
111        }
112
113        let mut packets = vec![Packet::default(); TEST_NUM_MSGS];
114        let recv = recv_mmsg_exact(&reader, &mut packets[..]).await.unwrap();
115        assert_eq!(TEST_NUM_MSGS, recv);
116        for packet in packets.iter().take(recv) {
117            assert_eq!(packet.meta().size, PACKET_DATA_SIZE);
118            assert_eq!(packet.meta().socket_addr(), saddr);
119        }
120
121        let mut packets = vec![Packet::default(); sent - TEST_NUM_MSGS];
122        packets
123            .iter_mut()
124            .for_each(|pkt| *pkt.meta_mut() = Meta::default());
125        let recv = recv_mmsg_exact(&reader, &mut packets[..]).await.unwrap();
126        assert_eq!(sent - TEST_NUM_MSGS, recv);
127        for packet in packets.iter().take(recv) {
128            assert_eq!(packet.meta().size, PACKET_DATA_SIZE);
129            assert_eq!(packet.meta().socket_addr(), saddr);
130        }
131    }
132
133    #[tokio::test]
134    async fn test_recv_mmsg_multi_iter() {
135        test_multi_iter(test_setup_reader_sender("127.0.0.1:0").await.unwrap()).await;
136
137        match test_setup_reader_sender("::1:0").await {
138            Ok(config) => test_multi_iter(config).await,
139            Err(e) => warn!("Failed to configure IPv6: {:?}", e),
140        }
141    }
142
143    #[tokio::test]
144    async fn test_recv_mmsg_exact_multi_iter_timeout() {
145        let reader = bind_to_localhost_async().await.expect("bind");
146        let addr = reader.local_addr().unwrap();
147        let sender = bind_to_localhost_async().await.expect("bind");
148        let saddr = sender.local_addr().unwrap();
149        let sent = TEST_NUM_MSGS;
150        for _ in 0..sent {
151            let data = [0; PACKET_DATA_SIZE];
152            sender.send_to(&data[..], &addr).await.unwrap();
153        }
154
155        let start = Instant::now();
156        let mut packets = vec![Packet::default(); TEST_NUM_MSGS];
157        let recv = recv_mmsg_exact(&reader, &mut packets[..]).await.unwrap();
158        assert_eq!(TEST_NUM_MSGS, recv);
159        for packet in packets.iter().take(recv) {
160            assert_eq!(packet.meta().size, PACKET_DATA_SIZE);
161            assert_eq!(packet.meta().socket_addr(), saddr);
162        }
163
164        packets
165            .iter_mut()
166            .for_each(|pkt| *pkt.meta_mut() = Meta::default());
167        let _recv = recv_mmsg(&reader, &mut packets[..]).await;
168        assert!(start.elapsed().as_secs() < 5);
169    }
170
171    #[tokio::test]
172    async fn test_recv_mmsg_multi_addrs() {
173        let reader = bind_to_localhost_async().await.expect("bind");
174        let addr = reader.local_addr().unwrap();
175
176        let sender1 = bind_to_localhost_async().await.expect("bind");
177        let saddr1 = sender1.local_addr().unwrap();
178        let sent1 = TEST_NUM_MSGS - 1;
179
180        let sender2 = bind_to_localhost_async().await.expect("bind");
181        let saddr2 = sender2.local_addr().unwrap();
182        let sent2 = TEST_NUM_MSGS + 1;
183
184        for _ in 0..sent1 {
185            let data = [0; PACKET_DATA_SIZE];
186            sender1.send_to(&data[..], &addr).await.unwrap();
187        }
188
189        for _ in 0..sent2 {
190            let data = [0; PACKET_DATA_SIZE];
191            sender2.send_to(&data[..], &addr).await.unwrap();
192        }
193
194        let mut packets = vec![Packet::default(); TEST_NUM_MSGS];
195
196        let recv = recv_mmsg(&reader, &mut packets[..]).await.unwrap();
197        assert_eq!(TEST_NUM_MSGS, recv);
198        for packet in packets.iter().take(sent1) {
199            assert_eq!(packet.meta().size, PACKET_DATA_SIZE);
200            assert_eq!(packet.meta().socket_addr(), saddr1);
201        }
202        for packet in packets.iter().skip(sent1).take(recv - sent1) {
203            assert_eq!(packet.meta().size, PACKET_DATA_SIZE);
204            assert_eq!(packet.meta().socket_addr(), saddr2);
205        }
206
207        packets
208            .iter_mut()
209            .for_each(|pkt| *pkt.meta_mut() = Meta::default());
210        let recv = recv_mmsg(&reader, &mut packets[..]).await.unwrap();
211        assert_eq!(sent1 + sent2 - TEST_NUM_MSGS, recv);
212        for packet in packets.iter().take(recv) {
213            assert_eq!(packet.meta().size, PACKET_DATA_SIZE);
214            assert_eq!(packet.meta().socket_addr(), saddr2);
215        }
216    }
217}