use std::fmt::{self, Display};
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
#[cfg(feature = "tokio-runtime")]
use async_trait::async_trait;
use futures_util::{future::Future, stream::Stream, StreamExt, TryFutureExt};
use tracing::warn;
use crate::error::ProtoError;
#[cfg(feature = "tokio-runtime")]
use crate::iocompat::AsyncIoTokioAsStd;
use crate::tcp::{Connect, DnsTcpStream, TcpStream};
use crate::xfer::{DnsClientStream, SerialMessage};
use crate::BufDnsStreamHandle;
#[cfg(feature = "tokio-runtime")]
use crate::TokioTime;
#[must_use = "futures do nothing unless polled"]
pub struct TcpClientStream<S>
where
S: DnsTcpStream,
{
tcp_stream: TcpStream<S>,
}
impl<S: Connect> TcpClientStream<S> {
#[allow(clippy::new_ret_no_self)]
pub fn new(name_server: SocketAddr) -> (TcpClientConnect<S>, BufDnsStreamHandle) {
Self::with_timeout(name_server, Duration::from_secs(5))
}
pub fn with_timeout(
name_server: SocketAddr,
timeout: Duration,
) -> (TcpClientConnect<S>, BufDnsStreamHandle) {
Self::with_bind_addr_and_timeout(name_server, None, timeout)
}
#[allow(clippy::new_ret_no_self)]
pub fn with_bind_addr_and_timeout(
name_server: SocketAddr,
bind_addr: Option<SocketAddr>,
timeout: Duration,
) -> (TcpClientConnect<S>, BufDnsStreamHandle) {
let (stream_future, sender) =
TcpStream::<S>::with_bind_addr_and_timeout(name_server, bind_addr, timeout);
let new_future = Box::pin(
stream_future
.map_ok(move |tcp_stream| Self { tcp_stream })
.map_err(ProtoError::from),
);
(TcpClientConnect(new_future), sender)
}
}
impl<S: DnsTcpStream> TcpClientStream<S> {
pub fn from_stream(tcp_stream: TcpStream<S>) -> Self {
Self { tcp_stream }
}
#[allow(clippy::new_ret_no_self)]
pub fn with_future<F: Future<Output = io::Result<S>> + Send + 'static>(
future: F,
name_server: SocketAddr,
timeout: Duration,
) -> (TcpClientConnect<S>, BufDnsStreamHandle) {
let (stream_future, sender) = TcpStream::<S>::with_future(future, name_server, timeout);
let new_future = Box::pin(
stream_future
.map_ok(move |tcp_stream| Self { tcp_stream })
.map_err(ProtoError::from),
);
(TcpClientConnect(new_future), sender)
}
}
impl<S: DnsTcpStream> Display for TcpClientStream<S> {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(formatter, "TCP({})", self.tcp_stream.peer_addr())
}
}
impl<S: DnsTcpStream> DnsClientStream for TcpClientStream<S> {
type Time = S::Time;
fn name_server_addr(&self) -> SocketAddr {
self.tcp_stream.peer_addr()
}
}
impl<S: DnsTcpStream> Stream for TcpClientStream<S> {
type Item = Result<SerialMessage, ProtoError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let message = try_ready_stream!(self.tcp_stream.poll_next_unpin(cx));
let peer = self.tcp_stream.peer_addr();
if message.addr() != peer {
warn!("{} does not match name_server: {}", message.addr(), peer)
}
Poll::Ready(Some(Ok(message)))
}
}
pub struct TcpClientConnect<S: DnsTcpStream>(
Pin<Box<dyn Future<Output = Result<TcpClientStream<S>, ProtoError>> + Send + 'static>>,
);
impl<S: DnsTcpStream> Future for TcpClientConnect<S> {
type Output = Result<TcpClientStream<S>, ProtoError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.0.as_mut().poll(cx)
}
}
#[cfg(feature = "tokio-runtime")]
use tokio::net::TcpStream as TokioTcpStream;
#[cfg(feature = "tokio-runtime")]
impl<T> DnsTcpStream for AsyncIoTokioAsStd<T>
where
T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + Sync + Sized + 'static,
{
type Time = TokioTime;
}
#[cfg(feature = "tokio-runtime")]
#[async_trait]
impl Connect for AsyncIoTokioAsStd<TokioTcpStream> {
async fn connect_with_bind(
addr: SocketAddr,
bind_addr: Option<SocketAddr>,
) -> io::Result<Self> {
super::tokio::connect_with_bind(&addr, &bind_addr)
.await
.map(AsyncIoTokioAsStd)
}
}
#[cfg(test)]
#[cfg(feature = "tokio-runtime")]
mod tests {
use super::AsyncIoTokioAsStd;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use tokio::net::TcpStream as TokioTcpStream;
use tokio::runtime::Runtime;
use crate::tests::tcp_client_stream_test;
#[test]
fn test_tcp_stream_ipv4() {
let io_loop = Runtime::new().expect("failed to create tokio runtime");
tcp_client_stream_test::<AsyncIoTokioAsStd<TokioTcpStream>, Runtime>(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
io_loop,
)
}
#[test]
fn test_tcp_stream_ipv6() {
let io_loop = Runtime::new().expect("failed to create tokio runtime");
tcp_client_stream_test::<AsyncIoTokioAsStd<TokioTcpStream>, Runtime>(
IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
io_loop,
)
}
}