hickory_proto/tcp/
tcp_stream.rs

1// Copyright 2015-2016 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! This module contains all the TCP structures for demuxing TCP into streams of DNS packets.
9
10use std::io;
11use std::mem;
12use std::net::SocketAddr;
13use std::pin::Pin;
14use std::task::{Context, Poll};
15use std::time::Duration;
16
17use async_trait::async_trait;
18use futures_io::{AsyncRead, AsyncWrite};
19use futures_util::stream::Stream;
20use futures_util::{self, future::Future, ready, FutureExt};
21use tracing::debug;
22
23use crate::xfer::{SerialMessage, StreamReceiver};
24use crate::BufDnsStreamHandle;
25use crate::Time;
26
27/// Trait for TCP connection
28pub trait DnsTcpStream: AsyncRead + AsyncWrite + Unpin + Send + Sync + Sized + 'static {
29    /// Timer type to use with this TCP stream type
30    type Time: Time;
31}
32
33/// Trait for TCP connection
34#[async_trait]
35pub trait Connect: DnsTcpStream {
36    /// connect to tcp
37    async fn connect(addr: SocketAddr) -> io::Result<Self> {
38        Self::connect_with_bind(addr, None).await
39    }
40
41    /// connect to tcp with address to connect from
42    async fn connect_with_bind(addr: SocketAddr, bind_addr: Option<SocketAddr>)
43        -> io::Result<Self>;
44}
45
46/// Current state while writing to the remote of the TCP connection
47enum WriteTcpState {
48    /// Currently writing the length of bytes to of the buffer.
49    LenBytes {
50        /// Current position in the length buffer being written
51        pos: usize,
52        /// Length of the buffer
53        length: [u8; 2],
54        /// Buffer to write after the length
55        bytes: Vec<u8>,
56    },
57    /// Currently writing the buffer to the remote
58    Bytes {
59        /// Current position in the buffer written
60        pos: usize,
61        /// Buffer to write to the remote
62        bytes: Vec<u8>,
63    },
64    /// Currently flushing the bytes to the remote
65    Flushing,
66}
67
68/// Current state of a TCP stream as it's being read.
69pub(crate) enum ReadTcpState {
70    /// Currently reading the length of the TCP packet
71    LenBytes {
72        /// Current position in the buffer
73        pos: usize,
74        /// Buffer of the length to read
75        bytes: [u8; 2],
76    },
77    /// Currently reading the bytes of the DNS packet
78    Bytes {
79        /// Current position while reading the buffer
80        pos: usize,
81        /// buffer being read into
82        bytes: Vec<u8>,
83    },
84}
85
86/// A Stream used for sending data to and from a remote DNS endpoint (client or server).
87#[must_use = "futures do nothing unless polled"]
88pub struct TcpStream<S: DnsTcpStream> {
89    socket: S,
90    outbound_messages: StreamReceiver,
91    send_state: Option<WriteTcpState>,
92    read_state: ReadTcpState,
93    peer_addr: SocketAddr,
94}
95
96impl<S: Connect> TcpStream<S> {
97    /// Creates a new future of the eventually establish a IO stream connection or fail trying.
98    ///
99    /// Defaults to a 5 second timeout
100    ///
101    /// # Arguments
102    ///
103    /// * `name_server` - the IP and Port of the DNS server to connect to
104    #[allow(clippy::new_ret_no_self, clippy::type_complexity)]
105    pub fn new(
106        name_server: SocketAddr,
107    ) -> (
108        impl Future<Output = Result<Self, io::Error>> + Send,
109        BufDnsStreamHandle,
110    ) {
111        Self::with_timeout(name_server, Duration::from_secs(5))
112    }
113
114    /// Creates a new future of the eventually establish a IO stream connection or fail trying
115    ///
116    /// # Arguments
117    ///
118    /// * `name_server` - the IP and Port of the DNS server to connect to
119    /// * `timeout` - connection timeout
120    #[allow(clippy::type_complexity)]
121    pub fn with_timeout(
122        name_server: SocketAddr,
123        timeout: Duration,
124    ) -> (
125        impl Future<Output = Result<Self, io::Error>> + Send,
126        BufDnsStreamHandle,
127    ) {
128        let (message_sender, outbound_messages) = BufDnsStreamHandle::new(name_server);
129
130        // This set of futures collapses the next tcp socket into a stream which can be used for
131        //  sending and receiving tcp packets.
132        let stream_fut = Self::connect(name_server, None, timeout, outbound_messages);
133
134        (stream_fut, message_sender)
135    }
136
137    /// Creates a new future of the eventually establish a IO stream connection or fail trying
138    ///
139    /// # Arguments
140    ///
141    /// * `name_server` - the IP and Port of the DNS server to connect to
142    /// * `bind_addr` - the IP and port to connect from
143    /// * `timeout` - connection timeout
144    #[allow(clippy::type_complexity)]
145    pub fn with_bind_addr_and_timeout(
146        name_server: SocketAddr,
147        bind_addr: Option<SocketAddr>,
148        timeout: Duration,
149    ) -> (
150        impl Future<Output = Result<Self, io::Error>> + Send,
151        BufDnsStreamHandle,
152    ) {
153        let (message_sender, outbound_messages) = BufDnsStreamHandle::new(name_server);
154        let stream_fut = Self::connect(name_server, bind_addr, timeout, outbound_messages);
155
156        (stream_fut, message_sender)
157    }
158
159    async fn connect(
160        name_server: SocketAddr,
161        bind_addr: Option<SocketAddr>,
162        timeout: Duration,
163        outbound_messages: StreamReceiver,
164    ) -> Result<Self, io::Error> {
165        let tcp = S::connect_with_bind(name_server, bind_addr);
166        Self::connect_with_future(tcp, name_server, timeout, outbound_messages).await
167    }
168}
169
170impl<S: DnsTcpStream> TcpStream<S> {
171    /// Returns the address of the peer connection.
172    pub fn peer_addr(&self) -> SocketAddr {
173        self.peer_addr
174    }
175
176    fn pollable_split(
177        &mut self,
178    ) -> (
179        &mut S,
180        &mut StreamReceiver,
181        &mut Option<WriteTcpState>,
182        &mut ReadTcpState,
183    ) {
184        (
185            &mut self.socket,
186            &mut self.outbound_messages,
187            &mut self.send_state,
188            &mut self.read_state,
189        )
190    }
191
192    /// Initializes a TcpStream.
193    ///
194    /// This is intended for use with a TcpListener and Incoming.
195    ///
196    /// # Arguments
197    ///
198    /// * `stream` - the established IO stream for communication
199    /// * `peer_addr` - sources address of the stream
200    pub fn from_stream(stream: S, peer_addr: SocketAddr) -> (Self, BufDnsStreamHandle) {
201        let (message_sender, outbound_messages) = BufDnsStreamHandle::new(peer_addr);
202        let stream = Self::from_stream_with_receiver(stream, peer_addr, outbound_messages);
203        (stream, message_sender)
204    }
205
206    /// Wraps a stream where a sender and receiver have already been established
207    pub fn from_stream_with_receiver(
208        socket: S,
209        peer_addr: SocketAddr,
210        outbound_messages: StreamReceiver,
211    ) -> Self {
212        Self {
213            socket,
214            outbound_messages,
215            send_state: None,
216            read_state: ReadTcpState::LenBytes {
217                pos: 0,
218                bytes: [0u8; 2],
219            },
220            peer_addr,
221        }
222    }
223
224    /// Creates a new future of the eventually establish a IO stream connection or fail trying
225    ///
226    /// # Arguments
227    ///
228    /// * `future` - underlying stream future which this tcp stream relies on
229    /// * `name_server` - the IP and Port of the DNS server to connect to
230    /// * `timeout` - connection timeout
231    #[allow(clippy::type_complexity)]
232    pub fn with_future<F: Future<Output = Result<S, io::Error>> + Send + 'static>(
233        future: F,
234        name_server: SocketAddr,
235        timeout: Duration,
236    ) -> (
237        impl Future<Output = Result<Self, io::Error>> + Send,
238        BufDnsStreamHandle,
239    ) {
240        let (message_sender, outbound_messages) = BufDnsStreamHandle::new(name_server);
241        let stream_fut = Self::connect_with_future(future, name_server, timeout, outbound_messages);
242
243        (stream_fut, message_sender)
244    }
245
246    async fn connect_with_future<F: Future<Output = Result<S, io::Error>> + Send + 'static>(
247        future: F,
248        name_server: SocketAddr,
249        timeout: Duration,
250        outbound_messages: StreamReceiver,
251    ) -> Result<Self, io::Error> {
252        S::Time::timeout(timeout, future)
253            .map(move |tcp_stream: Result<Result<S, io::Error>, _>| {
254                tcp_stream
255                    .and_then(|tcp_stream| tcp_stream)
256                    .map(|tcp_stream| {
257                        debug!("TCP connection established to: {}", name_server);
258                        Self {
259                            socket: tcp_stream,
260                            outbound_messages,
261                            send_state: None,
262                            read_state: ReadTcpState::LenBytes {
263                                pos: 0,
264                                bytes: [0u8; 2],
265                            },
266                            peer_addr: name_server,
267                        }
268                    })
269            })
270            .await
271    }
272}
273
274impl<S: DnsTcpStream> Stream for TcpStream<S> {
275    type Item = io::Result<SerialMessage>;
276
277    #[allow(clippy::cognitive_complexity)]
278    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
279        let peer = self.peer_addr;
280        let (socket, outbound_messages, send_state, read_state) = self.pollable_split();
281        let mut socket = Pin::new(socket);
282        let mut outbound_messages = Pin::new(outbound_messages);
283
284        // this will not accept incoming data while there is data to send
285        //  makes this self throttling.
286        // TODO: it might be interesting to try and split the sending and receiving futures.
287        loop {
288            // in the case we are sending, send it all?
289            if send_state.is_some() {
290                // sending...
291                match send_state {
292                    Some(WriteTcpState::LenBytes {
293                        ref mut pos,
294                        ref length,
295                        ..
296                    }) => {
297                        let wrote = ready!(socket.as_mut().poll_write(cx, &length[*pos..]))?;
298                        *pos += wrote;
299                    }
300                    Some(WriteTcpState::Bytes {
301                        ref mut pos,
302                        ref bytes,
303                    }) => {
304                        let wrote = ready!(socket.as_mut().poll_write(cx, &bytes[*pos..]))?;
305                        *pos += wrote;
306                    }
307                    Some(WriteTcpState::Flushing) => {
308                        ready!(socket.as_mut().poll_flush(cx))?;
309                    }
310                    _ => (),
311                }
312
313                // get current state
314                let current_state = send_state.take();
315
316                // switch states
317                match current_state {
318                    Some(WriteTcpState::LenBytes { pos, length, bytes }) => {
319                        if pos < length.len() {
320                            *send_state = Some(WriteTcpState::LenBytes { pos, length, bytes });
321                        } else {
322                            *send_state = Some(WriteTcpState::Bytes { pos: 0, bytes });
323                        }
324                    }
325                    Some(WriteTcpState::Bytes { pos, bytes }) => {
326                        if pos < bytes.len() {
327                            *send_state = Some(WriteTcpState::Bytes { pos, bytes });
328                        } else {
329                            // At this point we successfully delivered the entire message.
330                            //  flush
331                            *send_state = Some(WriteTcpState::Flushing);
332                        }
333                    }
334                    Some(WriteTcpState::Flushing) => {
335                        // At this point we successfully delivered the entire message.
336                        send_state.take();
337                    }
338                    None => (),
339                };
340            } else {
341                // then see if there is more to send
342                match outbound_messages.as_mut().poll_next(cx)
343                    // .map_err(|()| io::Error::new(io::ErrorKind::Other, "unknown"))?
344                {
345                    // already handled above, here to make sure the poll() pops the next message
346                    Poll::Ready(Some(message)) => {
347                        // if there is no peer, this connection should die...
348                        let (buffer, dst) = message.into();
349
350                        // This is an error if the destination is not our peer (this is TCP after all)
351                        //  This will kill the connection...
352                        if peer != dst {
353                            return Poll::Ready(Some(Err(io::Error::new(
354                                io::ErrorKind::InvalidData,
355                                format!("mismatched peer: {peer} and dst: {dst}"),
356                            ))));
357                        }
358
359                        // will return if the socket will block
360                        // the length is 16 bits
361                        let len = u16::to_be_bytes(buffer.len() as u16);
362
363                        debug!("sending message len: {} to: {}", buffer.len(), dst);
364                        *send_state = Some(WriteTcpState::LenBytes {
365                            pos: 0,
366                            length: len,
367                            bytes: buffer,
368                        });
369                    }
370                    // now we get to drop through to the receives...
371                    // TODO: should we also return None if there are no more messages to send?
372                    Poll::Pending => break,
373                    Poll::Ready(None) => {
374                        debug!("no messages to send");
375                        break;
376                    }
377                }
378            }
379        }
380
381        let mut ret_buf: Option<Vec<u8>> = None;
382
383        // this will loop while there is data to read, or the data has been read, or an IO
384        //  event would block
385        while ret_buf.is_none() {
386            // Evaluates the next state. If None is the result, then no state change occurs,
387            //  if Some(_) is returned, then that will be used as the next state.
388            let new_state: Option<ReadTcpState> = match read_state {
389                ReadTcpState::LenBytes {
390                    ref mut pos,
391                    ref mut bytes,
392                } => {
393                    // debug!("reading length {}", bytes.len());
394                    let read = ready!(socket.as_mut().poll_read(cx, &mut bytes[*pos..]))?;
395                    if read == 0 {
396                        // the Stream was closed!
397                        debug!("zero bytes read, stream closed?");
398                        //try!(self.socket.shutdown(Shutdown::Both)); // TODO: add generic shutdown function
399
400                        if *pos == 0 {
401                            // Since this is the start of the next message, we have a clean end
402                            return Poll::Ready(None);
403                        } else {
404                            return Poll::Ready(Some(Err(io::Error::new(
405                                io::ErrorKind::BrokenPipe,
406                                "closed while reading length",
407                            ))));
408                        }
409                    }
410                    debug!("in ReadTcpState::LenBytes: {}", pos);
411                    *pos += read;
412
413                    if *pos < bytes.len() {
414                        debug!("remain ReadTcpState::LenBytes: {}", pos);
415                        None
416                    } else {
417                        let length = u16::from_be_bytes(*bytes);
418                        debug!("got length: {}", length);
419                        let mut bytes = vec![0; length as usize];
420                        bytes.resize(length as usize, 0);
421
422                        debug!("move ReadTcpState::Bytes: {}", bytes.len());
423                        Some(ReadTcpState::Bytes { pos: 0, bytes })
424                    }
425                }
426                ReadTcpState::Bytes {
427                    ref mut pos,
428                    ref mut bytes,
429                } => {
430                    let read = ready!(socket.as_mut().poll_read(cx, &mut bytes[*pos..]))?;
431                    if read == 0 {
432                        // the Stream was closed!
433                        debug!("zero bytes read for message, stream closed?");
434
435                        // Since this is the start of the next message, we have a clean end
436                        // try!(self.socket.shutdown(Shutdown::Both));  // TODO: add generic shutdown function
437                        return Poll::Ready(Some(Err(io::Error::new(
438                            io::ErrorKind::BrokenPipe,
439                            "closed while reading message",
440                        ))));
441                    }
442
443                    debug!("in ReadTcpState::Bytes: {}", bytes.len());
444                    *pos += read;
445
446                    if *pos < bytes.len() {
447                        debug!("remain ReadTcpState::Bytes: {}", bytes.len());
448                        None
449                    } else {
450                        debug!("reset ReadTcpState::LenBytes: {}", 0);
451                        Some(ReadTcpState::LenBytes {
452                            pos: 0,
453                            bytes: [0u8; 2],
454                        })
455                    }
456                }
457            };
458
459            // this will move to the next state,
460            //  if it was a completed receipt of bytes, then it will move out the bytes
461            if let Some(state) = new_state {
462                if let ReadTcpState::Bytes { pos, bytes } = mem::replace(read_state, state) {
463                    debug!("returning bytes");
464                    assert_eq!(pos, bytes.len());
465                    ret_buf = Some(bytes);
466                }
467            }
468        }
469
470        // if the buffer is ready, return it, if not we're Pending
471        if let Some(buffer) = ret_buf {
472            debug!("returning buffer");
473            let src_addr = self.peer_addr;
474            Poll::Ready(Some(Ok(SerialMessage::new(buffer, src_addr))))
475        } else {
476            debug!("bottomed out");
477            // at a minimum the outbound_messages should have been polled,
478            //  which will wake this future up later...
479            Poll::Pending
480        }
481    }
482}
483
484#[cfg(test)]
485#[cfg(feature = "tokio-runtime")]
486mod tests {
487    #[cfg(not(target_os = "linux"))]
488    use std::net::Ipv6Addr;
489    use std::net::{IpAddr, Ipv4Addr};
490    use tokio::net::TcpStream as TokioTcpStream;
491    use tokio::runtime::Runtime;
492
493    use crate::iocompat::AsyncIoTokioAsStd;
494
495    use crate::tests::tcp_stream_test;
496    #[test]
497    fn test_tcp_stream_ipv4() {
498        let io_loop = Runtime::new().expect("failed to create tokio runtime");
499        tcp_stream_test::<AsyncIoTokioAsStd<TokioTcpStream>, Runtime>(
500            IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
501            io_loop,
502        )
503    }
504
505    #[test]
506    #[cfg(not(target_os = "linux"))] // ignored until Travis-CI fixes IPv6
507    fn test_tcp_stream_ipv6() {
508        let io_loop = Runtime::new().expect("failed to create tokio runtime");
509        tcp_stream_test::<AsyncIoTokioAsStd<TokioTcpStream>, Runtime>(
510            IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
511            io_loop,
512        )
513    }
514}