1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
use std::io::{Error, ErrorKind, Result};
use std::net::SocketAddr;

use super::SocketAddrIterator;
use crate::error::LunaticError;
use crate::host;
use crate::net::TlsStream;

/// A TLS server, listening for connections.
///
/// After creating a [`TlsListener`] by [`bind`][`TlsListener::bind()`]ing it to
/// an address, it listens for incoming encrypted TCP (TLS) connections. These
/// can be accepted by calling [`accept()`][`TlsListener::accept()`].
///
/// The Transmission Control Protocol is specified in [IETF RFC 793].
///
/// [IETF RFC 793]: https://tools.ietf.org/html/rfc793
///
/// # Examples
///
/// ```no_run
/// use lunatic::{net, Mailbox, Process};
/// use std::io::{BufRead, BufReader, Write};
///
/// fn main() {
///     let listener = net::TlsListener::bind("127.0.0.1:0").unwrap();
///     while let Ok((tls_stream, _peer)) = listener.accept() {
///         // Handle connections in a new process
///         Process::spawn(tls_stream, handle);
///     }
/// }
///
/// fn handle(mut tls_stream: net::TlsStream, _: Mailbox<()>) {
///     let mut buf_reader = BufReader::new(tls_stream.clone());
///     loop {
///         let mut buffer = String::new();
///         let read = buf_reader.read_line(&mut buffer).unwrap();
///         if buffer.contains("exit") || read == 0 {
///             return;
///         }
///         tls_stream.write(buffer.as_bytes()).unwrap();
///     }
/// }
/// ```
#[derive(Debug)]
pub struct TlsListener {
    id: u64,
}

impl Drop for TlsListener {
    fn drop(&mut self) {
        unsafe { host::api::networking::drop_tls_listener(self.id) };
    }
}

impl TlsListener {
    /// Creates a new [`TlsListener`] bound to the given address.
    ///
    /// Binding with a port number of 0 will request that the operating system
    /// assigns an available port to this listener.
    ///
    /// If `addr` yields multiple addresses, binding will be attempted with each
    /// of the addresses until one succeeds and returns the listener. If
    /// none of the addresses succeed in creating a listener, the error from
    /// the last attempt is returned.
    pub fn bind<A>(addr: A, certs: Vec<u8>, keys: Vec<u8>) -> Result<Self>
    where
        A: super::ToSocketAddrs,
    {
        let mut id = 0;
        for addr in addr.to_socket_addrs()? {
            let result = match addr {
                SocketAddr::V4(v4_addr) => {
                    let ip = v4_addr.ip().octets();
                    let port = v4_addr.port() as u32;
                    unsafe {
                        host::api::networking::tls_bind(
                            4,
                            ip.as_ptr(),
                            port,
                            0,
                            0,
                            &mut id as *mut u64,
                            certs.as_ptr() as *const u32,
                            certs.len(),
                            keys.as_ptr() as *const u32,
                            keys.len(),
                        )
                    }
                }
                SocketAddr::V6(v6_addr) => {
                    let ip = v6_addr.ip().octets();
                    let port = v6_addr.port() as u32;
                    let flow_info = v6_addr.flowinfo();
                    let scope_id = v6_addr.scope_id();
                    unsafe {
                        host::api::networking::tls_bind(
                            6,
                            ip.as_ptr(),
                            port,
                            flow_info,
                            scope_id,
                            &mut id as *mut u64,
                            certs.as_ptr() as *const u32,
                            certs.len(),
                            keys.as_ptr() as *const u32,
                            keys.len(),
                        )
                    }
                }
            };
            if result == 0 {
                return Ok(Self { id });
            }
        }
        let lunatic_error = LunaticError::Error(id);
        Err(Error::new(ErrorKind::Other, lunatic_error))
    }

    /// Accepts a new incoming connection.
    ///
    /// This will block and typically needs its own dedicated child process
    /// loop.
    ///
    /// Returns a TLS stream and the peer address.
    pub fn accept(&self) -> Result<(TlsStream, SocketAddr)> {
        let mut tls_stream_or_error_id = 0;
        let mut dns_iter_id = 0;
        let result = unsafe {
            host::api::networking::tls_accept(
                self.id,
                &mut tls_stream_or_error_id as *mut u64,
                &mut dns_iter_id as *mut u64,
            )
        };
        if result == 0 {
            let tls_stream = TlsStream::from(tls_stream_or_error_id);
            let mut dns_iter = SocketAddrIterator::from(dns_iter_id);
            let peer = dns_iter.next().expect("must contain one element");
            Ok((tls_stream, peer))
        } else {
            let lunatic_error = LunaticError::Error(tls_stream_or_error_id);
            Err(Error::new(ErrorKind::Other, lunatic_error))
        }
    }

    /// Returns the local address that this listener is bound to.
    ///
    /// This can be useful, for example, to identify when binding to port 0
    /// which port was assigned by the OS.
    pub fn local_addr(&self) -> Result<SocketAddr> {
        let mut dns_iter_or_error_id = 0;
        let result = unsafe {
            host::api::networking::tls_local_addr(self.id, &mut dns_iter_or_error_id as *mut u64)
        };
        if result == 0 {
            let mut dns_iter = SocketAddrIterator::from(dns_iter_or_error_id);
            let addr = dns_iter.next().expect("must contain one element");
            Ok(addr)
        } else {
            let lunatic_error = LunaticError::Error(dns_iter_or_error_id);
            Err(Error::new(ErrorKind::Other, lunatic_error))
        }
    }
}