1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
use crate::{ReplyError, Result};
use tokio::io::ErrorKind as IOErrorKind;
use tokio::time::timeout;
use tokio::net::{TcpStream, ToSocketAddrs};
use std::time::Duration;

/// Easy to destructure bytes buffers by naming each fields:
///
/// # Examples (before)
///
/// ```ignore
/// let mut buf = [0u8; 2];
/// stream.read_exact(&mut buf).await?;
/// let [version, method_len] = buf;
///
/// assert_eq!(version, 0x05);
/// ```
///
/// # Examples (after)
///
/// ```ignore
/// let [version, method_len] = read_exact!(stream, [0u8; 2]);
///
/// assert_eq!(version, 0x05);
/// ```
#[macro_export]
macro_rules! read_exact {
    ($stream: expr, $array: expr) => {{
        let mut x = $array;
        //        $stream
        //            .read_exact(&mut x)
        //            .await
        //            .map_err(|_| io_err("lol"))?;
        $stream.read_exact(&mut x).await.map(|_| x)
    }};
}

#[macro_export]
macro_rules! ready {
    ($e:expr $(,)?) => {
        match $e {
            std::task::Poll::Ready(t) => t,
            std::task::Poll::Pending => return std::task::Poll::Pending,
        }
    };
}

pub async fn tcp_connect_with_timeout<T>(addr: T, request_timeout_s: u64) -> Result<TcpStream>
    where T: ToSocketAddrs,
{
    let fut = tcp_connect(addr);
    match timeout(Duration::from_secs(request_timeout_s), fut).await {
        Ok(result) => result,
        Err(_) => Err(ReplyError::ConnectionTimeout.into()),
    }
}

pub async fn tcp_connect<T>(addr: T) -> Result<TcpStream>
    where T: ToSocketAddrs,
{
    match TcpStream::connect(addr).await {
        Ok(o) => Ok(o),
        Err(e) => match e.kind() {
            // Match other TCP errors with ReplyError
            IOErrorKind::ConnectionRefused => {
                Err(ReplyError::ConnectionRefused.into())
            }
            IOErrorKind::ConnectionAborted => {
                Err(ReplyError::ConnectionNotAllowed.into())
            }
            IOErrorKind::ConnectionReset => {
                Err(ReplyError::ConnectionNotAllowed.into())
            }
            IOErrorKind::NotConnected => {
                Err(ReplyError::NetworkUnreachable.into())
            }
            _ => Err(e.into()), // #[error("General failure")] ?
        },
    }
}