hickory_proto/tests/
tcp.rs

1use alloc::string::ToString;
2use alloc::sync::Arc;
3use core::sync::atomic::AtomicBool;
4use std::io::{Read, Write};
5use std::net::{IpAddr, SocketAddr};
6use std::println;
7
8use futures_util::stream::StreamExt;
9
10use crate::runtime::RuntimeProvider;
11use crate::tcp::TcpStream;
12use crate::xfer::SerialMessage;
13use crate::xfer::dns_handle::DnsStreamHandle;
14
15const TEST_BYTES: &[u8; 8] = b"DEADBEEF";
16const TEST_BYTES_LEN: usize = 8;
17const SEND_RECV_TIMES: usize = 4;
18
19fn tcp_server_setup(
20    server_name: &str,
21    server_addr: IpAddr,
22) -> (Arc<AtomicBool>, std::thread::JoinHandle<()>, SocketAddr) {
23    let succeeded = Arc::new(AtomicBool::new(false));
24    let succeeded_clone = succeeded.clone();
25    std::thread::Builder::new()
26        .name("thread_killer".to_string())
27        .spawn(move || {
28            let succeeded = succeeded_clone;
29            for _ in 0..15 {
30                std::thread::sleep(core::time::Duration::from_secs(1));
31                if succeeded.load(core::sync::atomic::Ordering::Relaxed) {
32                    return;
33                }
34            }
35
36            println!("Thread Killer has been awoken, killing process");
37            std::process::exit(-1);
38        })
39        .expect("Thread spawning failed");
40
41    // TODO: need a timeout on listen
42    let server = std::net::TcpListener::bind(SocketAddr::new(server_addr, 0))
43        .expect("Unable to bind a TCP socket");
44    let server_addr = server.local_addr().unwrap();
45
46    // an in and out server
47    let server_handle = std::thread::Builder::new()
48        .name(server_name.to_string())
49        .spawn(move || {
50            let (mut socket, _) = server.accept().expect("accept failed");
51
52            socket
53                .set_read_timeout(Some(core::time::Duration::from_secs(5)))
54                .unwrap(); // should receive something within 5 seconds...
55            socket
56                .set_write_timeout(Some(core::time::Duration::from_secs(5)))
57                .unwrap(); // should receive something within 5 seconds...
58
59            for _ in 0..SEND_RECV_TIMES {
60                // wait for some bytes...
61                let mut len_bytes = [0_u8; 2];
62                socket
63                    .read_exact(&mut len_bytes)
64                    .expect("SERVER: receive failed");
65                let length =
66                    (u16::from(len_bytes[0]) << 8) & 0xFF00 | u16::from(len_bytes[1]) & 0x00FF;
67                assert_eq!(length as usize, TEST_BYTES_LEN);
68
69                let mut buffer = [0_u8; TEST_BYTES_LEN];
70                socket.read_exact(&mut buffer).unwrap();
71
72                // println!("read bytes iter: {}", i);
73                assert_eq!(&buffer, TEST_BYTES);
74
75                // bounce them right back...
76                socket
77                    .write_all(&len_bytes)
78                    .expect("SERVER: send length failed");
79                socket
80                    .write_all(&buffer)
81                    .expect("SERVER: send buffer failed");
82                // println!("wrote bytes iter: {}", i);
83                std::thread::yield_now();
84            }
85        })
86        .unwrap();
87    (succeeded, server_handle, server_addr)
88}
89
90/// Test tcp_stream.
91pub async fn tcp_stream_test(server_addr: IpAddr, provider: impl RuntimeProvider) {
92    let (succeeded, server_handle, server_addr) =
93        tcp_server_setup("test_tcp_stream:server", server_addr);
94
95    // setup the client, which is going to run on the testing thread...
96
97    let tcp = provider
98        .connect_tcp(server_addr, None, None)
99        .await
100        .expect("connect failed");
101    let (mut stream, mut sender) = TcpStream::from_stream(tcp, server_addr);
102
103    for _ in 0..SEND_RECV_TIMES {
104        // test once
105        sender
106            .send(SerialMessage::new(TEST_BYTES.to_vec(), server_addr))
107            .expect("send failed");
108
109        let (buffer, stream_tmp) = stream.into_future().await;
110        stream = stream_tmp;
111        let message = buffer
112            .expect("no buffer received")
113            .expect("error receiving buffer");
114        assert_eq!(message.bytes(), TEST_BYTES);
115    }
116
117    succeeded.store(true, core::sync::atomic::Ordering::Relaxed);
118    server_handle.join().expect("server thread failed");
119}
120
121/// Test tcp_client_stream.
122pub async fn tcp_client_stream_test(server_addr: IpAddr, provider: impl RuntimeProvider) {
123    let (succeeded, server_handle, server_addr) =
124        tcp_server_setup("test_tcp_client_stream:server", server_addr);
125
126    // setup the client, which is going to run on the testing thread...
127
128    let tcp = provider
129        .connect_tcp(server_addr, None, None)
130        .await
131        .expect("connect failed");
132    let (mut stream, mut sender) = TcpStream::from_stream(tcp, server_addr);
133
134    for _ in 0..SEND_RECV_TIMES {
135        // test once
136        sender
137            .send(SerialMessage::new(TEST_BYTES.to_vec(), server_addr))
138            .expect("send failed");
139        let (buffer, stream_tmp) = stream.into_future().await;
140        stream = stream_tmp;
141        let buffer = buffer
142            .expect("no buffer received")
143            .expect("error receiving buffer");
144        assert_eq!(buffer.bytes(), TEST_BYTES);
145    }
146
147    succeeded.store(true, core::sync::atomic::Ordering::Relaxed);
148    server_handle.join().expect("server thread failed");
149}