hickory_proto/tcp/tcp_stream.rs
1// Copyright 2015-2016 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! This module contains all the TCP structures for demuxing TCP into streams of DNS packets.
9
10use alloc::vec::Vec;
11use core::mem;
12use core::pin::Pin;
13use core::task::{Context, Poll};
14use core::time::Duration;
15use std::io;
16use std::net::SocketAddr;
17
18use futures_io::{AsyncRead, AsyncWrite};
19use futures_util::stream::Stream;
20use futures_util::{self, FutureExt, future::Future, ready};
21use tracing::debug;
22
23use crate::BufDnsStreamHandle;
24use crate::runtime::Time;
25use crate::xfer::{SerialMessage, StreamReceiver};
26
27/// Trait for TCP connection
28pub trait DnsTcpStream: AsyncRead + AsyncWrite + Unpin + Send + Sync + Sized + 'static {
29 /// Timer type to use with this TCP stream type
30 type Time: Time;
31}
32
33/// Current state while writing to the remote of the TCP connection
34enum WriteTcpState {
35 /// Currently writing the length of bytes to of the buffer.
36 LenBytes {
37 /// Current position in the length buffer being written
38 pos: usize,
39 /// Length of the buffer
40 length: [u8; 2],
41 /// Buffer to write after the length
42 bytes: Vec<u8>,
43 },
44 /// Currently writing the buffer to the remote
45 Bytes {
46 /// Current position in the buffer written
47 pos: usize,
48 /// Buffer to write to the remote
49 bytes: Vec<u8>,
50 },
51 /// Currently flushing the bytes to the remote
52 Flushing,
53}
54
55/// Current state of a TCP stream as it's being read.
56pub(crate) enum ReadTcpState {
57 /// Currently reading the length of the TCP packet
58 LenBytes {
59 /// Current position in the buffer
60 pos: usize,
61 /// Buffer of the length to read
62 bytes: [u8; 2],
63 },
64 /// Currently reading the bytes of the DNS packet
65 Bytes {
66 /// Current position while reading the buffer
67 pos: usize,
68 /// buffer being read into
69 bytes: Vec<u8>,
70 },
71}
72
73/// A Stream used for sending data to and from a remote DNS endpoint (client or server).
74#[must_use = "futures do nothing unless polled"]
75pub struct TcpStream<S: DnsTcpStream> {
76 socket: S,
77 outbound_messages: StreamReceiver,
78 send_state: Option<WriteTcpState>,
79 read_state: ReadTcpState,
80 peer_addr: SocketAddr,
81}
82
83impl<S: DnsTcpStream> TcpStream<S> {
84 /// Returns the address of the peer connection.
85 pub fn peer_addr(&self) -> SocketAddr {
86 self.peer_addr
87 }
88
89 fn pollable_split(
90 &mut self,
91 ) -> (
92 &mut S,
93 &mut StreamReceiver,
94 &mut Option<WriteTcpState>,
95 &mut ReadTcpState,
96 ) {
97 (
98 &mut self.socket,
99 &mut self.outbound_messages,
100 &mut self.send_state,
101 &mut self.read_state,
102 )
103 }
104
105 /// Initializes a TcpStream.
106 ///
107 /// This is intended for use with a TcpListener and Incoming.
108 ///
109 /// # Arguments
110 ///
111 /// * `stream` - the established IO stream for communication
112 /// * `peer_addr` - sources address of the stream
113 pub fn from_stream(stream: S, peer_addr: SocketAddr) -> (Self, BufDnsStreamHandle) {
114 let (message_sender, outbound_messages) = BufDnsStreamHandle::new(peer_addr);
115 let stream = Self::from_stream_with_receiver(stream, peer_addr, outbound_messages);
116 (stream, message_sender)
117 }
118
119 /// Wraps a stream where a sender and receiver have already been established
120 pub fn from_stream_with_receiver(
121 socket: S,
122 peer_addr: SocketAddr,
123 outbound_messages: StreamReceiver,
124 ) -> Self {
125 Self {
126 socket,
127 outbound_messages,
128 send_state: None,
129 read_state: ReadTcpState::LenBytes {
130 pos: 0,
131 bytes: [0u8; 2],
132 },
133 peer_addr,
134 }
135 }
136
137 /// Creates a new future of the eventually establish a IO stream connection or fail trying
138 ///
139 /// # Arguments
140 ///
141 /// * `future` - underlying stream future which this tcp stream relies on
142 /// * `name_server` - the IP and Port of the DNS server to connect to
143 /// * `timeout` - connection timeout
144 #[allow(clippy::type_complexity)]
145 pub fn with_future<F: Future<Output = Result<S, io::Error>> + Send + 'static>(
146 future: F,
147 name_server: SocketAddr,
148 timeout: Duration,
149 ) -> (
150 impl Future<Output = Result<Self, io::Error>> + Send,
151 BufDnsStreamHandle,
152 ) {
153 let (message_sender, outbound_messages) = BufDnsStreamHandle::new(name_server);
154 let stream_fut = Self::connect_with_future(future, name_server, timeout, outbound_messages);
155
156 (stream_fut, message_sender)
157 }
158
159 async fn connect_with_future<F: Future<Output = Result<S, io::Error>> + Send + 'static>(
160 future: F,
161 name_server: SocketAddr,
162 timeout: Duration,
163 outbound_messages: StreamReceiver,
164 ) -> Result<Self, io::Error> {
165 S::Time::timeout(timeout, future)
166 .map(move |tcp_stream: Result<Result<S, io::Error>, _>| {
167 tcp_stream
168 .and_then(|tcp_stream| tcp_stream)
169 .map(|tcp_stream| {
170 debug!("TCP connection established to: {}", name_server);
171 Self {
172 socket: tcp_stream,
173 outbound_messages,
174 send_state: None,
175 read_state: ReadTcpState::LenBytes {
176 pos: 0,
177 bytes: [0u8; 2],
178 },
179 peer_addr: name_server,
180 }
181 })
182 })
183 .await
184 }
185}
186
187impl<S: DnsTcpStream> Stream for TcpStream<S> {
188 type Item = io::Result<SerialMessage>;
189
190 #[allow(clippy::cognitive_complexity)]
191 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
192 let peer = self.peer_addr;
193 let (socket, outbound_messages, send_state, read_state) = self.pollable_split();
194 let mut socket = Pin::new(socket);
195 let mut outbound_messages = Pin::new(outbound_messages);
196
197 // this will not accept incoming data while there is data to send
198 // makes this self throttling.
199 // TODO: it might be interesting to try and split the sending and receiving futures.
200 loop {
201 // in the case we are sending, send it all?
202 if send_state.is_some() {
203 // sending...
204 match send_state {
205 Some(WriteTcpState::LenBytes { pos, length, .. }) => {
206 let wrote = ready!(socket.as_mut().poll_write(cx, &length[*pos..]))?;
207 *pos += wrote;
208 }
209 Some(WriteTcpState::Bytes { pos, bytes }) => {
210 let wrote = ready!(socket.as_mut().poll_write(cx, &bytes[*pos..]))?;
211 *pos += wrote;
212 }
213 Some(WriteTcpState::Flushing) => {
214 ready!(socket.as_mut().poll_flush(cx))?;
215 }
216 _ => (),
217 }
218
219 // get current state
220 let current_state = send_state.take();
221
222 // switch states
223 match current_state {
224 Some(WriteTcpState::LenBytes { pos, length, bytes }) => {
225 if pos < length.len() {
226 *send_state = Some(WriteTcpState::LenBytes { pos, length, bytes });
227 } else {
228 *send_state = Some(WriteTcpState::Bytes { pos: 0, bytes });
229 }
230 }
231 Some(WriteTcpState::Bytes { pos, bytes }) => {
232 if pos < bytes.len() {
233 *send_state = Some(WriteTcpState::Bytes { pos, bytes });
234 } else {
235 // At this point we successfully delivered the entire message.
236 // flush
237 *send_state = Some(WriteTcpState::Flushing);
238 }
239 }
240 Some(WriteTcpState::Flushing) => {
241 // At this point we successfully delivered the entire message.
242 send_state.take();
243 }
244 None => (),
245 };
246 } else {
247 // then see if there is more to send
248 match outbound_messages.as_mut().poll_next(cx)
249 // .map_err(|()| io::Error::new(io::ErrorKind::Other, "unknown"))?
250 {
251 // already handled above, here to make sure the poll() pops the next message
252 Poll::Ready(Some(message)) => {
253 // if there is no peer, this connection should die...
254 let (buffer, dst) = message.into();
255
256 // This is an error if the destination is not our peer (this is TCP after all)
257 // This will kill the connection...
258 if peer != dst {
259 return Poll::Ready(Some(Err(io::Error::new(
260 io::ErrorKind::InvalidData,
261 format!("mismatched peer: {peer} and dst: {dst}"),
262 ))));
263 }
264
265 // will return if the socket will block
266 // the length is 16 bits
267 let len = u16::to_be_bytes(buffer.len() as u16);
268
269 debug!("sending message len: {} to: {}", buffer.len(), dst);
270 *send_state = Some(WriteTcpState::LenBytes {
271 pos: 0,
272 length: len,
273 bytes: buffer,
274 });
275 }
276 // now we get to drop through to the receives...
277 // TODO: should we also return None if there are no more messages to send?
278 Poll::Pending => break,
279 Poll::Ready(None) => {
280 debug!("no messages to send");
281 break;
282 }
283 }
284 }
285 }
286
287 let mut ret_buf: Option<Vec<u8>> = None;
288
289 // this will loop while there is data to read, or the data has been read, or an IO
290 // event would block
291 while ret_buf.is_none() {
292 // Evaluates the next state. If None is the result, then no state change occurs,
293 // if Some(_) is returned, then that will be used as the next state.
294 let new_state: Option<ReadTcpState> = match read_state {
295 ReadTcpState::LenBytes { pos, bytes } => {
296 // debug!("reading length {}", bytes.len());
297 let read = ready!(socket.as_mut().poll_read(cx, &mut bytes[*pos..]))?;
298 if read == 0 {
299 // the Stream was closed!
300 debug!("zero bytes read, stream closed?");
301 //try!(self.socket.shutdown(Shutdown::Both)); // TODO: add generic shutdown function
302
303 if *pos == 0 {
304 // Since this is the start of the next message, we have a clean end
305 return Poll::Ready(None);
306 } else {
307 return Poll::Ready(Some(Err(io::Error::new(
308 io::ErrorKind::BrokenPipe,
309 "closed while reading length",
310 ))));
311 }
312 }
313 debug!("in ReadTcpState::LenBytes: {}", pos);
314 *pos += read;
315
316 if *pos < bytes.len() {
317 debug!("remain ReadTcpState::LenBytes: {}", pos);
318 None
319 } else {
320 let length = u16::from_be_bytes(*bytes);
321 debug!("got length: {}", length);
322 let mut bytes = vec![0; length as usize];
323 bytes.resize(length as usize, 0);
324
325 debug!("move ReadTcpState::Bytes: {}", bytes.len());
326 Some(ReadTcpState::Bytes { pos: 0, bytes })
327 }
328 }
329 ReadTcpState::Bytes { pos, bytes } => {
330 let read = ready!(socket.as_mut().poll_read(cx, &mut bytes[*pos..]))?;
331 if read == 0 {
332 // the Stream was closed!
333 debug!("zero bytes read for message, stream closed?");
334
335 // Since this is the start of the next message, we have a clean end
336 // try!(self.socket.shutdown(Shutdown::Both)); // TODO: add generic shutdown function
337 return Poll::Ready(Some(Err(io::Error::new(
338 io::ErrorKind::BrokenPipe,
339 "closed while reading message",
340 ))));
341 }
342
343 debug!("in ReadTcpState::Bytes: {}", bytes.len());
344 *pos += read;
345
346 if *pos < bytes.len() {
347 debug!("remain ReadTcpState::Bytes: {}", bytes.len());
348 None
349 } else {
350 debug!("reset ReadTcpState::LenBytes: {}", 0);
351 Some(ReadTcpState::LenBytes {
352 pos: 0,
353 bytes: [0u8; 2],
354 })
355 }
356 }
357 };
358
359 // this will move to the next state,
360 // if it was a completed receipt of bytes, then it will move out the bytes
361 if let Some(state) = new_state {
362 if let ReadTcpState::Bytes { pos, bytes } = mem::replace(read_state, state) {
363 debug!("returning bytes");
364 assert_eq!(pos, bytes.len());
365 ret_buf = Some(bytes);
366 }
367 }
368 }
369
370 // if the buffer is ready, return it, if not we're Pending
371 if let Some(buffer) = ret_buf {
372 debug!("returning buffer");
373 let src_addr = self.peer_addr;
374 Poll::Ready(Some(Ok(SerialMessage::new(buffer, src_addr))))
375 } else {
376 debug!("bottomed out");
377 // at a minimum the outbound_messages should have been polled,
378 // which will wake this future up later...
379 Poll::Pending
380 }
381 }
382}
383
384#[cfg(test)]
385#[cfg(feature = "tokio")]
386mod tests {
387 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
388
389 use test_support::subscribe;
390
391 use crate::runtime::TokioRuntimeProvider;
392 use crate::tests::tcp_stream_test;
393
394 #[tokio::test]
395 async fn test_tcp_stream_ipv4() {
396 subscribe();
397 tcp_stream_test(IpAddr::V4(Ipv4Addr::LOCALHOST), TokioRuntimeProvider::new()).await;
398 }
399
400 #[tokio::test]
401 async fn test_tcp_stream_ipv6() {
402 subscribe();
403 tcp_stream_test(
404 IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
405 TokioRuntimeProvider::new(),
406 )
407 .await;
408 }
409}