1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
use super::network::SocketAddressFamily;
use super::{HostInputStream, HostOutputStream, StreamError};
use crate::preview2::host::network::util;
use crate::preview2::{
    with_ambient_tokio_runtime, AbortOnDropJoinHandle, InputStream, OutputStream, Subscribe,
};
use anyhow::{Error, Result};
use cap_net_ext::{AddressFamily, Blocking};
use io_lifetimes::raw::{FromRawSocketlike, IntoRawSocketlike};
use rustix::net::sockopt;
use std::io;
use std::mem;
use std::sync::Arc;
use tokio::io::Interest;

/// The state of a TCP socket.
///
/// This represents the various states a socket can be in during the
/// activities of binding, listening, accepting, and connecting.
pub(crate) enum TcpState {
    /// The initial state for a newly-created socket.
    Default,

    /// Binding started via `start_bind`.
    BindStarted,

    /// Binding finished via `finish_bind`. The socket has an address but
    /// is not yet listening for connections.
    Bound,

    /// Listening started via `listen_start`.
    ListenStarted,

    /// The socket is now listening and waiting for an incoming connection.
    Listening,

    /// An outgoing connection is started via `start_connect`.
    Connecting,

    /// An outgoing connection is ready to be established.
    ConnectReady,

    /// An outgoing connection was attempted but failed.
    ConnectFailed,

    /// An outgoing connection has been established.
    Connected,
}

/// A host TCP socket, plus associated bookkeeping.
///
/// The inner state is wrapped in an Arc because the same underlying socket is
/// used for implementing the stream types.
pub struct TcpSocket {
    /// The part of a `TcpSocket` which is reference-counted so that we
    /// can pass it to async tasks.
    pub(crate) inner: Arc<tokio::net::TcpStream>,

    /// The current state in the bind/listen/accept/connect progression.
    pub(crate) tcp_state: TcpState,

    /// The desired listen queue size. Set to None to use the system's default.
    pub(crate) listen_backlog_size: Option<i32>,

    pub(crate) family: SocketAddressFamily,

    // The socket options below are not automatically inherited from the listener
    // on all platforms. So we keep track of which options have been explicitly
    // set and manually apply those values to newly accepted clients.
    #[cfg(target_os = "macos")]
    pub(crate) receive_buffer_size: Option<usize>,
    #[cfg(target_os = "macos")]
    pub(crate) send_buffer_size: Option<usize>,
    #[cfg(target_os = "macos")]
    pub(crate) hop_limit: Option<u8>,
    #[cfg(target_os = "macos")]
    pub(crate) keep_alive_idle_time: Option<std::time::Duration>,
}

pub(crate) struct TcpReadStream {
    stream: Arc<tokio::net::TcpStream>,
    closed: bool,
}

impl TcpReadStream {
    fn new(stream: Arc<tokio::net::TcpStream>) -> Self {
        Self {
            stream,
            closed: false,
        }
    }
}

#[async_trait::async_trait]
impl HostInputStream for TcpReadStream {
    fn read(&mut self, size: usize) -> Result<bytes::Bytes, StreamError> {
        if self.closed {
            return Err(StreamError::Closed);
        }
        if size == 0 {
            return Ok(bytes::Bytes::new());
        }

        let mut buf = bytes::BytesMut::with_capacity(size);
        let n = match self.stream.try_read_buf(&mut buf) {
            // A 0-byte read indicates that the stream has closed.
            Ok(0) => {
                self.closed = true;
                0
            }
            Ok(n) => n,

            // Failing with `EWOULDBLOCK` is how we differentiate between a closed channel and no
            // data to read right now.
            Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => 0,

            Err(e) => {
                self.closed = true;
                return Err(StreamError::LastOperationFailed(e.into()));
            }
        };

        buf.truncate(n);
        Ok(buf.freeze())
    }
}

#[async_trait::async_trait]
impl Subscribe for TcpReadStream {
    async fn ready(&mut self) {
        if self.closed {
            return;
        }
        self.stream.readable().await.unwrap();
    }
}

const SOCKET_READY_SIZE: usize = 1024 * 1024 * 1024;

pub(crate) struct TcpWriteStream {
    stream: Arc<tokio::net::TcpStream>,
    last_write: LastWrite,
}

enum LastWrite {
    Waiting(AbortOnDropJoinHandle<Result<()>>),
    Error(Error),
    Done,
}

impl TcpWriteStream {
    pub(crate) fn new(stream: Arc<tokio::net::TcpStream>) -> Self {
        Self {
            stream,
            last_write: LastWrite::Done,
        }
    }

    /// Write `bytes` in a background task, remembering the task handle for use in a future call to
    /// `write_ready`
    fn background_write(&mut self, mut bytes: bytes::Bytes) {
        assert!(matches!(self.last_write, LastWrite::Done));

        let stream = self.stream.clone();
        self.last_write = LastWrite::Waiting(crate::preview2::spawn(async move {
            // Note: we are not using the AsyncWrite impl here, and instead using the TcpStream
            // primitive try_write, which goes directly to attempt a write with mio. This has
            // two advantages: 1. this operation takes a &TcpStream instead of a &mut TcpStream
            // required to AsyncWrite, and 2. it eliminates any buffering in tokio we may need
            // to flush.
            while !bytes.is_empty() {
                stream.writable().await?;
                match stream.try_write(&bytes) {
                    Ok(n) => {
                        let _ = bytes.split_to(n);
                    }
                    Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => continue,
                    Err(e) => return Err(e.into()),
                }
            }

            Ok(())
        }));
    }
}

