1#[cfg(target_os = "linux")]
4use {
5 itertools::izip,
6 libc::{iovec, mmsghdr, msghdr, sockaddr_in, sockaddr_in6, sockaddr_storage, socklen_t},
7 std::{
8 mem::{self, MaybeUninit},
9 os::unix::io::AsRawFd,
10 ptr,
11 },
12};
13use {
14 solana_sdk::transport::TransportError,
15 std::{
16 borrow::Borrow,
17 io,
18 iter::repeat,
19 net::{SocketAddr, UdpSocket},
20 },
21 thiserror::Error,
22};
23
24#[derive(Debug, Error)]
25pub enum SendPktsError {
26 #[error("IO Error, some packets could not be sent")]
28 IoError(io::Error, usize),
29}
30
31impl From<SendPktsError> for TransportError {
32 fn from(err: SendPktsError) -> Self {
33 Self::Custom(format!("{err:?}"))
34 }
35}
36
37#[cfg(not(target_os = "linux"))]
38pub fn batch_send<S, T>(sock: &UdpSocket, packets: &[(T, S)]) -> Result<(), SendPktsError>
39where
40 S: Borrow<SocketAddr>,
41 T: AsRef<[u8]>,
42{
43 let mut num_failed = 0;
44 let mut erropt = None;
45 for (p, a) in packets {
46 if let Err(e) = sock.send_to(p.as_ref(), a.borrow()) {
47 num_failed += 1;
48 if erropt.is_none() {
49 erropt = Some(e);
50 }
51 }
52 }
53
54 if let Some(err) = erropt {
55 Err(SendPktsError::IoError(err, num_failed))
56 } else {
57 Ok(())
58 }
59}
60
61#[cfg(target_os = "linux")]
62fn mmsghdr_for_packet(
63 packet: &[u8],
64 dest: &SocketAddr,
65 iov: &mut MaybeUninit<iovec>,
66 addr: &mut MaybeUninit<sockaddr_storage>,
67 hdr: &mut MaybeUninit<mmsghdr>,
68) {
69 const SIZE_OF_SOCKADDR_IN: usize = mem::size_of::<sockaddr_in>();
70 const SIZE_OF_SOCKADDR_IN6: usize = mem::size_of::<sockaddr_in6>();
71
72 iov.write(iovec {
73 iov_base: packet.as_ptr() as *mut libc::c_void,
74 iov_len: packet.len(),
75 });
76
77 let msg_namelen = match dest {
78 SocketAddr::V4(socket_addr_v4) => {
79 unsafe {
80 ptr::write(
81 addr.as_mut_ptr() as *mut _,
82 *nix::sys::socket::SockaddrIn::from(*socket_addr_v4).as_ref(),
83 );
84 }
85 SIZE_OF_SOCKADDR_IN as socklen_t
86 }
87 SocketAddr::V6(socket_addr_v6) => {
88 unsafe {
89 ptr::write(
90 addr.as_mut_ptr() as *mut _,
91 *nix::sys::socket::SockaddrIn6::from(*socket_addr_v6).as_ref(),
92 );
93 }
94 SIZE_OF_SOCKADDR_IN6 as socklen_t
95 }
96 };
97
98 hdr.write(mmsghdr {
99 msg_len: 0,
100 msg_hdr: msghdr {
101 msg_name: addr as *mut _ as *mut _,
102 msg_namelen,
103 msg_iov: iov.as_mut_ptr(),
104 msg_iovlen: 1,
105 msg_control: ptr::null::<libc::c_void>() as *mut _,
106 msg_controllen: 0,
107 msg_flags: 0,
108 },
109 });
110}
111
112#[cfg(target_os = "linux")]
113fn sendmmsg_retry(sock: &UdpSocket, hdrs: &mut [mmsghdr]) -> Result<(), SendPktsError> {
114 let sock_fd = sock.as_raw_fd();
115 let mut total_sent = 0;
116 let mut erropt = None;
117
118 let mut pkts = &mut *hdrs;
119 while !pkts.is_empty() {
120 let npkts = match unsafe { libc::sendmmsg(sock_fd, &mut pkts[0], pkts.len() as u32, 0) } {
121 -1 => {
122 if erropt.is_none() {
123 erropt = Some(io::Error::last_os_error());
124 }
125 1_usize
127 }
128 n => {
129 total_sent += n as usize;
132 n as usize
133 }
134 };
135 pkts = &mut pkts[npkts..];
136 }
137
138 if let Some(err) = erropt {
139 Err(SendPktsError::IoError(err, hdrs.len() - total_sent))
140 } else {
141 Ok(())
142 }
143}
144
145#[cfg(target_os = "linux")]
146pub fn batch_send<S, T>(sock: &UdpSocket, packets: &[(T, S)]) -> Result<(), SendPktsError>
147where
148 S: Borrow<SocketAddr>,
149 T: AsRef<[u8]>,
150{
151 let size = packets.len();
152 let mut iovs = vec![MaybeUninit::uninit(); size];
153 let mut addrs = vec![MaybeUninit::zeroed(); size];
154 let mut hdrs = vec![MaybeUninit::uninit(); size];
155 for ((pkt, dest), hdr, iov, addr) in izip!(packets, &mut hdrs, &mut iovs, &mut addrs) {
156 mmsghdr_for_packet(pkt.as_ref(), dest.borrow(), iov, addr, hdr);
157 }
158 let _iovs = unsafe { mem::transmute::<Vec<MaybeUninit<iovec>>, Vec<iovec>>(iovs) };
161 let _addrs = unsafe {
162 mem::transmute::<Vec<MaybeUninit<sockaddr_storage>>, Vec<sockaddr_storage>>(addrs)
163 };
164 let mut hdrs = unsafe { mem::transmute::<Vec<MaybeUninit<mmsghdr>>, Vec<mmsghdr>>(hdrs) };
165
166 sendmmsg_retry(sock, &mut hdrs)
167}
168
169pub fn multi_target_send<S, T>(
170 sock: &UdpSocket,
171 packet: T,
172 dests: &[S],
173) -> Result<(), SendPktsError>
174where
175 S: Borrow<SocketAddr>,
176 T: AsRef<[u8]>,
177{
178 let dests = dests.iter().map(Borrow::borrow);
179 let pkts: Vec<_> = repeat(&packet).zip(dests).collect();
180 batch_send(sock, &pkts)
181}
182
183#[cfg(test)]
184mod tests {
185 use {
186 crate::{
187 packet::Packet,
188 recvmmsg::recv_mmsg,
189 sendmmsg::{batch_send, multi_target_send, SendPktsError},
190 },
191 assert_matches::assert_matches,
192 solana_sdk::packet::PACKET_DATA_SIZE,
193 std::{
194 io::ErrorKind,
195 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket},
196 },
197 };
198
199 #[test]
200 pub fn test_send_mmsg_one_dest() {
201 let reader = UdpSocket::bind("127.0.0.1:0").expect("bind");
202 let addr = reader.local_addr().unwrap();
203 let sender = UdpSocket::bind("127.0.0.1:0").expect("bind");
204
205 let packets: Vec<_> = (0..32).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect();
206 let packet_refs: Vec<_> = packets.iter().map(|p| (&p[..], &addr)).collect();
207
208 let sent = batch_send(&sender, &packet_refs[..]).ok();
209 assert_eq!(sent, Some(()));
210
211 let mut packets = vec![Packet::default(); 32];
212 let recv = recv_mmsg(&reader, &mut packets[..]).unwrap();
213 assert_eq!(32, recv);
214 }
215
216 #[test]
217 pub fn test_send_mmsg_multi_dest() {
218 let reader = UdpSocket::bind("127.0.0.1:0").expect("bind");
219 let addr = reader.local_addr().unwrap();
220
221 let reader2 = UdpSocket::bind("127.0.0.1:0").expect("bind");
222 let addr2 = reader2.local_addr().unwrap();
223
224 let sender = UdpSocket::bind("127.0.0.1:0").expect("bind");
225
226 let packets: Vec<_> = (0..32).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect();
227 let packet_refs: Vec<_> = packets
228 .iter()
229 .enumerate()
230 .map(|(i, p)| {
231 if i < 16 {
232 (&p[..], &addr)
233 } else {
234 (&p[..], &addr2)
235 }
236 })
237 .collect();
238
239 let sent = batch_send(&sender, &packet_refs[..]).ok();
240 assert_eq!(sent, Some(()));
241
242 let mut packets = vec![Packet::default(); 32];
243 let recv = recv_mmsg(&reader, &mut packets[..]).unwrap();
244 assert_eq!(16, recv);
245
246 let mut packets = vec![Packet::default(); 32];
247 let recv = recv_mmsg(&reader2, &mut packets[..]).unwrap();
248 assert_eq!(16, recv);
249 }
250
251 #[test]
252 pub fn test_multicast_msg() {
253 let reader = UdpSocket::bind("127.0.0.1:0").expect("bind");
254 let addr = reader.local_addr().unwrap();
255
256 let reader2 = UdpSocket::bind("127.0.0.1:0").expect("bind");
257 let addr2 = reader2.local_addr().unwrap();
258
259 let reader3 = UdpSocket::bind("127.0.0.1:0").expect("bind");
260 let addr3 = reader3.local_addr().unwrap();
261
262 let reader4 = UdpSocket::bind("127.0.0.1:0").expect("bind");
263 let addr4 = reader4.local_addr().unwrap();
264
265 let sender = UdpSocket::bind("127.0.0.1:0").expect("bind");
266
267 let packet = Packet::default();
268
269 let sent = multi_target_send(
270 &sender,
271 packet.data(..).unwrap(),
272 &[&addr, &addr2, &addr3, &addr4],
273 )
274 .ok();
275 assert_eq!(sent, Some(()));
276
277 let mut packets = vec![Packet::default(); 32];
278 let recv = recv_mmsg(&reader, &mut packets[..]).unwrap();
279 assert_eq!(1, recv);
280
281 let mut packets = vec![Packet::default(); 32];
282 let recv = recv_mmsg(&reader2, &mut packets[..]).unwrap();
283 assert_eq!(1, recv);
284
285 let mut packets = vec![Packet::default(); 32];
286 let recv = recv_mmsg(&reader3, &mut packets[..]).unwrap();
287 assert_eq!(1, recv);
288
289 let mut packets = vec![Packet::default(); 32];
290 let recv = recv_mmsg(&reader4, &mut packets[..]).unwrap();
291 assert_eq!(1, recv);
292 }
293
294 #[test]
295 fn test_intermediate_failures_mismatched_bind() {
296 let packets: Vec<_> = (0..3).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect();
297 let ip4 = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080);
298 let ip6 = SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 8080);
299 let packet_refs: Vec<_> = vec![
300 (&packets[0][..], &ip4),
301 (&packets[1][..], &ip6),
302 (&packets[2][..], &ip4),
303 ];
304 let dest_refs: Vec<_> = vec![&ip4, &ip6, &ip4];
305
306 let sender = UdpSocket::bind("0.0.0.0:0").expect("bind");
307 let res = batch_send(&sender, &packet_refs[..]);
308 assert_matches!(res, Err(SendPktsError::IoError(_, 1)));
309 let res = multi_target_send(&sender, &packets[0], &dest_refs);
310 assert_matches!(res, Err(SendPktsError::IoError(_, 1)));
311 }
312
313 #[test]
314 fn test_intermediate_failures_unreachable_address() {
315 let packets: Vec<_> = (0..5).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect();
316 let ipv4local = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080);
317 let ipv4broadcast = SocketAddr::new(IpAddr::V4(Ipv4Addr::BROADCAST), 8080);
318 let sender = UdpSocket::bind("0.0.0.0:0").expect("bind");
319
320 let packet_refs: Vec<_> = vec![
322 (&packets[0][..], &ipv4local),
323 (&packets[1][..], &ipv4broadcast),
324 (&packets[2][..], &ipv4local),
325 (&packets[3][..], &ipv4broadcast),
326 (&packets[4][..], &ipv4local),
327 ];
328 match batch_send(&sender, &packet_refs[..]) {
329 Ok(()) => panic!(),
330 Err(SendPktsError::IoError(ioerror, num_failed)) => {
331 assert_matches!(ioerror.kind(), ErrorKind::PermissionDenied);
332 assert_eq!(num_failed, 2);
333 }
334 }
335
336 let packet_refs: Vec<_> = vec![
338 (&packets[0][..], &ipv4broadcast),
339 (&packets[1][..], &ipv4local),
340 (&packets[2][..], &ipv4broadcast),
341 (&packets[3][..], &ipv4local),
342 (&packets[4][..], &ipv4broadcast),
343 ];
344 match batch_send(&sender, &packet_refs[..]) {
345 Ok(()) => panic!(),
346 Err(SendPktsError::IoError(ioerror, num_failed)) => {
347 assert_matches!(ioerror.kind(), ErrorKind::PermissionDenied);
348 assert_eq!(num_failed, 3);
349 }
350 }
351
352 let packet_refs: Vec<_> = vec![
354 (&packets[0][..], &ipv4local),
355 (&packets[1][..], &ipv4local),
356 (&packets[2][..], &ipv4broadcast),
357 (&packets[3][..], &ipv4broadcast),
358 (&packets[4][..], &ipv4local),
359 ];
360 match batch_send(&sender, &packet_refs[..]) {
361 Ok(()) => panic!(),
362 Err(SendPktsError::IoError(ioerror, num_failed)) => {
363 assert_matches!(ioerror.kind(), ErrorKind::PermissionDenied);
364 assert_eq!(num_failed, 2);
365 }
366 }
367
368 let dest_refs: Vec<_> = vec![
370 &ipv4local,
371 &ipv4broadcast,
372 &ipv4local,
373 &ipv4broadcast,
374 &ipv4local,
375 ];
376 match multi_target_send(&sender, &packets[0], &dest_refs) {
377 Ok(()) => panic!(),
378 Err(SendPktsError::IoError(ioerror, num_failed)) => {
379 assert_matches!(ioerror.kind(), ErrorKind::PermissionDenied);
380 assert_eq!(num_failed, 2);
381 }
382 }
383
384 let dest_refs: Vec<_> = vec![
386 &ipv4broadcast,
387 &ipv4local,
388 &ipv4broadcast,
389 &ipv4local,
390 &ipv4broadcast,
391 ];
392 match multi_target_send(&sender, &packets[0], &dest_refs) {
393 Ok(()) => panic!(),
394 Err(SendPktsError::IoError(ioerror, num_failed)) => {
395 assert_matches!(ioerror.kind(), ErrorKind::PermissionDenied);
396 assert_eq!(num_failed, 3);
397 }
398 }
399 }
400}