hickory_proto/tcp/
tcp_stream.rs

1// Copyright 2015-2016 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! This module contains all the TCP structures for demuxing TCP into streams of DNS packets.
9
10use alloc::vec::Vec;
11use core::mem;
12use core::pin::Pin;
13use core::task::{Context, Poll};
14use core::time::Duration;
15use std::io;
16use std::net::SocketAddr;
17
18use futures_io::{AsyncRead, AsyncWrite};
19use futures_util::stream::Stream;
20use futures_util::{self, FutureExt, future::Future, ready};
21use tracing::debug;
22
23use crate::BufDnsStreamHandle;
24use crate::runtime::Time;
25use crate::xfer::{SerialMessage, StreamReceiver};
26
27/// Trait for TCP connection
28pub trait DnsTcpStream: AsyncRead + AsyncWrite + Unpin + Send + Sync + Sized + 'static {
29    /// Timer type to use with this TCP stream type
30    type Time: Time;
31}
32
33/// Current state while writing to the remote of the TCP connection
34enum WriteTcpState {
35    /// Currently writing the length of bytes to of the buffer.
36    LenBytes {
37        /// Current position in the length buffer being written
38        pos: usize,
39        /// Length of the buffer
40        length: [u8; 2],
41        /// Buffer to write after the length
42        bytes: Vec<u8>,
43    },
44    /// Currently writing the buffer to the remote
45    Bytes {
46        /// Current position in the buffer written
47        pos: usize,
48        /// Buffer to write to the remote
49        bytes: Vec<u8>,
50    },
51    /// Currently flushing the bytes to the remote
52    Flushing,
53}
54
55/// Current state of a TCP stream as it's being read.
56pub(crate) enum ReadTcpState {
57    /// Currently reading the length of the TCP packet
58    LenBytes {
59        /// Current position in the buffer
60        pos: usize,
61        /// Buffer of the length to read
62        bytes: [u8; 2],
63    },
64    /// Currently reading the bytes of the DNS packet
65    Bytes {
66        /// Current position while reading the buffer
67        pos: usize,
68        /// buffer being read into
69        bytes: Vec<u8>,
70    },
71}
72
73/// A Stream used for sending data to and from a remote DNS endpoint (client or server).
74#[must_use = "futures do nothing unless polled"]
75pub struct TcpStream<S: DnsTcpStream> {
76    socket: S,
77    outbound_messages: StreamReceiver,
78    send_state: Option<WriteTcpState>,
79    read_state: ReadTcpState,
80    peer_addr: SocketAddr,
81}
82
83impl<S: DnsTcpStream> TcpStream<S> {
84    /// Returns the address of the peer connection.
85    pub fn peer_addr(&self) -> SocketAddr {
86        self.peer_addr
87    }
88
89    fn pollable_split(
90        &mut self,
91    ) -> (
92        &mut S,
93        &mut StreamReceiver,
94        &mut Option<WriteTcpState>,
95        &mut ReadTcpState,
96    ) {
97        (
98            &mut self.socket,
99            &mut self.outbound_messages,
100            &mut self.send_state,
101            &mut self.read_state,
102        )
103    }
104
105    /// Initializes a TcpStream.
106    ///
107    /// This is intended for use with a TcpListener and Incoming.
108    ///
109    /// # Arguments
110    ///
111    /// * `stream` - the established IO stream for communication
112    /// * `peer_addr` - sources address of the stream
113    pub fn from_stream(stream: S, peer_addr: SocketAddr) -> (Self, BufDnsStreamHandle) {
114        let (message_sender, outbound_messages) = BufDnsStreamHandle::new(peer_addr);
115        let stream = Self::from_stream_with_receiver(stream, peer_addr, outbound_messages);
116        (stream, message_sender)
117    }
118
119    /// Wraps a stream where a sender and receiver have already been established
120    pub fn from_stream_with_receiver(
121        socket: S,
122        peer_addr: SocketAddr,
123        outbound_messages: StreamReceiver,
124    ) -> Self {
125        Self {
126            socket,
127            outbound_messages,
128            send_state: None,
129            read_state: ReadTcpState::LenBytes {
130                pos: 0,
131                bytes: [0u8; 2],
132            },
133            peer_addr,
134        }
135    }
136
137    /// Creates a new future of the eventually establish a IO stream connection or fail trying
138    ///
139    /// # Arguments
140    ///
141    /// * `future` - underlying stream future which this tcp stream relies on
142    /// * `name_server` - the IP and Port of the DNS server to connect to
143    /// * `timeout` - connection timeout
144    #[allow(clippy::type_complexity)]
145    pub fn with_future<F: Future<Output = Result<S, io::Error>> + Send + 'static>(
146        future: F,
147        name_server: SocketAddr,
148        timeout: Duration,
149    ) -> (
150        impl Future<Output = Result<Self, io::Error>> + Send,
151        BufDnsStreamHandle,
152    ) {
153        let (message_sender, outbound_messages) = BufDnsStreamHandle::new(name_server);
154        let stream_fut = Self::connect_with_future(future, name_server, timeout, outbound_messages);
155
156        (stream_fut, message_sender)
157    }
158
159    async fn connect_with_future<F: Future<Output = Result<S, io::Error>> + Send + 'static>(
160        future: F,
161        name_server: SocketAddr,
162        timeout: Duration,
163        outbound_messages: StreamReceiver,
164    ) -> Result<Self, io::Error> {
165        S::Time::timeout(timeout, future)
166            .map(move |tcp_stream: Result<Result<S, io::Error>, _>| {
167                tcp_stream
168                    .and_then(|tcp_stream| tcp_stream)
169                    .map(|tcp_stream| {
170                        debug!("TCP connection established to: {}", name_server);
171                        Self {
172                            socket: tcp_stream,
173                            outbound_messages,
174                            send_state: None,
175                            read_state: ReadTcpState::LenBytes {
176                                pos: 0,
177                                bytes: [0u8; 2],
178                            },
179                            peer_addr: name_server,
180                        }
181                    })
182            })
183            .await
184    }
185}
186
187impl<S: DnsTcpStream> Stream for TcpStream<S> {
188    type Item = io::Result<SerialMessage>;
189
190    #[allow(clippy::cognitive_complexity)]
191    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
192        let peer = self.peer_addr;
193        let (socket, outbound_messages, send_state, read_state) = self.pollable_split();
194        let mut socket = Pin::new(socket);
195        let mut outbound_messages = Pin::new(outbound_messages);
196
197        // this will not accept incoming data while there is data to send
198        //  makes this self throttling.
199        // TODO: it might be interesting to try and split the sending and receiving futures.
200        loop {
201            // in the case we are sending, send it all?
202            if send_state.is_some() {
203                // sending...
204                match send_state {
205                    Some(WriteTcpState::LenBytes { pos, length, .. }) => {
206                        let wrote = ready!(socket.as_mut().poll_write(cx, &length[*pos..]))?;
207                        *pos += wrote;
208                    }
209                    Some(WriteTcpState::Bytes { pos, bytes }) => {
210                        let wrote = ready!(socket.as_mut().poll_write(cx, &bytes[*pos..]))?;
211                        *pos += wrote;
212                    }
213                    Some(WriteTcpState::Flushing) => {
214                        ready!(socket.as_mut().poll_flush(cx))?;
215                    }
216                    _ => (),
217                }
218
219                // get current state
220                let current_state = send_state.take();
221
222                // switch states
223                match current_state {
224                    Some(WriteTcpState::LenBytes { pos, length, bytes }) => {
225                        if pos < length.len() {
226                            *send_state = Some(WriteTcpState::LenBytes { pos, length, bytes });
227                        } else {
228                            *send_state = Some(WriteTcpState::Bytes { pos: 0, bytes });
229                        }
230                    }
231                    Some(WriteTcpState::Bytes { pos, bytes }) => {
232                        if pos < bytes.len() {
233                            *send_state = Some(WriteTcpState::Bytes { pos, bytes });
234                        } else {
235                            // At this point we successfully delivered the entire message.
236                            //  flush
237                            *send_state = Some(WriteTcpState::Flushing);
238                        }
239                    }
240                    Some(WriteTcpState::Flushing) => {
241                        // At this point we successfully delivered the entire message.
242                        send_state.take();
243                    }
244                    None => (),
245                };
246            } else {
247                // then see if there is more to send
248                match outbound_messages.as_mut().poll_next(cx)
249                    // .map_err(|()| io::Error::new(io::ErrorKind::Other, "unknown"))?
250                {
251                    // already handled above, here to make sure the poll() pops the next message
252                    Poll::Ready(Some(message)) => {
253                        // if there is no peer, this connection should die...
254                        let (buffer, dst) = message.into();
255
256                        // This is an error if the destination is not our peer (this is TCP after all)
257                        //  This will kill the connection...
258                        if peer != dst {
259                            return Poll::Ready(Some(Err(io::Error::new(
260                                io::ErrorKind::InvalidData,
261                                format!("mismatched peer: {peer} and dst: {dst}"),
262                            ))));
263                        }
264
265                        // will return if the socket will block
266                        // the length is 16 bits
267                        let len = u16::to_be_bytes(buffer.len() as u16);
268
269                        debug!("sending message len: {} to: {}", buffer.len(), dst);
270                        *send_state = Some(WriteTcpState::LenBytes {
271                            pos: 0,
272                            length: len,
273                            bytes: buffer,
274                        });
275                    }
276                    // now we get to drop through to the receives...
277                    // TODO: should we also return None if there are no more messages to send?
278                    Poll::Pending => break,
279                    Poll::Ready(None) => {
280                        debug!("no messages to send");
281                        break;
282                    }
283                }
284            }
285        }
286
287        let mut ret_buf: Option<Vec<u8>> = None;
288
289        // this will loop while there is data to read, or the data has been read, or an IO
290        //  event would block
291        while ret_buf.is_none() {
292            // Evaluates the next state. If None is the result, then no state change occurs,
293            //  if Some(_) is returned, then that will be used as the next state.
294            let new_state: Option<ReadTcpState> = match read_state {
295                ReadTcpState::LenBytes { pos, bytes } => {
296                    // debug!("reading length {}", bytes.len());
297                    let read = ready!(socket.as_mut().poll_read(cx, &mut bytes[*pos..]))?;
298                    if read == 0 {
299                        // the Stream was closed!
300                        debug!("zero bytes read, stream closed?");
301                        //try!(self.socket.shutdown(Shutdown::Both)); // TODO: add generic shutdown function
302
303                        if *pos == 0 {
304                            // Since this is the start of the next message, we have a clean end
305                            return Poll::Ready(None);
306                        } else {
307                            return Poll::Ready(Some(Err(io::Error::new(
308                                io::ErrorKind::BrokenPipe,
309                                "closed while reading length",
310                            ))));
311                        }
312                    }
313                    debug!("in ReadTcpState::LenBytes: {}", pos);
314                    *pos += read;
315
316                    if *pos < bytes.len() {
317                        debug!("remain ReadTcpState::LenBytes: {}", pos);
318                        None
319                    } else {
320                        let length = u16::from_be_bytes(*bytes);
321                        debug!("got length: {}", length);
322                        let mut bytes = vec![0; length as usize];
323                        bytes.resize(length as usize, 0);
324
325                        debug!("move ReadTcpState::Bytes: {}", bytes.len());
326                        Some(ReadTcpState::Bytes { pos: 0, bytes })
327                    }
328                }
329                ReadTcpState::Bytes { pos, bytes } => {
330                    let read = ready!(socket.as_mut().poll_read(cx, &mut bytes[*pos..]))?;
331                    if read == 0 {
332                        // the Stream was closed!
333                        debug!("zero bytes read for message, stream closed?");
334
335                        // Since this is the start of the next message, we have a clean end
336                        // try!(self.socket.shutdown(Shutdown::Both));  // TODO: add generic shutdown function
337                        return Poll::Ready(Some(Err(io::Error::new(
338                            io::ErrorKind::BrokenPipe,
339                            "closed while reading message",
340                        ))));
341                    }
342
343                    debug!("in ReadTcpState::Bytes: {}", bytes.len());
344                    *pos += read;
345
346                    if *pos < bytes.len() {
347                        debug!("remain ReadTcpState::Bytes: {}", bytes.len());
348                        None
349                    } else {
350                        debug!("reset ReadTcpState::LenBytes: {}", 0);
351                        Some(ReadTcpState::LenBytes {
352                            pos: 0,
353                            bytes: [0u8; 2],
354                        })
355                    }
356                }
357            };
358
359            // this will move to the next state,
360            //  if it was a completed receipt of bytes, then it will move out the bytes
361            if let Some(state) = new_state {
362                if let ReadTcpState::Bytes { pos, bytes } = mem::replace(read_state, state) {
363                    debug!("returning bytes");
364                    assert_eq!(pos, bytes.len());
365                    ret_buf = Some(bytes);
366                }
367            }
368        }
369
370        // if the buffer is ready, return it, if not we're Pending
371        if let Some(buffer) = ret_buf {
372            debug!("returning buffer");
373            let src_addr = self.peer_addr;
374            Poll::Ready(Some(Ok(SerialMessage::new(buffer, src_addr))))
375        } else {
376            debug!("bottomed out");
377            // at a minimum the outbound_messages should have been polled,
378            //  which will wake this future up later...
379            Poll::Pending
380        }
381    }
382}
383
384#[cfg(test)]
385#[cfg(feature = "tokio")]
386mod tests {
387    use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
388
389    use test_support::subscribe;
390
391    use crate::runtime::TokioRuntimeProvider;
392    use crate::tests::tcp_stream_test;
393
394    #[tokio::test]
395    async fn test_tcp_stream_ipv4() {
396        subscribe();
397        tcp_stream_test(IpAddr::V4(Ipv4Addr::LOCALHOST), TokioRuntimeProvider::new()).await;
398    }
399
400    #[tokio::test]
401    async fn test_tcp_stream_ipv6() {
402        subscribe();
403        tcp_stream_test(
404            IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
405            TokioRuntimeProvider::new(),
406        )
407        .await;
408    }
409}