wasmtime_wasi/
tcp.rs

1use crate::bindings::sockets::tcp::ErrorCode;
2use crate::host::network;
3use crate::network::SocketAddressFamily;
4use crate::runtime::{with_ambient_tokio_runtime, AbortOnDropJoinHandle};
5use crate::{
6    DynInputStream, DynOutputStream, InputStream, OutputStream, Pollable, SocketError,
7    SocketResult, StreamError,
8};
9use anyhow::Result;
10use cap_net_ext::AddressFamily;
11use futures::Future;
12use io_lifetimes::views::SocketlikeView;
13use io_lifetimes::AsSocketlike;
14use rustix::io::Errno;
15use rustix::net::sockopt;
16use std::io;
17use std::mem;
18use std::net::{Shutdown, SocketAddr};
19use std::pin::Pin;
20use std::sync::Arc;
21use std::task::Poll;
22use tokio::sync::Mutex;
23
24/// Value taken from rust std library.
25const DEFAULT_BACKLOG: u32 = 128;
26
27/// The state of a TCP socket.
28///
29/// This represents the various states a socket can be in during the
30/// activities of binding, listening, accepting, and connecting.
31enum TcpState {
32    /// The initial state for a newly-created socket.
33    Default(tokio::net::TcpSocket),
34
35    /// Binding started via `start_bind`.
36    BindStarted(tokio::net::TcpSocket),
37
38    /// Binding finished via `finish_bind`. The socket has an address but
39    /// is not yet listening for connections.
40    Bound(tokio::net::TcpSocket),
41
42    /// Listening started via `listen_start`.
43    ListenStarted(tokio::net::TcpSocket),
44
45    /// The socket is now listening and waiting for an incoming connection.
46    Listening {
47        listener: tokio::net::TcpListener,
48        pending_accept: Option<io::Result<tokio::net::TcpStream>>,
49    },
50
51    /// An outgoing connection is started via `start_connect`.
52    Connecting(Pin<Box<dyn Future<Output = io::Result<tokio::net::TcpStream>> + Send>>),
53
54    /// An outgoing connection is ready to be established.
55    ConnectReady(io::Result<tokio::net::TcpStream>),
56
57    /// An outgoing connection has been established.
58    Connected {
59        stream: Arc<tokio::net::TcpStream>,
60
61        // WASI is single threaded, so in practice these Mutexes should never be contended:
62        reader: Arc<Mutex<TcpReader>>,
63        writer: Arc<Mutex<TcpWriter>>,
64    },
65
66    Closed,
67}
68
69impl std::fmt::Debug for TcpState {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        match self {
72            Self::Default(_) => f.debug_tuple("Default").finish(),
73            Self::BindStarted(_) => f.debug_tuple("BindStarted").finish(),
74            Self::Bound(_) => f.debug_tuple("Bound").finish(),
75            Self::ListenStarted(_) => f.debug_tuple("ListenStarted").finish(),
76            Self::Listening { pending_accept, .. } => f
77                .debug_struct("Listening")
78                .field("pending_accept", pending_accept)
79                .finish(),
80            Self::Connecting(_) => f.debug_tuple("Connecting").finish(),
81            Self::ConnectReady(_) => f.debug_tuple("ConnectReady").finish(),
82            Self::Connected { .. } => f.debug_tuple("Connected").finish(),
83            Self::Closed => write!(f, "Closed"),
84        }
85    }
86}
87
88/// A host TCP socket, plus associated bookkeeping.
89pub struct TcpSocket {
90    /// The current state in the bind/listen/accept/connect progression.
91    tcp_state: TcpState,
92
93    /// The desired listen queue size.
94    listen_backlog_size: u32,
95
96    family: SocketAddressFamily,
97
98    // The socket options below are not automatically inherited from the listener
99    // on all platforms. So we keep track of which options have been explicitly
100    // set and manually apply those values to newly accepted clients.
101    #[cfg(target_os = "macos")]
102    receive_buffer_size: Option<usize>,
103    #[cfg(target_os = "macos")]
104    send_buffer_size: Option<usize>,
105    #[cfg(target_os = "macos")]
106    hop_limit: Option<u8>,
107    #[cfg(target_os = "macos")]
108    keep_alive_idle_time: Option<std::time::Duration>,
109}
110
111impl TcpSocket {
112    /// Create a new socket in the given family.
113    pub fn new(family: AddressFamily) -> io::Result<Self> {
114        with_ambient_tokio_runtime(|| {
115            let (socket, family) = match family {
116                AddressFamily::Ipv4 => {
117                    let socket = tokio::net::TcpSocket::new_v4()?;
118                    (socket, SocketAddressFamily::Ipv4)
119                }
120                AddressFamily::Ipv6 => {
121                    let socket = tokio::net::TcpSocket::new_v6()?;
122                    sockopt::set_ipv6_v6only(&socket, true)?;
123                    (socket, SocketAddressFamily::Ipv6)
124                }
125            };
126
127            Self::from_state(TcpState::Default(socket), family)
128        })
129    }
130
131    /// Create a `TcpSocket` from an existing socket.
132    fn from_state(state: TcpState, family: SocketAddressFamily) -> io::Result<Self> {
133        Ok(Self {
134            tcp_state: state,
135            listen_backlog_size: DEFAULT_BACKLOG,
136            family,
137            #[cfg(target_os = "macos")]
138            receive_buffer_size: None,
139            #[cfg(target_os = "macos")]
140            send_buffer_size: None,
141            #[cfg(target_os = "macos")]
142            hop_limit: None,
143            #[cfg(target_os = "macos")]
144            keep_alive_idle_time: None,
145        })
146    }
147
148    fn as_std_view(&self) -> SocketResult<SocketlikeView<'_, std::net::TcpStream>> {
149        use crate::bindings::sockets::network::ErrorCode;
150
151        match &self.tcp_state {
152            TcpState::Default(socket) | TcpState::Bound(socket) => {
153                Ok(socket.as_socketlike_view::<std::net::TcpStream>())
154            }
155            TcpState::Connected { stream, .. } => {
156                Ok(stream.as_socketlike_view::<std::net::TcpStream>())
157            }
158            TcpState::Listening { listener, .. } => {
159                Ok(listener.as_socketlike_view::<std::net::TcpStream>())
160            }
161
162            TcpState::BindStarted(..)
163            | TcpState::ListenStarted(..)
164            | TcpState::Connecting(..)
165            | TcpState::ConnectReady(..)
166            | TcpState::Closed => Err(ErrorCode::InvalidState.into()),
167        }
168    }
169}
170
171impl TcpSocket {
172    pub fn start_bind(&mut self, local_address: SocketAddr) -> io::Result<()> {
173        let tokio_socket = match &self.tcp_state {
174            TcpState::Default(socket) => socket,
175            TcpState::BindStarted(..) => return Err(Errno::ALREADY.into()),
176            _ => return Err(Errno::ISCONN.into()),
177        };
178
179        network::util::validate_unicast(&local_address)?;
180        network::util::validate_address_family(&local_address, &self.family)?;
181
182        {
183            // Automatically bypass the TIME_WAIT state when the user is trying
184            // to bind to a specific port:
185            let reuse_addr = local_address.port() > 0;
186
187            // Unconditionally (re)set SO_REUSEADDR, even when the value is false.
188            // This ensures we're not accidentally affected by any socket option
189            // state left behind by a previous failed call to this method (start_bind).
190            network::util::set_tcp_reuseaddr(&tokio_socket, reuse_addr)?;
191
192            // Perform the OS bind call.
193            tokio_socket.bind(local_address).map_err(|error| {
194                match Errno::from_io_error(&error) {
195                    // From https://pubs.opengroup.org/onlinepubs/9699919799/functions/bind.html:
196                    // > [EAFNOSUPPORT] The specified address is not a valid address for the address family of the specified socket
197                    //
198                    // The most common reasons for this error should have already
199                    // been handled by our own validation slightly higher up in this
200                    // function. This error mapping is here just in case there is
201                    // an edge case we didn't catch.
202                    Some(Errno::AFNOSUPPORT) =>  io::Error::new(
203                        io::ErrorKind::InvalidInput,
204                        "The specified address is not a valid address for the address family of the specified socket",
205                    ),
206
207                    // See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-bind#:~:text=WSAENOBUFS
208                    // Windows returns WSAENOBUFS when the ephemeral ports have been exhausted.
209                    #[cfg(windows)]
210                    Some(Errno::NOBUFS) => io::Error::new(io::ErrorKind::AddrInUse, "no more free local ports"),
211
212                    _ => error,
213                }
214            })?;
215
216            self.tcp_state = match std::mem::replace(&mut self.tcp_state, TcpState::Closed) {
217                TcpState::Default(socket) => TcpState::BindStarted(socket),
218                _ => unreachable!(),
219            };
220
221            Ok(())
222        }
223    }
224
225    pub fn finish_bind(&mut self) -> SocketResult<()> {
226        match std::mem::replace(&mut self.tcp_state, TcpState::Closed) {
227            TcpState::BindStarted(socket) => {
228                self.tcp_state = TcpState::Bound(socket);
229                Ok(())
230            }
231            current_state => {
232                // Reset the state so that the outside world doesn't see this socket as closed
233                self.tcp_state = current_state;
234                Err(ErrorCode::NotInProgress.into())
235            }
236        }
237    }
238
239    pub fn start_connect(&mut self, remote_address: SocketAddr) -> SocketResult<()> {
240        match self.tcp_state {
241            TcpState::Default(..) | TcpState::Bound(..) => {}
242
243            TcpState::Connecting(..) | TcpState::ConnectReady(..) => {
244                return Err(ErrorCode::ConcurrencyConflict.into())
245            }
246
247            _ => return Err(ErrorCode::InvalidState.into()),
248        };
249
250        network::util::validate_unicast(&remote_address)?;
251        network::util::validate_remote_address(&remote_address)?;
252        network::util::validate_address_family(&remote_address, &self.family)?;
253
254        let (TcpState::Default(tokio_socket) | TcpState::Bound(tokio_socket)) =
255            std::mem::replace(&mut self.tcp_state, TcpState::Closed)
256        else {
257            unreachable!();
258        };
259
260        let future = tokio_socket.connect(remote_address);
261
262        self.tcp_state = TcpState::Connecting(Box::pin(future));
263        Ok(())
264    }
265
266    pub fn finish_connect(&mut self) -> SocketResult<(DynInputStream, DynOutputStream)> {
267        let previous_state = std::mem::replace(&mut self.tcp_state, TcpState::Closed);
268        let result = match previous_state {
269            TcpState::ConnectReady(result) => result,
270            TcpState::Connecting(mut future) => {
271                let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref());
272                match with_ambient_tokio_runtime(|| future.as_mut().poll(&mut cx)) {
273                    Poll::Ready(result) => result,
274                    Poll::Pending => {
275                        self.tcp_state = TcpState::Connecting(future);
276                        return Err(ErrorCode::WouldBlock.into());
277                    }
278                }
279            }
280            previous_state => {
281                self.tcp_state = previous_state;
282                return Err(ErrorCode::NotInProgress.into());
283            }
284        };
285
286        match result {
287            Ok(stream) => {
288                let stream = Arc::new(stream);
289                let reader = Arc::new(Mutex::new(TcpReader::new(stream.clone())));
290                let writer = Arc::new(Mutex::new(TcpWriter::new(stream.clone())));
291                self.tcp_state = TcpState::Connected {
292                    stream,
293                    reader: reader.clone(),
294                    writer: writer.clone(),
295                };
296                let input: DynInputStream = Box::new(TcpReadStream(reader));
297                let output: DynOutputStream = Box::new(TcpWriteStream(writer));
298                Ok((input, output))
299            }
300            Err(err) => {
301                self.tcp_state = TcpState::Closed;
302                Err(err.into())
303            }
304        }
305    }
306
307    pub fn start_listen(&mut self) -> SocketResult<()> {
308        match std::mem::replace(&mut self.tcp_state, TcpState::Closed) {
309            TcpState::Bound(tokio_socket) => {
310                self.tcp_state = TcpState::ListenStarted(tokio_socket);
311                Ok(())
312            }
313            TcpState::ListenStarted(tokio_socket) => {
314                self.tcp_state = TcpState::ListenStarted(tokio_socket);
315                Err(ErrorCode::ConcurrencyConflict.into())
316            }
317            previous_state => {
318                self.tcp_state = previous_state;
319                Err(ErrorCode::InvalidState.into())
320            }
321        }
322    }
323
324    pub fn finish_listen(&mut self) -> SocketResult<()> {
325        let tokio_socket = match std::mem::replace(&mut self.tcp_state, TcpState::Closed) {
326            TcpState::ListenStarted(tokio_socket) => tokio_socket,
327            previous_state => {
328                self.tcp_state = previous_state;
329                return Err(ErrorCode::NotInProgress.into());
330            }
331        };
332
333        match with_ambient_tokio_runtime(|| tokio_socket.listen(self.listen_backlog_size)) {
334            Ok(listener) => {
335                self.tcp_state = TcpState::Listening {
336                    listener,
337                    pending_accept: None,
338                };
339                Ok(())
340            }
341            Err(err) => {
342                self.tcp_state = TcpState::Closed;
343
344                Err(match Errno::from_io_error(&err) {
345                    // See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-listen#:~:text=WSAEMFILE
346                    // According to the docs, `listen` can return EMFILE on Windows.
347                    // This is odd, because we're not trying to create a new socket
348                    // or file descriptor of any kind. So we rewrite it to less
349                    // surprising error code.
350                    //
351                    // At the time of writing, this behavior has never been experimentally
352                    // observed by any of the wasmtime authors, so we're relying fully
353                    // on Microsoft's documentation here.
354                    #[cfg(windows)]
355                    Some(Errno::MFILE) => Errno::NOBUFS.into(),
356
357                    _ => err.into(),
358                })
359            }
360        }
361    }
362
363    pub fn accept(&mut self) -> SocketResult<(Self, DynInputStream, DynOutputStream)> {
364        let TcpState::Listening {
365            listener,
366            pending_accept,
367        } = &mut self.tcp_state
368        else {
369            return Err(ErrorCode::InvalidState.into());
370        };
371
372        let result = match pending_accept.take() {
373            Some(result) => result,
374            None => {
375                let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref());
376                match with_ambient_tokio_runtime(|| listener.poll_accept(&mut cx))
377                    .map_ok(|(stream, _)| stream)
378                {
379                    Poll::Ready(result) => result,
380                    Poll::Pending => Err(Errno::WOULDBLOCK.into()),
381                }
382            }
383        };
384
385        let client = result.map_err(|err| match Errno::from_io_error(&err) {
386            // From: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-accept#:~:text=WSAEINPROGRESS
387            // > WSAEINPROGRESS: A blocking Windows Sockets 1.1 call is in progress,
388            // > or the service provider is still processing a callback function.
389            //
390            // wasi-sockets doesn't have an equivalent to the EINPROGRESS error,
391            // because in POSIX this error is only returned by a non-blocking
392            // `connect` and wasi-sockets has a different solution for that.
393            #[cfg(windows)]
394            Some(Errno::INPROGRESS) => Errno::INTR.into(),
395
396            // Normalize Linux' non-standard behavior.
397            //
398            // From https://man7.org/linux/man-pages/man2/accept.2.html:
399            // > Linux accept() passes already-pending network errors on the
400            // > new socket as an error code from accept(). This behavior
401            // > differs from other BSD socket implementations. (...)
402            #[cfg(target_os = "linux")]
403            Some(
404                Errno::CONNRESET
405                | Errno::NETRESET
406                | Errno::HOSTUNREACH
407                | Errno::HOSTDOWN
408                | Errno::NETDOWN
409                | Errno::NETUNREACH
410                | Errno::PROTO
411                | Errno::NOPROTOOPT
412                | Errno::NONET
413                | Errno::OPNOTSUPP,
414            ) => Errno::CONNABORTED.into(),
415
416            _ => err,
417        })?;
418
419        #[cfg(target_os = "macos")]
420        {
421            // Manually inherit socket options from listener. We only have to
422            // do this on platforms that don't already do this automatically
423            // and only if a specific value was explicitly set on the listener.
424
425            if let Some(size) = self.receive_buffer_size {
426                _ = network::util::set_socket_recv_buffer_size(&client, size); // Ignore potential error.
427            }
428
429            if let Some(size) = self.send_buffer_size {
430                _ = network::util::set_socket_send_buffer_size(&client, size); // Ignore potential error.
431            }
432
433            // For some reason, IP_TTL is inherited, but IPV6_UNICAST_HOPS isn't.
434            if let (SocketAddressFamily::Ipv6, Some(ttl)) = (self.family, self.hop_limit) {
435                _ = network::util::set_ipv6_unicast_hops(&client, ttl); // Ignore potential error.
436            }
437
438            if let Some(value) = self.keep_alive_idle_time {
439                _ = network::util::set_tcp_keepidle(&client, value); // Ignore potential error.
440            }
441        }
442
443        let client = Arc::new(client);
444
445        let reader = Arc::new(Mutex::new(TcpReader::new(client.clone())));
446        let writer = Arc::new(Mutex::new(TcpWriter::new(client.clone())));
447
448        let input: DynInputStream = Box::new(TcpReadStream(reader.clone()));
449        let output: DynOutputStream = Box::new(TcpWriteStream(writer.clone()));
450        let tcp_socket = TcpSocket::from_state(
451            TcpState::Connected {
452                stream: client,
453                reader,
454                writer,
455            },
456            self.family,
457        )?;
458
459        Ok((tcp_socket, input, output))
460    }
461
462    pub fn local_address(&self) -> SocketResult<SocketAddr> {
463        let view = match self.tcp_state {
464            TcpState::Default(..) => return Err(ErrorCode::InvalidState.into()),
465            TcpState::BindStarted(..) => return Err(ErrorCode::ConcurrencyConflict.into()),
466            _ => self.as_std_view()?,
467        };
468
469        Ok(view.local_addr()?)
470    }
471
472    pub fn remote_address(&self) -> SocketResult<SocketAddr> {
473        let view = match self.tcp_state {
474            TcpState::Connected { .. } => self.as_std_view()?,
475            TcpState::Connecting(..) | TcpState::ConnectReady(..) => {
476                return Err(ErrorCode::ConcurrencyConflict.into())
477            }
478            _ => return Err(ErrorCode::InvalidState.into()),
479        };
480
481        Ok(view.peer_addr()?)
482    }
483
484    pub fn is_listening(&self) -> bool {
485        matches!(self.tcp_state, TcpState::Listening { .. })
486    }
487
488    pub fn address_family(&self) -> SocketAddressFamily {
489        self.family
490    }
491
492    pub fn set_listen_backlog_size(&mut self, value: u32) -> SocketResult<()> {
493        const MIN_BACKLOG: u32 = 1;
494        const MAX_BACKLOG: u32 = i32::MAX as u32; // OS'es will most likely limit it down even further.
495
496        if value == 0 {
497            return Err(ErrorCode::InvalidArgument.into());
498        }
499
500        // Silently clamp backlog size. This is OK for us to do, because operating systems do this too.
501        let value = value.clamp(MIN_BACKLOG, MAX_BACKLOG);
502
503        match &self.tcp_state {
504            TcpState::Default(..) | TcpState::Bound(..) => {
505                // Socket not listening yet. Stash value for first invocation to `listen`.
506            }
507            TcpState::Listening { listener, .. } => {
508                // Try to update the backlog by calling `listen` again.
509                // Not all platforms support this. We'll only update our own value if the OS supports changing the backlog size after the fact.
510
511                rustix::net::listen(&listener, value.try_into().unwrap())
512                    .map_err(|_| ErrorCode::NotSupported)?;
513            }
514            _ => return Err(ErrorCode::InvalidState.into()),
515        }
516        self.listen_backlog_size = value;
517
518        Ok(())
519    }
520
521    pub fn keep_alive_enabled(&self) -> SocketResult<bool> {
522        let view = &*self.as_std_view()?;
523        Ok(sockopt::get_socket_keepalive(view)?)
524    }
525
526    pub fn set_keep_alive_enabled(&self, value: bool) -> SocketResult<()> {
527        let view = &*self.as_std_view()?;
528        Ok(sockopt::set_socket_keepalive(view, value)?)
529    }
530
531    pub fn keep_alive_idle_time(&self) -> SocketResult<std::time::Duration> {
532        let view = &*self.as_std_view()?;
533        Ok(sockopt::get_tcp_keepidle(view)?)
534    }
535
536    pub fn set_keep_alive_idle_time(&mut self, duration: std::time::Duration) -> SocketResult<()> {
537        {
538            let view = &*self.as_std_view()?;
539            network::util::set_tcp_keepidle(view, duration)?;
540        }
541
542        #[cfg(target_os = "macos")]
543        {
544            self.keep_alive_idle_time = Some(duration);
545        }
546
547        Ok(())
548    }
549
550    pub fn keep_alive_interval(&self) -> SocketResult<std::time::Duration> {
551        let view = &*self.as_std_view()?;
552        Ok(sockopt::get_tcp_keepintvl(view)?)
553    }
554
555    pub fn set_keep_alive_interval(&self, duration: std::time::Duration) -> SocketResult<()> {
556        let view = &*self.as_std_view()?;
557        Ok(network::util::set_tcp_keepintvl(view, duration)?)
558    }
559
560    pub fn keep_alive_count(&self) -> SocketResult<u32> {
561        let view = &*self.as_std_view()?;
562        Ok(sockopt::get_tcp_keepcnt(view)?)
563    }
564
565    pub fn set_keep_alive_count(&self, value: u32) -> SocketResult<()> {
566        let view = &*self.as_std_view()?;
567        Ok(network::util::set_tcp_keepcnt(view, value)?)
568    }
569
570    pub fn hop_limit(&self) -> SocketResult<u8> {
571        let view = &*self.as_std_view()?;
572
573        let ttl = match self.family {
574            SocketAddressFamily::Ipv4 => network::util::get_ip_ttl(view)?,
575            SocketAddressFamily::Ipv6 => network::util::get_ipv6_unicast_hops(view)?,
576        };
577
578        Ok(ttl)
579    }
580
581    pub fn set_hop_limit(&mut self, value: u8) -> SocketResult<()> {
582        {
583            let view = &*self.as_std_view()?;
584
585            match self.family {
586                SocketAddressFamily::Ipv4 => network::util::set_ip_ttl(view, value)?,
587                SocketAddressFamily::Ipv6 => network::util::set_ipv6_unicast_hops(view, value)?,
588            }
589        }
590
591        #[cfg(target_os = "macos")]
592        {
593            self.hop_limit = Some(value);
594        }
595
596        Ok(())
597    }
598
599    pub fn receive_buffer_size(&self) -> SocketResult<usize> {
600        let view = &*self.as_std_view()?;
601
602        Ok(network::util::get_socket_recv_buffer_size(view)?)
603    }
604
605    pub fn set_receive_buffer_size(&mut self, value: usize) -> SocketResult<()> {
606        {
607            let view = &*self.as_std_view()?;
608
609            network::util::set_socket_recv_buffer_size(view, value)?;
610        }
611
612        #[cfg(target_os = "macos")]
613        {
614            self.receive_buffer_size = Some(value);
615        }
616
617        Ok(())
618    }
619
620    pub fn send_buffer_size(&self) -> SocketResult<usize> {
621        let view = &*self.as_std_view()?;
622
623        Ok(network::util::get_socket_send_buffer_size(view)?)
624    }
625
626    pub fn set_send_buffer_size(&mut self, value: usize) -> SocketResult<()> {
627        {
628            let view = &*self.as_std_view()?;
629
630            network::util::set_socket_send_buffer_size(view, value)?;
631        }
632
633        #[cfg(target_os = "macos")]
634        {
635            self.send_buffer_size = Some(value);
636        }
637
638        Ok(())
639    }
640
641    pub fn shutdown(&self, how: Shutdown) -> SocketResult<()> {
642        let TcpState::Connected { reader, writer, .. } = &self.tcp_state else {
643            return Err(ErrorCode::InvalidState.into());
644        };
645
646        if let Shutdown::Both | Shutdown::Read = how {
647            try_lock_for_socket(reader)?.shutdown();
648        }
649
650        if let Shutdown::Both | Shutdown::Write = how {
651            try_lock_for_socket(writer)?.shutdown();
652        }
653
654        Ok(())
655    }
656}
657
658#[async_trait::async_trait]
659impl Pollable for TcpSocket {
660    async fn ready(&mut self) {
661        match &mut self.tcp_state {
662            TcpState::Default(..)
663            | TcpState::BindStarted(..)
664            | TcpState::Bound(..)
665            | TcpState::ListenStarted(..)
666            | TcpState::ConnectReady(..)
667            | TcpState::Closed
668            | TcpState::Connected { .. } => {
669                // No async operation in progress.
670            }
671            TcpState::Connecting(future) => {
672                self.tcp_state = TcpState::ConnectReady(future.as_mut().await);
673            }
674            TcpState::Listening {
675                listener,
676                pending_accept,
677            } => match pending_accept {
678                Some(_) => {}
679                None => {
680                    let result = futures::future::poll_fn(|cx| {
681                        listener.poll_accept(cx).map_ok(|(stream, _)| stream)
682                    })
683                    .await;
684                    *pending_accept = Some(result);
685                }
686            },
687        }
688    }
689}
690
691struct TcpReader {
692    stream: Arc<tokio::net::TcpStream>,
693    closed: bool,
694}
695
696impl TcpReader {
697    fn new(stream: Arc<tokio::net::TcpStream>) -> Self {
698        Self {
699            stream,
700            closed: false,
701        }
702    }
703    fn read(&mut self, size: usize) -> Result<bytes::Bytes, StreamError> {
704        if self.closed {
705            return Err(StreamError::Closed);
706        }
707        if size == 0 {
708            return Ok(bytes::Bytes::new());
709        }
710
711        let mut buf = bytes::BytesMut::with_capacity(size);
712        let n = match self.stream.try_read_buf(&mut buf) {
713            // A 0-byte read indicates that the stream has closed.
714            Ok(0) => {
715                self.closed = true;
716                return Err(StreamError::Closed);
717            }
718            Ok(n) => n,
719
720            // Failing with `EWOULDBLOCK` is how we differentiate between a closed channel and no
721            // data to read right now.
722            Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => 0,
723
724            Err(e) => {
725                self.closed = true;
726                return Err(StreamError::LastOperationFailed(e.into()));
727            }
728        };
729
730        buf.truncate(n);
731        Ok(buf.freeze())
732    }
733
734    fn shutdown(&mut self) {
735        native_shutdown(&self.stream, Shutdown::Read);
736        self.closed = true;
737    }
738
739    async fn ready(&mut self) {
740        if self.closed {
741            return;
742        }
743
744        self.stream.readable().await.unwrap();
745    }
746}
747
748struct TcpReadStream(Arc<Mutex<TcpReader>>);
749
750#[async_trait::async_trait]
751impl InputStream for TcpReadStream {
752    fn read(&mut self, size: usize) -> Result<bytes::Bytes, StreamError> {
753        try_lock_for_stream(&self.0)?.read(size)
754    }
755}
756
757#[async_trait::async_trait]
758impl Pollable for TcpReadStream {
759    async fn ready(&mut self) {
760        self.0.lock().await.ready().await
761    }
762}
763
764const SOCKET_READY_SIZE: usize = 1024 * 1024 * 1024;
765
766struct TcpWriter {
767    stream: Arc<tokio::net::TcpStream>,
768    state: WriteState,
769}
770
771enum WriteState {
772    Ready,
773    Writing(AbortOnDropJoinHandle<io::Result<()>>),
774    Closing(AbortOnDropJoinHandle<io::Result<()>>),
775    Closed,
776    Error(io::Error),
777}
778
779impl TcpWriter {
780    fn new(stream: Arc<tokio::net::TcpStream>) -> Self {
781        Self {
782            stream,
783            state: WriteState::Ready,
784        }
785    }
786
787    fn try_write_portable(stream: &tokio::net::TcpStream, buf: &[u8]) -> io::Result<usize> {
788        stream.try_write(buf).map_err(|error| {
789            match Errno::from_io_error(&error) {
790                // Windows returns `WSAESHUTDOWN` when writing to a shut down socket.
791                // We normalize this to EPIPE, because that is what the other platforms return.
792                // See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-send#:~:text=WSAESHUTDOWN
793                #[cfg(windows)]
794                Some(Errno::SHUTDOWN) => io::Error::new(io::ErrorKind::BrokenPipe, error),
795
796                _ => error,
797            }
798        })
799    }
800
801    /// Write `bytes` in a background task, remembering the task handle for use in a future call to
802    /// `write_ready`
803    fn background_write(&mut self, mut bytes: bytes::Bytes) {
804        assert!(matches!(self.state, WriteState::Ready));
805
806        let stream = self.stream.clone();
807        self.state = WriteState::Writing(crate::runtime::spawn(async move {
808            // Note: we are not using the AsyncWrite impl here, and instead using the TcpStream
809            // primitive try_write, which goes directly to attempt a write with mio. This has
810            // two advantages: 1. this operation takes a &TcpStream instead of a &mut TcpStream
811            // required to AsyncWrite, and 2. it eliminates any buffering in tokio we may need
812            // to flush.
813            while !bytes.is_empty() {
814                stream.writable().await?;
815                match Self::try_write_portable(&stream, &bytes) {
816                    Ok(n) => {
817                        let _ = bytes.split_to(n);
818                    }
819                    Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => continue,
820                    Err(e) => return Err(e.into()),
821                }
822            }
823
824            Ok(())
825        }));
826    }
827
828    fn write(&mut self, mut bytes: bytes::Bytes) -> Result<(), StreamError> {
829        match self.state {
830            WriteState::Ready => {}
831            WriteState::Closed => return Err(StreamError::Closed),
832            WriteState::Writing(_) | WriteState::Closing(_) | WriteState::Error(_) => {
833                return Err(StreamError::Trap(anyhow::anyhow!(
834                    "unpermitted: must call check_write first"
835                )));
836            }
837        }
838        while !bytes.is_empty() {
839            match Self::try_write_portable(&self.stream, &bytes) {
840                Ok(n) => {
841                    let _ = bytes.split_to(n);
842                }
843
844                Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
845                    // As `try_write` indicated that it would have blocked, we'll perform the write
846                    // in the background to allow us to return immediately.
847                    self.background_write(bytes);
848
849                    return Ok(());
850                }
851
852                Err(e) if e.kind() == std::io::ErrorKind::BrokenPipe => {
853                    self.state = WriteState::Closed;
854                    return Err(StreamError::Closed);
855                }
856
857                Err(e) => return Err(StreamError::LastOperationFailed(e.into())),
858            }
859        }
860
861        Ok(())
862    }
863
864    fn flush(&mut self) -> Result<(), StreamError> {
865        // `flush` is a no-op here, as we're not managing any internal buffer. Additionally,
866        // `write_ready` will join the background write task if it's active, so following `flush`
867        // with `write_ready` will have the desired effect.
868        match self.state {
869            WriteState::Ready
870            | WriteState::Writing(_)
871            | WriteState::Closing(_)
872            | WriteState::Error(_) => Ok(()),
873            WriteState::Closed => Err(StreamError::Closed),
874        }
875    }
876
877    fn check_write(&mut self) -> Result<usize, StreamError> {
878        match mem::replace(&mut self.state, WriteState::Closed) {
879            WriteState::Writing(task) => {
880                self.state = WriteState::Writing(task);
881                return Ok(0);
882            }
883            WriteState::Closing(task) => {
884                self.state = WriteState::Closing(task);
885                return Ok(0);
886            }
887            WriteState::Ready => {
888                self.state = WriteState::Ready;
889            }
890            WriteState::Closed => return Err(StreamError::Closed),
891            WriteState::Error(e) => return Err(StreamError::LastOperationFailed(e.into())),
892        }
893
894        let writable = self.stream.writable();
895        futures::pin_mut!(writable);
896        if crate::runtime::poll_noop(writable).is_none() {
897            return Ok(0);
898        }
899        Ok(SOCKET_READY_SIZE)
900    }
901
902    fn shutdown(&mut self) {
903        self.state = match mem::replace(&mut self.state, WriteState::Closed) {
904            // No write in progress, immediately shut down:
905            WriteState::Ready => {
906                native_shutdown(&self.stream, Shutdown::Write);
907                WriteState::Closed
908            }
909
910            // Schedule the shutdown after the current write has finished:
911            WriteState::Writing(write) => {
912                let stream = self.stream.clone();
913                WriteState::Closing(crate::runtime::spawn(async move {
914                    let result = write.await;
915                    native_shutdown(&stream, Shutdown::Write);
916                    result
917                }))
918            }
919
920            s => s,
921        };
922    }
923
924    async fn cancel(&mut self) {
925        match mem::replace(&mut self.state, WriteState::Closed) {
926            WriteState::Writing(task) | WriteState::Closing(task) => _ = task.cancel().await,
927            _ => {}
928        }
929    }
930
931    async fn ready(&mut self) {
932        match &mut self.state {
933            WriteState::Writing(task) => {
934                self.state = match task.await {
935                    Ok(()) => WriteState::Ready,
936                    Err(e) => WriteState::Error(e),
937                }
938            }
939            WriteState::Closing(task) => {
940                self.state = match task.await {
941                    Ok(()) => WriteState::Closed,
942                    Err(e) => WriteState::Error(e),
943                }
944            }
945            _ => {}
946        }
947
948        if let WriteState::Ready = self.state {
949            self.stream.writable().await.unwrap();
950        }
951    }
952}
953
954struct TcpWriteStream(Arc<Mutex<TcpWriter>>);
955
956#[async_trait::async_trait]
957impl OutputStream for TcpWriteStream {
958    fn write(&mut self, bytes: bytes::Bytes) -> Result<(), StreamError> {
959        try_lock_for_stream(&self.0)?.write(bytes)
960    }
961
962    fn flush(&mut self) -> Result<(), StreamError> {
963        try_lock_for_stream(&self.0)?.flush()
964    }
965
966    fn check_write(&mut self) -> Result<usize, StreamError> {
967        try_lock_for_stream(&self.0)?.check_write()
968    }
969
970    async fn cancel(&mut self) {
971        self.0.lock().await.cancel().await
972    }
973}
974
975#[async_trait::async_trait]
976impl Pollable for TcpWriteStream {
977    async fn ready(&mut self) {
978        self.0.lock().await.ready().await
979    }
980}
981
982fn native_shutdown(stream: &tokio::net::TcpStream, how: Shutdown) {
983    _ = stream
984        .as_socketlike_view::<std::net::TcpStream>()
985        .shutdown(how);
986}
987
988fn try_lock_for_stream<T>(mutex: &Mutex<T>) -> Result<tokio::sync::MutexGuard<'_, T>, StreamError> {
989    mutex
990        .try_lock()
991        .map_err(|_| StreamError::trap("concurrent access to resource not supported"))
992}
993
994fn try_lock_for_socket<T>(mutex: &Mutex<T>) -> Result<tokio::sync::MutexGuard<'_, T>, SocketError> {
995    mutex.try_lock().map_err(|_| {
996        SocketError::trap(anyhow::anyhow!(
997            "concurrent access to resource not supported"
998        ))
999    })
1000}