solana_net_utils/
ip_echo_server.rs

1use {
2    crate::{HEADER_LENGTH, IP_ECHO_SERVER_RESPONSE_LENGTH},
3    log::*,
4    serde_derive::{Deserialize, Serialize},
5    solana_sdk::deserialize_utils::default_on_eof,
6    std::{
7        io,
8        net::{IpAddr, SocketAddr},
9        num::NonZeroUsize,
10        time::Duration,
11    },
12    tokio::{
13        io::{AsyncReadExt, AsyncWriteExt},
14        net::{TcpListener, TcpStream},
15        runtime::{self, Runtime},
16        time::timeout,
17    },
18};
19
20pub type IpEchoServer = Runtime;
21
22// Enforce a minimum of two threads:
23// - One thread to monitor the TcpListener and spawn async tasks
24// - One thread to service the spawned tasks
25// The unsafe is safe because we're using a fixed, known non-zero value
26pub const MINIMUM_IP_ECHO_SERVER_THREADS: NonZeroUsize = unsafe { NonZeroUsize::new_unchecked(2) };
27// IP echo requests require little computation and come in fairly infrequently,
28// so keep the number of server workers small to avoid overhead
29pub const DEFAULT_IP_ECHO_SERVER_THREADS: NonZeroUsize = MINIMUM_IP_ECHO_SERVER_THREADS;
30pub const MAX_PORT_COUNT_PER_MESSAGE: usize = 4;
31
32const IO_TIMEOUT: Duration = Duration::from_secs(5);
33
34#[derive(Serialize, Deserialize, Default, Debug)]
35pub(crate) struct IpEchoServerMessage {
36    tcp_ports: [u16; MAX_PORT_COUNT_PER_MESSAGE], // Fixed size list of ports to avoid vec serde
37    udp_ports: [u16; MAX_PORT_COUNT_PER_MESSAGE], // Fixed size list of ports to avoid vec serde
38}
39
40#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
41pub struct IpEchoServerResponse {
42    // Public IP address of request echoed back to the node.
43    pub(crate) address: IpAddr,
44    // Cluster shred-version of the node running the server.
45    #[serde(deserialize_with = "default_on_eof")]
46    pub(crate) shred_version: Option<u16>,
47}
48
49impl IpEchoServerMessage {
50    pub fn new(tcp_ports: &[u16], udp_ports: &[u16]) -> Self {
51        let mut msg = Self::default();
52        assert!(tcp_ports.len() <= msg.tcp_ports.len());
53        assert!(udp_ports.len() <= msg.udp_ports.len());
54
55        msg.tcp_ports[..tcp_ports.len()].copy_from_slice(tcp_ports);
56        msg.udp_ports[..udp_ports.len()].copy_from_slice(udp_ports);
57        msg
58    }
59}
60
61pub(crate) fn ip_echo_server_request_length() -> usize {
62    const REQUEST_TERMINUS_LENGTH: usize = 1;
63    HEADER_LENGTH
64        + bincode::serialized_size(&IpEchoServerMessage::default()).unwrap() as usize
65        + REQUEST_TERMINUS_LENGTH
66}
67
68async fn process_connection(
69    mut socket: TcpStream,
70    peer_addr: SocketAddr,
71    shred_version: Option<u16>,
72) -> io::Result<()> {
73    info!("connection from {:?}", peer_addr);
74
75    let mut data = vec![0u8; ip_echo_server_request_length()];
76
77    let mut writer = {
78        let (mut reader, writer) = socket.split();
79        let _ = timeout(IO_TIMEOUT, reader.read_exact(&mut data)).await??;
80        writer
81    };
82
83    let request_header: String = data[0..HEADER_LENGTH].iter().map(|b| *b as char).collect();
84    if request_header != "\0\0\0\0" {
85        // Explicitly check for HTTP GET/POST requests to more gracefully handle
86        // the case where a user accidentally tried to use a gossip entrypoint in
87        // place of a JSON RPC URL:
88        if request_header == "GET " || request_header == "POST" {
89            // Send HTTP error response
90            timeout(
91                IO_TIMEOUT,
92                writer.write_all(b"HTTP/1.1 400 Bad Request\nContent-length: 0\n\n"),
93            )
94            .await??;
95            return Ok(());
96        }
97        return Err(io::Error::new(
98            io::ErrorKind::Other,
99            format!("Bad request header: {request_header}"),
100        ));
101    }
102
103    let msg =
104        bincode::deserialize::<IpEchoServerMessage>(&data[HEADER_LENGTH..]).map_err(|err| {
105            io::Error::new(
106                io::ErrorKind::Other,
107                format!("Failed to deserialize IpEchoServerMessage: {err:?}"),
108            )
109        })?;
110
111    trace!("request: {:?}", msg);
112
113    // Fire a datagram at each non-zero UDP port
114    match std::net::UdpSocket::bind("0.0.0.0:0") {
115        Ok(udp_socket) => {
116            for udp_port in &msg.udp_ports {
117                if *udp_port != 0 {
118                    match udp_socket.send_to(&[0], SocketAddr::from((peer_addr.ip(), *udp_port))) {
119                        Ok(_) => debug!("Successful send_to udp/{}", udp_port),
120                        Err(err) => info!("Failed to send_to udp/{}: {}", udp_port, err),
121                    }
122                }
123            }
124        }
125        Err(err) => {
126            warn!("Failed to bind local udp socket: {}", err);
127        }
128    }
129
130    // Try to connect to each non-zero TCP port
131    for tcp_port in &msg.tcp_ports {
132        if *tcp_port != 0 {
133            debug!("Connecting to tcp/{}", tcp_port);
134
135            let mut tcp_stream = timeout(
136                IO_TIMEOUT,
137                TcpStream::connect(&SocketAddr::new(peer_addr.ip(), *tcp_port)),
138            )
139            .await??;
140
141            debug!("Connection established to tcp/{}", *tcp_port);
142            tcp_stream.shutdown().await?;
143        }
144    }
145    let response = IpEchoServerResponse {
146        address: peer_addr.ip(),
147        shred_version,
148    };
149    // "\0\0\0\0" header is added to ensure a valid response will never
150    // conflict with the first four bytes of a valid HTTP response.
151    let mut bytes = vec![0u8; IP_ECHO_SERVER_RESPONSE_LENGTH];
152    bincode::serialize_into(&mut bytes[HEADER_LENGTH..], &response).unwrap();
153    trace!("response: {:?}", bytes);
154    writer.write_all(&bytes).await
155}
156
157async fn run_echo_server(tcp_listener: std::net::TcpListener, shred_version: Option<u16>) {
158    info!("bound to {:?}", tcp_listener.local_addr().unwrap());
159    let tcp_listener =
160        TcpListener::from_std(tcp_listener).expect("Failed to convert std::TcpListener");
161
162    loop {
163        match tcp_listener.accept().await {
164            Ok((socket, peer_addr)) => {
165                runtime::Handle::current().spawn(async move {
166                    if let Err(err) = process_connection(socket, peer_addr, shred_version).await {
167                        info!("session failed: {:?}", err);
168                    }
169                });
170            }
171            Err(err) => warn!("listener accept failed: {:?}", err),
172        }
173    }
174}
175
176/// Starts a simple TCP server on the given port that echos the IP address of any peer that
177/// connects.  Used by |get_public_ip_addr|
178pub fn ip_echo_server(
179    tcp_listener: std::net::TcpListener,
180    num_server_threads: NonZeroUsize,
181    // Cluster shred-version of the node running the server.
182    shred_version: Option<u16>,
183) -> IpEchoServer {
184    tcp_listener.set_nonblocking(true).unwrap();
185
186    let runtime = tokio::runtime::Builder::new_multi_thread()
187        .thread_name("solIpEchoSrvrRt")
188        .worker_threads(num_server_threads.get())
189        .enable_all()
190        .build()
191        .expect("new tokio runtime");
192    runtime.spawn(run_echo_server(tcp_listener, shred_version));
193    runtime
194}