1use std::io;
11use std::mem;
12use std::net::SocketAddr;
13use std::pin::Pin;
14use std::task::{Context, Poll};
15use std::time::Duration;
16
17use async_trait::async_trait;
18use futures_io::{AsyncRead, AsyncWrite};
19use futures_util::stream::Stream;
20use futures_util::{self, future::Future, ready, FutureExt};
21use tracing::debug;
22
23use crate::xfer::{SerialMessage, StreamReceiver};
24use crate::BufDnsStreamHandle;
25use crate::Time;
26
27pub trait DnsTcpStream: AsyncRead + AsyncWrite + Unpin + Send + Sync + Sized + 'static {
29 type Time: Time;
31}
32
33#[async_trait]
35pub trait Connect: DnsTcpStream {
36 async fn connect(addr: SocketAddr) -> io::Result<Self> {
38 Self::connect_with_bind(addr, None).await
39 }
40
41 async fn connect_with_bind(addr: SocketAddr, bind_addr: Option<SocketAddr>)
43 -> io::Result<Self>;
44}
45
46enum WriteTcpState {
48 LenBytes {
50 pos: usize,
52 length: [u8; 2],
54 bytes: Vec<u8>,
56 },
57 Bytes {
59 pos: usize,
61 bytes: Vec<u8>,
63 },
64 Flushing,
66}
67
68pub(crate) enum ReadTcpState {
70 LenBytes {
72 pos: usize,
74 bytes: [u8; 2],
76 },
77 Bytes {
79 pos: usize,
81 bytes: Vec<u8>,
83 },
84}
85
86#[must_use = "futures do nothing unless polled"]
88pub struct TcpStream<S: DnsTcpStream> {
89 socket: S,
90 outbound_messages: StreamReceiver,
91 send_state: Option<WriteTcpState>,
92 read_state: ReadTcpState,
93 peer_addr: SocketAddr,
94}
95
96impl<S: Connect> TcpStream<S> {
97 #[allow(clippy::new_ret_no_self, clippy::type_complexity)]
105 pub fn new(
106 name_server: SocketAddr,
107 ) -> (
108 impl Future<Output = Result<Self, io::Error>> + Send,
109 BufDnsStreamHandle,
110 ) {
111 Self::with_timeout(name_server, Duration::from_secs(5))
112 }
113
114 #[allow(clippy::type_complexity)]
121 pub fn with_timeout(
122 name_server: SocketAddr,
123 timeout: Duration,
124 ) -> (
125 impl Future<Output = Result<Self, io::Error>> + Send,
126 BufDnsStreamHandle,
127 ) {
128 let (message_sender, outbound_messages) = BufDnsStreamHandle::new(name_server);
129
130 let stream_fut = Self::connect(name_server, None, timeout, outbound_messages);
133
134 (stream_fut, message_sender)
135 }
136
137 #[allow(clippy::type_complexity)]
145 pub fn with_bind_addr_and_timeout(
146 name_server: SocketAddr,
147 bind_addr: Option<SocketAddr>,
148 timeout: Duration,
149 ) -> (
150 impl Future<Output = Result<Self, io::Error>> + Send,
151 BufDnsStreamHandle,
152 ) {
153 let (message_sender, outbound_messages) = BufDnsStreamHandle::new(name_server);
154 let stream_fut = Self::connect(name_server, bind_addr, timeout, outbound_messages);
155
156 (stream_fut, message_sender)
157 }
158
159 async fn connect(
160 name_server: SocketAddr,
161 bind_addr: Option<SocketAddr>,
162 timeout: Duration,
163 outbound_messages: StreamReceiver,
164 ) -> Result<Self, io::Error> {
165 let tcp = S::connect_with_bind(name_server, bind_addr);
166 Self::connect_with_future(tcp, name_server, timeout, outbound_messages).await
167 }
168}
169
170impl<S: DnsTcpStream> TcpStream<S> {
171 pub fn peer_addr(&self) -> SocketAddr {
173 self.peer_addr
174 }
175
176 fn pollable_split(
177 &mut self,
178 ) -> (
179 &mut S,
180 &mut StreamReceiver,
181 &mut Option<WriteTcpState>,
182 &mut ReadTcpState,
183 ) {
184 (
185 &mut self.socket,
186 &mut self.outbound_messages,
187 &mut self.send_state,
188 &mut self.read_state,
189 )
190 }
191
192 pub fn from_stream(stream: S, peer_addr: SocketAddr) -> (Self, BufDnsStreamHandle) {
201 let (message_sender, outbound_messages) = BufDnsStreamHandle::new(peer_addr);
202 let stream = Self::from_stream_with_receiver(stream, peer_addr, outbound_messages);
203 (stream, message_sender)
204 }
205
206 pub fn from_stream_with_receiver(
208 socket: S,
209 peer_addr: SocketAddr,
210 outbound_messages: StreamReceiver,
211 ) -> Self {
212 Self {
213 socket,
214 outbound_messages,
215 send_state: None,
216 read_state: ReadTcpState::LenBytes {
217 pos: 0,
218 bytes: [0u8; 2],
219 },
220 peer_addr,
221 }
222 }
223
224 #[allow(clippy::type_complexity)]
232 pub fn with_future<F: Future<Output = Result<S, io::Error>> + Send + 'static>(
233 future: F,
234 name_server: SocketAddr,
235 timeout: Duration,
236 ) -> (
237 impl Future<Output = Result<Self, io::Error>> + Send,
238 BufDnsStreamHandle,
239 ) {
240 let (message_sender, outbound_messages) = BufDnsStreamHandle::new(name_server);
241 let stream_fut = Self::connect_with_future(future, name_server, timeout, outbound_messages);
242
243 (stream_fut, message_sender)
244 }
245
246 async fn connect_with_future<F: Future<Output = Result<S, io::Error>> + Send + 'static>(
247 future: F,
248 name_server: SocketAddr,
249 timeout: Duration,
250 outbound_messages: StreamReceiver,
251 ) -> Result<Self, io::Error> {
252 S::Time::timeout(timeout, future)
253 .map(move |tcp_stream: Result<Result<S, io::Error>, _>| {
254 tcp_stream
255 .and_then(|tcp_stream| tcp_stream)
256 .map(|tcp_stream| {
257 debug!("TCP connection established to: {}", name_server);
258 Self {
259 socket: tcp_stream,
260 outbound_messages,
261 send_state: None,
262 read_state: ReadTcpState::LenBytes {
263 pos: 0,
264 bytes: [0u8; 2],
265 },
266 peer_addr: name_server,
267 }
268 })
269 })
270 .await
271 }
272}
273
274impl<S: DnsTcpStream> Stream for TcpStream<S> {
275 type Item = io::Result<SerialMessage>;
276
277 #[allow(clippy::cognitive_complexity)]
278 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
279 let peer = self.peer_addr;
280 let (socket, outbound_messages, send_state, read_state) = self.pollable_split();
281 let mut socket = Pin::new(socket);
282 let mut outbound_messages = Pin::new(outbound_messages);
283
284 loop {
288 if send_state.is_some() {
290 match send_state {
292 Some(WriteTcpState::LenBytes {
293 ref mut pos,
294 ref length,
295 ..
296 }) => {
297 let wrote = ready!(socket.as_mut().poll_write(cx, &length[*pos..]))?;
298 *pos += wrote;
299 }
300 Some(WriteTcpState::Bytes {
301 ref mut pos,
302 ref bytes,
303 }) => {
304 let wrote = ready!(socket.as_mut().poll_write(cx, &bytes[*pos..]))?;
305 *pos += wrote;
306 }
307 Some(WriteTcpState::Flushing) => {
308 ready!(socket.as_mut().poll_flush(cx))?;
309 }
310 _ => (),
311 }
312
313 let current_state = send_state.take();
315
316 match current_state {
318 Some(WriteTcpState::LenBytes { pos, length, bytes }) => {
319 if pos < length.len() {
320 *send_state = Some(WriteTcpState::LenBytes { pos, length, bytes });
321 } else {
322 *send_state = Some(WriteTcpState::Bytes { pos: 0, bytes });
323 }
324 }
325 Some(WriteTcpState::Bytes { pos, bytes }) => {
326 if pos < bytes.len() {
327 *send_state = Some(WriteTcpState::Bytes { pos, bytes });
328 } else {
329 *send_state = Some(WriteTcpState::Flushing);
332 }
333 }
334 Some(WriteTcpState::Flushing) => {
335 send_state.take();
337 }
338 None => (),
339 };
340 } else {
341 match outbound_messages.as_mut().poll_next(cx)
343 {
345 Poll::Ready(Some(message)) => {
347 let (buffer, dst) = message.into();
349
350 if peer != dst {
353 return Poll::Ready(Some(Err(io::Error::new(
354 io::ErrorKind::InvalidData,
355 format!("mismatched peer: {peer} and dst: {dst}"),
356 ))));
357 }
358
359 let len = u16::to_be_bytes(buffer.len() as u16);
362
363 debug!("sending message len: {} to: {}", buffer.len(), dst);
364 *send_state = Some(WriteTcpState::LenBytes {
365 pos: 0,
366 length: len,
367 bytes: buffer,
368 });
369 }
370 Poll::Pending => break,
373 Poll::Ready(None) => {
374 debug!("no messages to send");
375 break;
376 }
377 }
378 }
379 }
380
381 let mut ret_buf: Option<Vec<u8>> = None;
382
383 while ret_buf.is_none() {
386 let new_state: Option<ReadTcpState> = match read_state {
389 ReadTcpState::LenBytes {
390 ref mut pos,
391 ref mut bytes,
392 } => {
393 let read = ready!(socket.as_mut().poll_read(cx, &mut bytes[*pos..]))?;
395 if read == 0 {
396 debug!("zero bytes read, stream closed?");
398 if *pos == 0 {
401 return Poll::Ready(None);
403 } else {
404 return Poll::Ready(Some(Err(io::Error::new(
405 io::ErrorKind::BrokenPipe,
406 "closed while reading length",
407 ))));
408 }
409 }
410 debug!("in ReadTcpState::LenBytes: {}", pos);
411 *pos += read;
412
413 if *pos < bytes.len() {
414 debug!("remain ReadTcpState::LenBytes: {}", pos);
415 None
416 } else {
417 let length = u16::from_be_bytes(*bytes);
418 debug!("got length: {}", length);
419 let mut bytes = vec![0; length as usize];
420 bytes.resize(length as usize, 0);
421
422 debug!("move ReadTcpState::Bytes: {}", bytes.len());
423 Some(ReadTcpState::Bytes { pos: 0, bytes })
424 }
425 }
426 ReadTcpState::Bytes {
427 ref mut pos,
428 ref mut bytes,
429 } => {
430 let read = ready!(socket.as_mut().poll_read(cx, &mut bytes[*pos..]))?;
431 if read == 0 {
432 debug!("zero bytes read for message, stream closed?");
434
435 return Poll::Ready(Some(Err(io::Error::new(
438 io::ErrorKind::BrokenPipe,
439 "closed while reading message",
440 ))));
441 }
442
443 debug!("in ReadTcpState::Bytes: {}", bytes.len());
444 *pos += read;
445
446 if *pos < bytes.len() {
447 debug!("remain ReadTcpState::Bytes: {}", bytes.len());
448 None
449 } else {
450 debug!("reset ReadTcpState::LenBytes: {}", 0);
451 Some(ReadTcpState::LenBytes {
452 pos: 0,
453 bytes: [0u8; 2],
454 })
455 }
456 }
457 };
458
459 if let Some(state) = new_state {
462 if let ReadTcpState::Bytes { pos, bytes } = mem::replace(read_state, state) {
463 debug!("returning bytes");
464 assert_eq!(pos, bytes.len());
465 ret_buf = Some(bytes);
466 }
467 }
468 }
469
470 if let Some(buffer) = ret_buf {
472 debug!("returning buffer");
473 let src_addr = self.peer_addr;
474 Poll::Ready(Some(Ok(SerialMessage::new(buffer, src_addr))))
475 } else {
476 debug!("bottomed out");
477 Poll::Pending
480 }
481 }
482}
483
484#[cfg(test)]
485#[cfg(feature = "tokio-runtime")]
486mod tests {
487 #[cfg(not(target_os = "linux"))]
488 use std::net::Ipv6Addr;
489 use std::net::{IpAddr, Ipv4Addr};
490 use tokio::net::TcpStream as TokioTcpStream;
491 use tokio::runtime::Runtime;
492
493 use crate::iocompat::AsyncIoTokioAsStd;
494
495 use crate::tests::tcp_stream_test;
496 #[test]
497 fn test_tcp_stream_ipv4() {
498 let io_loop = Runtime::new().expect("failed to create tokio runtime");
499 tcp_stream_test::<AsyncIoTokioAsStd<TokioTcpStream>, Runtime>(
500 IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
501 io_loop,
502 )
503 }
504
505 #[test]
506 #[cfg(not(target_os = "linux"))] fn test_tcp_stream_ipv6() {
508 let io_loop = Runtime::new().expect("failed to create tokio runtime");
509 tcp_stream_test::<AsyncIoTokioAsStd<TokioTcpStream>, Runtime>(
510 IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
511 io_loop,
512 )
513 }
514}