1#[cfg(target_os = "linux")]
4#[allow(deprecated)]
5use nix::sys::socket::InetAddr;
6#[cfg(target_os = "linux")]
7use {
8 itertools::izip,
9 libc::{iovec, mmsghdr, sockaddr_in, sockaddr_in6, sockaddr_storage},
10 std::os::unix::io::AsRawFd,
11};
12use {
13 solana_sdk::transport::TransportError,
14 std::{
15 borrow::Borrow,
16 io,
17 iter::repeat,
18 net::{SocketAddr, UdpSocket},
19 },
20 thiserror::Error,
21};
22
23#[derive(Debug, Error)]
24pub enum SendPktsError {
25 #[error("IO Error, some packets could not be sent")]
27 IoError(io::Error, usize),
28}
29
30impl From<SendPktsError> for TransportError {
31 fn from(err: SendPktsError) -> Self {
32 Self::Custom(format!("{:?}", err))
33 }
34}
35
36#[cfg(not(target_os = "linux"))]
37pub fn batch_send<S, T>(sock: &UdpSocket, packets: &[(T, S)]) -> Result<(), SendPktsError>
38where
39 S: Borrow<SocketAddr>,
40 T: AsRef<[u8]>,
41{
42 let mut num_failed = 0;
43 let mut erropt = None;
44 for (p, a) in packets {
45 if let Err(e) = sock.send_to(p.as_ref(), a.borrow()) {
46 num_failed += 1;
47 if erropt.is_none() {
48 erropt = Some(e);
49 }
50 }
51 }
52
53 if let Some(err) = erropt {
54 Err(SendPktsError::IoError(err, num_failed))
55 } else {
56 Ok(())
57 }
58}
59
60#[cfg(target_os = "linux")]
61fn mmsghdr_for_packet(
62 packet: &[u8],
63 dest: &SocketAddr,
64 iov: &mut iovec,
65 addr: &mut sockaddr_storage,
66 hdr: &mut mmsghdr,
67) {
68 const SIZE_OF_SOCKADDR_IN: usize = std::mem::size_of::<sockaddr_in>();
69 const SIZE_OF_SOCKADDR_IN6: usize = std::mem::size_of::<sockaddr_in6>();
70
71 *iov = iovec {
72 iov_base: packet.as_ptr() as *mut libc::c_void,
73 iov_len: packet.len(),
74 };
75 hdr.msg_hdr.msg_iov = iov;
76 hdr.msg_hdr.msg_iovlen = 1;
77 hdr.msg_hdr.msg_name = addr as *mut _ as *mut _;
78
79 #[allow(deprecated)]
80 match InetAddr::from_std(dest) {
81 InetAddr::V4(dest) => {
82 unsafe {
83 std::ptr::write(addr as *mut _ as *mut _, dest);
84 }
85 hdr.msg_hdr.msg_namelen = SIZE_OF_SOCKADDR_IN as u32;
86 }
87 InetAddr::V6(dest) => {
88 unsafe {
89 std::ptr::write(addr as *mut _ as *mut _, dest);
90 }
91 hdr.msg_hdr.msg_namelen = SIZE_OF_SOCKADDR_IN6 as u32;
92 }
93 };
94}
95
96#[cfg(target_os = "linux")]
97fn sendmmsg_retry(sock: &UdpSocket, hdrs: &mut [mmsghdr]) -> Result<(), SendPktsError> {
98 let sock_fd = sock.as_raw_fd();
99 let mut total_sent = 0;
100 let mut erropt = None;
101
102 let mut pkts = &mut *hdrs;
103 while !pkts.is_empty() {
104 let npkts = match unsafe { libc::sendmmsg(sock_fd, &mut pkts[0], pkts.len() as u32, 0) } {
105 -1 => {
106 if erropt.is_none() {
107 erropt = Some(io::Error::last_os_error());
108 }
109 1_usize
111 }
112 n => {
113 total_sent += n as usize;
116 n as usize
117 }
118 };
119 pkts = &mut pkts[npkts..];
120 }
121
122 if let Some(err) = erropt {
123 Err(SendPktsError::IoError(err, hdrs.len() - total_sent))
124 } else {
125 Ok(())
126 }
127}
128
129#[cfg(target_os = "linux")]
130pub fn batch_send<S, T>(sock: &UdpSocket, packets: &[(T, S)]) -> Result<(), SendPktsError>
131where
132 S: Borrow<SocketAddr>,
133 T: AsRef<[u8]>,
134{
135 let size = packets.len();
136 #[allow(clippy::uninit_assumed_init)]
137 let iovec = std::mem::MaybeUninit::<iovec>::uninit();
138 let mut iovs = vec![unsafe { iovec.assume_init() }; size];
139 let mut addrs = vec![unsafe { std::mem::zeroed() }; size];
140 let mut hdrs = vec![unsafe { std::mem::zeroed() }; size];
141 for ((pkt, dest), hdr, iov, addr) in izip!(packets, &mut hdrs, &mut iovs, &mut addrs) {
142 mmsghdr_for_packet(pkt.as_ref(), dest.borrow(), iov, addr, hdr);
143 }
144 sendmmsg_retry(sock, &mut hdrs)
145}
146
147pub fn multi_target_send<S, T>(
148 sock: &UdpSocket,
149 packet: T,
150 dests: &[S],
151) -> Result<(), SendPktsError>
152where
153 S: Borrow<SocketAddr>,
154 T: AsRef<[u8]>,
155{
156 let dests = dests.iter().map(Borrow::borrow);
157 let pkts: Vec<_> = repeat(&packet).zip(dests).collect();
158 batch_send(sock, &pkts)
159}
160
161#[cfg(test)]
162mod tests {
163 use {
164 crate::{
165 packet::Packet,
166 recvmmsg::recv_mmsg,
167 sendmmsg::{batch_send, multi_target_send, SendPktsError},
168 },
169 solana_sdk::packet::PACKET_DATA_SIZE,
170 std::{
171 io::ErrorKind,
172 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket},
173 },
174 };
175
176 #[test]
177 pub fn test_send_mmsg_one_dest() {
178 let reader = UdpSocket::bind("127.0.0.1:0").expect("bind");
179 let addr = reader.local_addr().unwrap();
180 let sender = UdpSocket::bind("127.0.0.1:0").expect("bind");
181
182 let packets: Vec<_> = (0..32).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect();
183 let packet_refs: Vec<_> = packets.iter().map(|p| (&p[..], &addr)).collect();
184
185 let sent = batch_send(&sender, &packet_refs[..]).ok();
186 assert_eq!(sent, Some(()));
187
188 let mut packets = vec![Packet::default(); 32];
189 let recv = recv_mmsg(&reader, &mut packets[..]).unwrap();
190 assert_eq!(32, recv);
191 }
192
193 #[test]
194 pub fn test_send_mmsg_multi_dest() {
195 let reader = UdpSocket::bind("127.0.0.1:0").expect("bind");
196 let addr = reader.local_addr().unwrap();
197
198 let reader2 = UdpSocket::bind("127.0.0.1:0").expect("bind");
199 let addr2 = reader2.local_addr().unwrap();
200
201 let sender = UdpSocket::bind("127.0.0.1:0").expect("bind");
202
203 let packets: Vec<_> = (0..32).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect();
204 let packet_refs: Vec<_> = packets
205 .iter()
206 .enumerate()
207 .map(|(i, p)| {
208 if i < 16 {
209 (&p[..], &addr)
210 } else {
211 (&p[..], &addr2)
212 }
213 })
214 .collect();
215
216 let sent = batch_send(&sender, &packet_refs[..]).ok();
217 assert_eq!(sent, Some(()));
218
219 let mut packets = vec![Packet::default(); 32];
220 let recv = recv_mmsg(&reader, &mut packets[..]).unwrap();
221 assert_eq!(16, recv);
222
223 let mut packets = vec![Packet::default(); 32];
224 let recv = recv_mmsg(&reader2, &mut packets[..]).unwrap();
225 assert_eq!(16, recv);
226 }
227
228 #[test]
229 pub fn test_multicast_msg() {
230 let reader = UdpSocket::bind("127.0.0.1:0").expect("bind");
231 let addr = reader.local_addr().unwrap();
232
233 let reader2 = UdpSocket::bind("127.0.0.1:0").expect("bind");
234 let addr2 = reader2.local_addr().unwrap();
235
236 let reader3 = UdpSocket::bind("127.0.0.1:0").expect("bind");
237 let addr3 = reader3.local_addr().unwrap();
238
239 let reader4 = UdpSocket::bind("127.0.0.1:0").expect("bind");
240 let addr4 = reader4.local_addr().unwrap();
241
242 let sender = UdpSocket::bind("127.0.0.1:0").expect("bind");
243
244 let packet = Packet::default();
245
246 let sent = multi_target_send(
247 &sender,
248 packet.data(..).unwrap(),
249 &[&addr, &addr2, &addr3, &addr4],
250 )
251 .ok();
252 assert_eq!(sent, Some(()));
253
254 let mut packets = vec![Packet::default(); 32];
255 let recv = recv_mmsg(&reader, &mut packets[..]).unwrap();
256 assert_eq!(1, recv);
257
258 let mut packets = vec![Packet::default(); 32];
259 let recv = recv_mmsg(&reader2, &mut packets[..]).unwrap();
260 assert_eq!(1, recv);
261
262 let mut packets = vec![Packet::default(); 32];
263 let recv = recv_mmsg(&reader3, &mut packets[..]).unwrap();
264 assert_eq!(1, recv);
265
266 let mut packets = vec![Packet::default(); 32];
267 let recv = recv_mmsg(&reader4, &mut packets[..]).unwrap();
268 assert_eq!(1, recv);
269 }
270
271 #[test]
272 fn test_intermediate_failures_mismatched_bind() {
273 let packets: Vec<_> = (0..3).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect();
274 let ip4 = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080);
275 let ip6 = SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 8080);
276 let packet_refs: Vec<_> = vec![
277 (&packets[0][..], &ip4),
278 (&packets[1][..], &ip6),
279 (&packets[2][..], &ip4),
280 ];
281 let dest_refs: Vec<_> = vec![&ip4, &ip6, &ip4];
282
283 let sender = UdpSocket::bind("0.0.0.0:0").expect("bind");
284 if let Err(SendPktsError::IoError(_, num_failed)) = batch_send(&sender, &packet_refs[..]) {
285 assert_eq!(num_failed, 1);
286 }
287 if let Err(SendPktsError::IoError(_, num_failed)) =
288 multi_target_send(&sender, &packets[0], &dest_refs)
289 {
290 assert_eq!(num_failed, 1);
291 }
292 }
293
294 #[test]
295 fn test_intermediate_failures_unreachable_address() {
296 let packets: Vec<_> = (0..5).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect();
297 let ipv4local = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080);
298 let ipv4broadcast = SocketAddr::new(IpAddr::V4(Ipv4Addr::BROADCAST), 8080);
299 let sender = UdpSocket::bind("0.0.0.0:0").expect("bind");
300
301 let packet_refs: Vec<_> = vec![
303 (&packets[0][..], &ipv4local),
304 (&packets[1][..], &ipv4broadcast),
305 (&packets[2][..], &ipv4local),
306 (&packets[3][..], &ipv4broadcast),
307 (&packets[4][..], &ipv4local),
308 ];
309 if let Err(SendPktsError::IoError(ioerror, num_failed)) =
310 batch_send(&sender, &packet_refs[..])
311 {
312 assert!(matches!(ioerror.kind(), ErrorKind::PermissionDenied));
313 assert_eq!(num_failed, 2);
314 }
315
316 let packet_refs: Vec<_> = vec![
318 (&packets[0][..], &ipv4broadcast),
319 (&packets[1][..], &ipv4local),
320 (&packets[2][..], &ipv4broadcast),
321 (&packets[3][..], &ipv4local),
322 (&packets[4][..], &ipv4broadcast),
323 ];
324 if let Err(SendPktsError::IoError(ioerror, num_failed)) =
325 batch_send(&sender, &packet_refs[..])
326 {
327 assert!(matches!(ioerror.kind(), ErrorKind::PermissionDenied));
328 assert_eq!(num_failed, 3);
329 }
330
331 let packet_refs: Vec<_> = vec![
333 (&packets[0][..], &ipv4local),
334 (&packets[1][..], &ipv4local),
335 (&packets[2][..], &ipv4broadcast),
336 (&packets[3][..], &ipv4broadcast),
337 (&packets[4][..], &ipv4local),
338 ];
339 if let Err(SendPktsError::IoError(ioerror, num_failed)) =
340 batch_send(&sender, &packet_refs[..])
341 {
342 assert!(matches!(ioerror.kind(), ErrorKind::PermissionDenied));
343 assert_eq!(num_failed, 2);
344 }
345
346 let dest_refs: Vec<_> = vec![
348 &ipv4local,
349 &ipv4broadcast,
350 &ipv4local,
351 &ipv4broadcast,
352 &ipv4local,
353 ];
354 if let Err(SendPktsError::IoError(ioerror, num_failed)) =
355 multi_target_send(&sender, &packets[0], &dest_refs)
356 {
357 assert!(matches!(ioerror.kind(), ErrorKind::PermissionDenied));
358 assert_eq!(num_failed, 2);
359 }
360
361 let dest_refs: Vec<_> = vec![
363 &ipv4broadcast,
364 &ipv4local,
365 &ipv4broadcast,
366 &ipv4local,
367 &ipv4broadcast,
368 ];
369 if let Err(SendPktsError::IoError(ioerror, num_failed)) =
370 multi_target_send(&sender, &packets[0], &dest_refs)
371 {
372 assert!(matches!(ioerror.kind(), ErrorKind::PermissionDenied));
373 assert_eq!(num_failed, 3);
374 }
375 }
376}