1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
// Copyright (C) 2021 Quickwit, Inc.
//
// Quickwit is offered under the AGPL v3.0 and as commercial software.
// For commercial licensing, contact us at hello@quickwit.io.
//
// AGPL:
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as
// published by the Free Software Foundation, either version 3 of the
// License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

use std::net::{SocketAddr, TcpListener, ToSocketAddrs};

/// Finds a random available TCP port.
pub fn find_available_tcp_port() -> anyhow::Result<u16> {
    let socket: SocketAddr = ([127, 0, 0, 1], 0u16).into();
    let listener = TcpListener::bind(socket)?;
    let port = listener.local_addr()?.port();
    Ok(port)
}

/// Converts an object into a resolved `SocketAddr`.
pub fn get_socket_addr<T: ToSocketAddrs + std::fmt::Debug>(addr: &T) -> anyhow::Result<SocketAddr> {
    addr.to_socket_addrs()?
        .next()
        .ok_or_else(|| anyhow::anyhow!("Failed to resolve address `{:?}`.", addr))
}

/// Returns true if the socket addr is a valid socket address containing a port.
///
/// If the socket adddress looks invalid to begin with, we may return false or true.
fn contains_port(addr: &str) -> bool {
    // [IPv6]:port
    if let Some((_, colon_port)) = addr[1..].rsplit_once(']') {
        return colon_port.starts_with(':');
    }
    if let Some((host, _port)) = addr[1..].rsplit_once(':') {
        // if host contains a ":" then is thi is probably a IPv6 address.
        return !host.contains(':');
    }
    false
}

/// Attempts to parse a `socket_addr`.
/// If no port is defined, it just accepts the address and uses the given default port.
///
/// This function supports
/// - IPv4
/// - IPv4:port
/// - IPv6
/// - \[IPv6\]:port -- IpV6 contains colon. It is customary to require bracket for this reason.
/// - hostname
/// - hostname:port
/// with or without a port.
///
/// Note that this function returns a SocketAddr, so that if a hostname
/// is given, DNS resolution will happen once and for all.
pub fn parse_socket_addr_with_default_port(
    addr: &str,
    default_port: u16,
) -> anyhow::Result<SocketAddr> {
    if contains_port(addr) {
        get_socket_addr(&addr)
    } else {
        get_socket_addr(&(addr, default_port))
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn test_parse_socket_addr_helper(addr: &str, expected_opt: Option<&str>) {
        let socket_addr_res = parse_socket_addr_with_default_port(addr, 1337);
        if let Some(expected) = expected_opt {
            assert!(
                socket_addr_res.is_ok(),
                "Parsing `{}` was expected to succeed",
                addr
            );
            let socket_addr = socket_addr_res.unwrap();
            let expected_socket_addr: SocketAddr = expected.parse().unwrap();
            assert_eq!(socket_addr, expected_socket_addr);
        } else {
            assert!(
                socket_addr_res.is_err(),
                "Parsing `{}` was expected to fail",
                addr
            );
        }
    }

    #[test]
    fn test_parse_socket_addr_with_ips() {
        test_parse_socket_addr_helper("127.0.0.1", Some("127.0.0.1:1337"));
        test_parse_socket_addr_helper("127.0.0.1:100", Some("127.0.0.1:100"));
        test_parse_socket_addr_helper("127.0..1:100", None);
        test_parse_socket_addr_helper(
            "2001:0db8:85a3:0000:0000:8a2e:0370:7334",
            Some("[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:1337"),
        );
        test_parse_socket_addr_helper("2001:0db8:85a3:0000:0000:8a2e:0370:7334:1000", None);
        test_parse_socket_addr_helper(
            "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:1000",
            Some("[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:1000"),
        );
        test_parse_socket_addr_helper("[2001:0db8:1000", None);
        test_parse_socket_addr_helper("2001:0db8:85a3:0000:0000:8a2e:0370:7334]:1000", None);
    }

    // This test require DNS.
    #[test]
    fn test_parse_socket_addr_with_resolution() {
        let socket_addr = parse_socket_addr_with_default_port("google.com:1000", 1337).unwrap();
        assert_eq!(socket_addr.port(), 1000);
        let socket_addr = parse_socket_addr_with_default_port("google.com", 1337).unwrap();
        assert_eq!(socket_addr.port(), 1337);
    }
}