impl HostOutputStream for TcpWriteStream {
    fn write(&mut self, mut bytes: bytes::Bytes) -> Result<(), StreamError> {
        match self.last_write {
            LastWrite::Done => {}
            LastWrite::Waiting(_) | LastWrite::Error(_) => {
                return Err(StreamError::Trap(anyhow::anyhow!(
                    "unpermitted: must call check_write first"
                )));
            }
        }
        while !bytes.is_empty() {
            match self.stream.try_write(&bytes) {
                Ok(n) => {
                    let _ = bytes.split_to(n);
                }

                Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
                    // As `try_write` indicated that it would have blocked, we'll perform the write
                    // in the background to allow us to return immediately.
                    self.background_write(bytes);

                    return Ok(());
                }

                Err(e) => return Err(StreamError::LastOperationFailed(e.into())),
            }
        }

        Ok(())
    }

    fn flush(&mut self) -> Result<(), StreamError> {
        // `flush` is a no-op here, as we're not managing any internal buffer. Additionally,
        // `write_ready` will join the background write task if it's active, so following `flush`
        // with `write_ready` will have the desired effect.
        Ok(())
    }

    fn check_write(&mut self) -> Result<usize, StreamError> {
        match mem::replace(&mut self.last_write, LastWrite::Done) {
            LastWrite::Waiting(task) => {
                self.last_write = LastWrite::Waiting(task);
                return Ok(0);
            }
            LastWrite::Done => {}
            LastWrite::Error(e) => return Err(StreamError::LastOperationFailed(e.into())),
        }

        let writable = self.stream.writable();
        futures::pin_mut!(writable);
        if super::poll_noop(writable).is_none() {
            return Ok(0);
        }
        Ok(SOCKET_READY_SIZE)
    }
}

#[async_trait::async_trait]
impl Subscribe for TcpWriteStream {
    async fn ready(&mut self) {
        if let LastWrite::Waiting(task) = &mut self.last_write {
            self.last_write = match task.await {
                Ok(()) => LastWrite::Done,
                Err(e) => LastWrite::Error(e),
            };
        }
        if let LastWrite::Done = self.last_write {
            self.stream.writable().await.unwrap();
        }
    }
}

impl TcpSocket {
    /// Create a new socket in the given family.
    pub fn new(family: AddressFamily) -> io::Result<Self> {
        // Create a new host socket and set it to non-blocking, which is needed
        // by our async implementation.
        let fd = util::tcp_socket(family, Blocking::No)?;

        let socket_address_family = match family {
            AddressFamily::Ipv4 => SocketAddressFamily::Ipv4,
            AddressFamily::Ipv6 => {
                sockopt::set_ipv6_v6only(&fd, true)?;
                SocketAddressFamily::Ipv6
            }
        };

        Self::from_fd(fd, socket_address_family)
    }

    /// Create a `TcpSocket` from an existing socket.
    ///
    /// The socket must be in non-blocking mode.
    pub(crate) fn from_fd(
        fd: rustix::fd::OwnedFd,
        family: SocketAddressFamily,
    ) -> io::Result<Self> {
        let stream = Self::setup_tokio_tcp_stream(fd)?;

        Ok(Self {
            inner: Arc::new(stream),
            tcp_state: TcpState::Default,
            listen_backlog_size: None,
            family,
            #[cfg(target_os = "macos")]
            receive_buffer_size: None,
            #[cfg(target_os = "macos")]
            send_buffer_size: None,
            #[cfg(target_os = "macos")]
            hop_limit: None,
            #[cfg(target_os = "macos")]
            keep_alive_idle_time: None,
        })
    }

    fn setup_tokio_tcp_stream(fd: rustix::fd::OwnedFd) -> io::Result<tokio::net::TcpStream> {
        let std_stream =
            unsafe { std::net::TcpStream::from_raw_socketlike(fd.into_raw_socketlike()) };
        with_ambient_tokio_runtime(|| tokio::net::TcpStream::try_from(std_stream))
    }

    pub fn tcp_socket(&self) -> &tokio::net::TcpStream {
        &self.inner
    }

    /// Create the input/output stream pair for a tcp socket.
    pub fn as_split(&self) -> (InputStream, OutputStream) {
        let input = Box::new(TcpReadStream::new(self.inner.clone()));
        let output = Box::new(TcpWriteStream::new(self.inner.clone()));
        (InputStream::Host(input), output)
    }
}

#[async_trait::async_trait]
impl Subscribe for TcpSocket {
    async fn ready(&mut self) {
        // Some states are ready immediately.
        match self.tcp_state {
            TcpState::BindStarted | TcpState::ListenStarted | TcpState::ConnectReady => return,
            _ => {}
        }

        // FIXME: Add `Interest::ERROR` when we update to tokio 1.32.
        self.inner
            .ready(Interest::READABLE | Interest::WRITABLE)
            .await
            .unwrap();
    }
}