1use crate::bindings::sockets::tcp::ErrorCode;
2use crate::host::network;
3use crate::network::SocketAddressFamily;
4use crate::runtime::{with_ambient_tokio_runtime, AbortOnDropJoinHandle};
5use crate::{
6 DynInputStream, DynOutputStream, InputStream, OutputStream, Pollable, SocketError,
7 SocketResult, StreamError,
8};
9use anyhow::Result;
10use cap_net_ext::AddressFamily;
11use futures::Future;
12use io_lifetimes::views::SocketlikeView;
13use io_lifetimes::AsSocketlike;
14use rustix::io::Errno;
15use rustix::net::sockopt;
16use std::io;
17use std::mem;
18use std::net::{Shutdown, SocketAddr};
19use std::pin::Pin;
20use std::sync::Arc;
21use std::task::Poll;
22use tokio::sync::Mutex;
23
24const DEFAULT_BACKLOG: u32 = 128;
26
27enum TcpState {
32 Default(tokio::net::TcpSocket),
34
35 BindStarted(tokio::net::TcpSocket),
37
38 Bound(tokio::net::TcpSocket),
41
42 ListenStarted(tokio::net::TcpSocket),
44
45 Listening {
47 listener: tokio::net::TcpListener,
48 pending_accept: Option<io::Result<tokio::net::TcpStream>>,
49 },
50
51 Connecting(Pin<Box<dyn Future<Output = io::Result<tokio::net::TcpStream>> + Send>>),
53
54 ConnectReady(io::Result<tokio::net::TcpStream>),
56
57 Connected {
59 stream: Arc<tokio::net::TcpStream>,
60
61 reader: Arc<Mutex<TcpReader>>,
63 writer: Arc<Mutex<TcpWriter>>,
64 },
65
66 Closed,
67}
68
69impl std::fmt::Debug for TcpState {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 match self {
72 Self::Default(_) => f.debug_tuple("Default").finish(),
73 Self::BindStarted(_) => f.debug_tuple("BindStarted").finish(),
74 Self::Bound(_) => f.debug_tuple("Bound").finish(),
75 Self::ListenStarted(_) => f.debug_tuple("ListenStarted").finish(),
76 Self::Listening { pending_accept, .. } => f
77 .debug_struct("Listening")
78 .field("pending_accept", pending_accept)
79 .finish(),
80 Self::Connecting(_) => f.debug_tuple("Connecting").finish(),
81 Self::ConnectReady(_) => f.debug_tuple("ConnectReady").finish(),
82 Self::Connected { .. } => f.debug_tuple("Connected").finish(),
83 Self::Closed => write!(f, "Closed"),
84 }
85 }
86}
87
88pub struct TcpSocket {
90 tcp_state: TcpState,
92
93 listen_backlog_size: u32,
95
96 family: SocketAddressFamily,
97
98 #[cfg(target_os = "macos")]
102 receive_buffer_size: Option<usize>,
103 #[cfg(target_os = "macos")]
104 send_buffer_size: Option<usize>,
105 #[cfg(target_os = "macos")]
106 hop_limit: Option<u8>,
107 #[cfg(target_os = "macos")]
108 keep_alive_idle_time: Option<std::time::Duration>,
109}
110
111impl TcpSocket {
112 pub fn new(family: AddressFamily) -> io::Result<Self> {
114 with_ambient_tokio_runtime(|| {
115 let (socket, family) = match family {
116 AddressFamily::Ipv4 => {
117 let socket = tokio::net::TcpSocket::new_v4()?;
118 (socket, SocketAddressFamily::Ipv4)
119 }
120 AddressFamily::Ipv6 => {
121 let socket = tokio::net::TcpSocket::new_v6()?;
122 sockopt::set_ipv6_v6only(&socket, true)?;
123 (socket, SocketAddressFamily::Ipv6)
124 }
125 };
126
127 Self::from_state(TcpState::Default(socket), family)
128 })
129 }
130
131 fn from_state(state: TcpState, family: SocketAddressFamily) -> io::Result<Self> {
133 Ok(Self {
134 tcp_state: state,
135 listen_backlog_size: DEFAULT_BACKLOG,
136 family,
137 #[cfg(target_os = "macos")]
138 receive_buffer_size: None,
139 #[cfg(target_os = "macos")]
140 send_buffer_size: None,
141 #[cfg(target_os = "macos")]
142 hop_limit: None,
143 #[cfg(target_os = "macos")]
144 keep_alive_idle_time: None,
145 })
146 }
147
148 fn as_std_view(&self) -> SocketResult<SocketlikeView<'_, std::net::TcpStream>> {
149 use crate::bindings::sockets::network::ErrorCode;
150
151 match &self.tcp_state {
152 TcpState::Default(socket) | TcpState::Bound(socket) => {
153 Ok(socket.as_socketlike_view::<std::net::TcpStream>())
154 }
155 TcpState::Connected { stream, .. } => {
156 Ok(stream.as_socketlike_view::<std::net::TcpStream>())
157 }
158 TcpState::Listening { listener, .. } => {
159 Ok(listener.as_socketlike_view::<std::net::TcpStream>())
160 }
161
162 TcpState::BindStarted(..)
163 | TcpState::ListenStarted(..)
164 | TcpState::Connecting(..)
165 | TcpState::ConnectReady(..)
166 | TcpState::Closed => Err(ErrorCode::InvalidState.into()),
167 }
168 }
169}
170
171impl TcpSocket {
172 pub fn start_bind(&mut self, local_address: SocketAddr) -> io::Result<()> {
173 let tokio_socket = match &self.tcp_state {
174 TcpState::Default(socket) => socket,
175 TcpState::BindStarted(..) => return Err(Errno::ALREADY.into()),
176 _ => return Err(Errno::ISCONN.into()),
177 };
178
179 network::util::validate_unicast(&local_address)?;
180 network::util::validate_address_family(&local_address, &self.family)?;
181
182 {
183 let reuse_addr = local_address.port() > 0;
186
187 network::util::set_tcp_reuseaddr(&tokio_socket, reuse_addr)?;
191
192 tokio_socket.bind(local_address).map_err(|error| {
194 match Errno::from_io_error(&error) {
195 Some(Errno::AFNOSUPPORT) => io::Error::new(
203 io::ErrorKind::InvalidInput,
204 "The specified address is not a valid address for the address family of the specified socket",
205 ),
206
207 #[cfg(windows)]
210 Some(Errno::NOBUFS) => io::Error::new(io::ErrorKind::AddrInUse, "no more free local ports"),
211
212 _ => error,
213 }
214 })?;
215
216 self.tcp_state = match std::mem::replace(&mut self.tcp_state, TcpState::Closed) {
217 TcpState::Default(socket) => TcpState::BindStarted(socket),
218 _ => unreachable!(),
219 };
220
221 Ok(())
222 }
223 }
224
225 pub fn finish_bind(&mut self) -> SocketResult<()> {
226 match std::mem::replace(&mut self.tcp_state, TcpState::Closed) {
227 TcpState::BindStarted(socket) => {
228 self.tcp_state = TcpState::Bound(socket);
229 Ok(())
230 }
231 current_state => {
232 self.tcp_state = current_state;
234 Err(ErrorCode::NotInProgress.into())
235 }
236 }
237 }
238
239 pub fn start_connect(&mut self, remote_address: SocketAddr) -> SocketResult<()> {
240 match self.tcp_state {
241 TcpState::Default(..) | TcpState::Bound(..) => {}
242
243 TcpState::Connecting(..) | TcpState::ConnectReady(..) => {
244 return Err(ErrorCode::ConcurrencyConflict.into())
245 }
246
247 _ => return Err(ErrorCode::InvalidState.into()),
248 };
249
250 network::util::validate_unicast(&remote_address)?;
251 network::util::validate_remote_address(&remote_address)?;
252 network::util::validate_address_family(&remote_address, &self.family)?;
253
254 let (TcpState::Default(tokio_socket) | TcpState::Bound(tokio_socket)) =
255 std::mem::replace(&mut self.tcp_state, TcpState::Closed)
256 else {
257 unreachable!();
258 };
259
260 let future = tokio_socket.connect(remote_address);
261
262 self.tcp_state = TcpState::Connecting(Box::pin(future));
263 Ok(())
264 }
265
266 pub fn finish_connect(&mut self) -> SocketResult<(DynInputStream, DynOutputStream)> {
267 let previous_state = std::mem::replace(&mut self.tcp_state, TcpState::Closed);
268 let result = match previous_state {
269 TcpState::ConnectReady(result) => result,
270 TcpState::Connecting(mut future) => {
271 let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref());
272 match with_ambient_tokio_runtime(|| future.as_mut().poll(&mut cx)) {
273 Poll::Ready(result) => result,
274 Poll::Pending => {
275 self.tcp_state = TcpState::Connecting(future);
276 return Err(ErrorCode::WouldBlock.into());
277 }
278 }
279 }
280 previous_state => {
281 self.tcp_state = previous_state;
282 return Err(ErrorCode::NotInProgress.into());
283 }
284 };
285
286 match result {
287 Ok(stream) => {
288 let stream = Arc::new(stream);
289 let reader = Arc::new(Mutex::new(TcpReader::new(stream.clone())));
290 let writer = Arc::new(Mutex::new(TcpWriter::new(stream.clone())));
291 self.tcp_state = TcpState::Connected {
292 stream,
293 reader: reader.clone(),
294 writer: writer.clone(),
295 };
296 let input: DynInputStream = Box::new(TcpReadStream(reader));
297 let output: DynOutputStream = Box::new(TcpWriteStream(writer));
298 Ok((input, output))
299 }
300 Err(err) => {
301 self.tcp_state = TcpState::Closed;
302 Err(err.into())
303 }
304 }
305 }
306
307 pub fn start_listen(&mut self) -> SocketResult<()> {
308 match std::mem::replace(&mut self.tcp_state, TcpState::Closed) {
309 TcpState::Bound(tokio_socket) => {
310 self.tcp_state = TcpState::ListenStarted(tokio_socket);
311 Ok(())
312 }
313 TcpState::ListenStarted(tokio_socket) => {
314 self.tcp_state = TcpState::ListenStarted(tokio_socket);
315 Err(ErrorCode::ConcurrencyConflict.into())
316 }
317 previous_state => {
318 self.tcp_state = previous_state;
319 Err(ErrorCode::InvalidState.into())
320 }
321 }
322 }
323
324 pub fn finish_listen(&mut self) -> SocketResult<()> {
325 let tokio_socket = match std::mem::replace(&mut self.tcp_state, TcpState::Closed) {
326 TcpState::ListenStarted(tokio_socket) => tokio_socket,
327 previous_state => {
328 self.tcp_state = previous_state;
329 return Err(ErrorCode::NotInProgress.into());
330 }
331 };
332
333 match with_ambient_tokio_runtime(|| tokio_socket.listen(self.listen_backlog_size)) {
334 Ok(listener) => {
335 self.tcp_state = TcpState::Listening {
336 listener,
337 pending_accept: None,
338 };
339 Ok(())
340 }
341 Err(err) => {
342 self.tcp_state = TcpState::Closed;
343
344 Err(match Errno::from_io_error(&err) {
345 #[cfg(windows)]
355 Some(Errno::MFILE) => Errno::NOBUFS.into(),
356
357 _ => err.into(),
358 })
359 }
360 }
361 }
362
363 pub fn accept(&mut self) -> SocketResult<(Self, DynInputStream, DynOutputStream)> {
364 let TcpState::Listening {
365 listener,
366 pending_accept,
367 } = &mut self.tcp_state
368 else {
369 return Err(ErrorCode::InvalidState.into());
370 };
371
372 let result = match pending_accept.take() {
373 Some(result) => result,
374 None => {
375 let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref());
376 match with_ambient_tokio_runtime(|| listener.poll_accept(&mut cx))
377 .map_ok(|(stream, _)| stream)
378 {
379 Poll::Ready(result) => result,
380 Poll::Pending => Err(Errno::WOULDBLOCK.into()),
381 }
382 }
383 };
384
385 let client = result.map_err(|err| match Errno::from_io_error(&err) {
386 #[cfg(windows)]
394 Some(Errno::INPROGRESS) => Errno::INTR.into(),
395
396 #[cfg(target_os = "linux")]
403 Some(
404 Errno::CONNRESET
405 | Errno::NETRESET
406 | Errno::HOSTUNREACH
407 | Errno::HOSTDOWN
408 | Errno::NETDOWN
409 | Errno::NETUNREACH
410 | Errno::PROTO
411 | Errno::NOPROTOOPT
412 | Errno::NONET
413 | Errno::OPNOTSUPP,
414 ) => Errno::CONNABORTED.into(),
415
416 _ => err,
417 })?;
418
419 #[cfg(target_os = "macos")]
420 {
421 if let Some(size) = self.receive_buffer_size {
426 _ = network::util::set_socket_recv_buffer_size(&client, size); }
428
429 if let Some(size) = self.send_buffer_size {
430 _ = network::util::set_socket_send_buffer_size(&client, size); }
432
433 if let (SocketAddressFamily::Ipv6, Some(ttl)) = (self.family, self.hop_limit) {
435 _ = network::util::set_ipv6_unicast_hops(&client, ttl); }
437
438 if let Some(value) = self.keep_alive_idle_time {
439 _ = network::util::set_tcp_keepidle(&client, value); }
441 }
442
443 let client = Arc::new(client);
444
445 let reader = Arc::new(Mutex::new(TcpReader::new(client.clone())));
446 let writer = Arc::new(Mutex::new(TcpWriter::new(client.clone())));
447
448 let input: DynInputStream = Box::new(TcpReadStream(reader.clone()));
449 let output: DynOutputStream = Box::new(TcpWriteStream(writer.clone()));
450 let tcp_socket = TcpSocket::from_state(
451 TcpState::Connected {
452 stream: client,
453 reader,
454 writer,
455 },
456 self.family,
457 )?;
458
459 Ok((tcp_socket, input, output))
460 }
461
462 pub fn local_address(&self) -> SocketResult<SocketAddr> {
463 let view = match self.tcp_state {
464 TcpState::Default(..) => return Err(ErrorCode::InvalidState.into()),
465 TcpState::BindStarted(..) => return Err(ErrorCode::ConcurrencyConflict.into()),
466 _ => self.as_std_view()?,
467 };
468
469 Ok(view.local_addr()?)
470 }
471
472 pub fn remote_address(&self) -> SocketResult<SocketAddr> {
473 let view = match self.tcp_state {
474 TcpState::Connected { .. } => self.as_std_view()?,
475 TcpState::Connecting(..) | TcpState::ConnectReady(..) => {
476 return Err(ErrorCode::ConcurrencyConflict.into())
477 }
478 _ => return Err(ErrorCode::InvalidState.into()),
479 };
480
481 Ok(view.peer_addr()?)
482 }
483
484 pub fn is_listening(&self) -> bool {
485 matches!(self.tcp_state, TcpState::Listening { .. })
486 }
487
488 pub fn address_family(&self) -> SocketAddressFamily {
489 self.family
490 }
491
492 pub fn set_listen_backlog_size(&mut self, value: u32) -> SocketResult<()> {
493 const MIN_BACKLOG: u32 = 1;
494 const MAX_BACKLOG: u32 = i32::MAX as u32; if value == 0 {
497 return Err(ErrorCode::InvalidArgument.into());
498 }
499
500 let value = value.clamp(MIN_BACKLOG, MAX_BACKLOG);
502
503 match &self.tcp_state {
504 TcpState::Default(..) | TcpState::Bound(..) => {
505 }
507 TcpState::Listening { listener, .. } => {
508 rustix::net::listen(&listener, value.try_into().unwrap())
512 .map_err(|_| ErrorCode::NotSupported)?;
513 }
514 _ => return Err(ErrorCode::InvalidState.into()),
515 }
516 self.listen_backlog_size = value;
517
518 Ok(())
519 }
520
521 pub fn keep_alive_enabled(&self) -> SocketResult<bool> {
522 let view = &*self.as_std_view()?;
523 Ok(sockopt::get_socket_keepalive(view)?)
524 }
525
526 pub fn set_keep_alive_enabled(&self, value: bool) -> SocketResult<()> {
527 let view = &*self.as_std_view()?;
528 Ok(sockopt::set_socket_keepalive(view, value)?)
529 }
530
531 pub fn keep_alive_idle_time(&self) -> SocketResult<std::time::Duration> {
532 let view = &*self.as_std_view()?;
533 Ok(sockopt::get_tcp_keepidle(view)?)
534 }
535
536 pub fn set_keep_alive_idle_time(&mut self, duration: std::time::Duration) -> SocketResult<()> {
537 {
538 let view = &*self.as_std_view()?;
539 network::util::set_tcp_keepidle(view, duration)?;
540 }
541
542 #[cfg(target_os = "macos")]
543 {
544 self.keep_alive_idle_time = Some(duration);
545 }
546
547 Ok(())
548 }
549
550 pub fn keep_alive_interval(&self) -> SocketResult<std::time::Duration> {
551 let view = &*self.as_std_view()?;
552 Ok(sockopt::get_tcp_keepintvl(view)?)
553 }
554
555 pub fn set_keep_alive_interval(&self, duration: std::time::Duration) -> SocketResult<()> {
556 let view = &*self.as_std_view()?;
557 Ok(network::util::set_tcp_keepintvl(view, duration)?)
558 }
559
560 pub fn keep_alive_count(&self) -> SocketResult<u32> {
561 let view = &*self.as_std_view()?;
562 Ok(sockopt::get_tcp_keepcnt(view)?)
563 }
564
565 pub fn set_keep_alive_count(&self, value: u32) -> SocketResult<()> {
566 let view = &*self.as_std_view()?;
567 Ok(network::util::set_tcp_keepcnt(view, value)?)
568 }
569
570 pub fn hop_limit(&self) -> SocketResult<u8> {
571 let view = &*self.as_std_view()?;
572
573 let ttl = match self.family {
574 SocketAddressFamily::Ipv4 => network::util::get_ip_ttl(view)?,
575 SocketAddressFamily::Ipv6 => network::util::get_ipv6_unicast_hops(view)?,
576 };
577
578 Ok(ttl)
579 }
580
581 pub fn set_hop_limit(&mut self, value: u8) -> SocketResult<()> {
582 {
583 let view = &*self.as_std_view()?;
584
585 match self.family {
586 SocketAddressFamily::Ipv4 => network::util::set_ip_ttl(view, value)?,
587 SocketAddressFamily::Ipv6 => network::util::set_ipv6_unicast_hops(view, value)?,
588 }
589 }
590
591 #[cfg(target_os = "macos")]
592 {
593 self.hop_limit = Some(value);
594 }
595
596 Ok(())
597 }
598
599 pub fn receive_buffer_size(&self) -> SocketResult<usize> {
600 let view = &*self.as_std_view()?;
601
602 Ok(network::util::get_socket_recv_buffer_size(view)?)
603 }
604
605 pub fn set_receive_buffer_size(&mut self, value: usize) -> SocketResult<()> {
606 {
607 let view = &*self.as_std_view()?;
608
609 network::util::set_socket_recv_buffer_size(view, value)?;
610 }
611
612 #[cfg(target_os = "macos")]
613 {
614 self.receive_buffer_size = Some(value);
615 }
616
617 Ok(())
618 }
619
620 pub fn send_buffer_size(&self) -> SocketResult<usize> {
621 let view = &*self.as_std_view()?;
622
623 Ok(network::util::get_socket_send_buffer_size(view)?)
624 }
625
626 pub fn set_send_buffer_size(&mut self, value: usize) -> SocketResult<()> {
627 {
628 let view = &*self.as_std_view()?;
629
630 network::util::set_socket_send_buffer_size(view, value)?;
631 }
632
633 #[cfg(target_os = "macos")]
634 {
635 self.send_buffer_size = Some(value);
636 }
637
638 Ok(())
639 }
640
641 pub fn shutdown(&self, how: Shutdown) -> SocketResult<()> {
642 let TcpState::Connected { reader, writer, .. } = &self.tcp_state else {
643 return Err(ErrorCode::InvalidState.into());
644 };
645
646 if let Shutdown::Both | Shutdown::Read = how {
647 try_lock_for_socket(reader)?.shutdown();
648 }
649
650 if let Shutdown::Both | Shutdown::Write = how {
651 try_lock_for_socket(writer)?.shutdown();
652 }
653
654 Ok(())
655 }
656}
657
658#[async_trait::async_trait]
659impl Pollable for TcpSocket {
660 async fn ready(&mut self) {
661 match &mut self.tcp_state {
662 TcpState::Default(..)
663 | TcpState::BindStarted(..)
664 | TcpState::Bound(..)
665 | TcpState::ListenStarted(..)
666 | TcpState::ConnectReady(..)
667 | TcpState::Closed
668 | TcpState::Connected { .. } => {
669 }
671 TcpState::Connecting(future) => {
672 self.tcp_state = TcpState::ConnectReady(future.as_mut().await);
673 }
674 TcpState::Listening {
675 listener,
676 pending_accept,
677 } => match pending_accept {
678 Some(_) => {}
679 None => {
680 let result = futures::future::poll_fn(|cx| {
681 listener.poll_accept(cx).map_ok(|(stream, _)| stream)
682 })
683 .await;
684 *pending_accept = Some(result);
685 }
686 },
687 }
688 }
689}
690
691struct TcpReader {
692 stream: Arc<tokio::net::TcpStream>,
693 closed: bool,
694}
695
696impl TcpReader {
697 fn new(stream: Arc<tokio::net::TcpStream>) -> Self {
698 Self {
699 stream,
700 closed: false,
701 }
702 }
703 fn read(&mut self, size: usize) -> Result<bytes::Bytes, StreamError> {
704 if self.closed {
705 return Err(StreamError::Closed);
706 }
707 if size == 0 {
708 return Ok(bytes::Bytes::new());
709 }
710
711 let mut buf = bytes::BytesMut::with_capacity(size);
712 let n = match self.stream.try_read_buf(&mut buf) {
713 Ok(0) => {
715 self.closed = true;
716 return Err(StreamError::Closed);
717 }
718 Ok(n) => n,
719
720 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => 0,
723
724 Err(e) => {
725 self.closed = true;
726 return Err(StreamError::LastOperationFailed(e.into()));
727 }
728 };
729
730 buf.truncate(n);
731 Ok(buf.freeze())
732 }
733
734 fn shutdown(&mut self) {
735 native_shutdown(&self.stream, Shutdown::Read);
736 self.closed = true;
737 }
738
739 async fn ready(&mut self) {
740 if self.closed {
741 return;
742 }
743
744 self.stream.readable().await.unwrap();
745 }
746}
747
748struct TcpReadStream(Arc<Mutex<TcpReader>>);
749
750#[async_trait::async_trait]
751impl InputStream for TcpReadStream {
752 fn read(&mut self, size: usize) -> Result<bytes::Bytes, StreamError> {
753 try_lock_for_stream(&self.0)?.read(size)
754 }
755}
756
757#[async_trait::async_trait]
758impl Pollable for TcpReadStream {
759 async fn ready(&mut self) {
760 self.0.lock().await.ready().await
761 }
762}
763
764const SOCKET_READY_SIZE: usize = 1024 * 1024 * 1024;
765
766struct TcpWriter {
767 stream: Arc<tokio::net::TcpStream>,
768 state: WriteState,
769}
770
771enum WriteState {
772 Ready,
773 Writing(AbortOnDropJoinHandle<io::Result<()>>),
774 Closing(AbortOnDropJoinHandle<io::Result<()>>),
775 Closed,
776 Error(io::Error),
777}
778
779impl TcpWriter {
780 fn new(stream: Arc<tokio::net::TcpStream>) -> Self {
781 Self {
782 stream,
783 state: WriteState::Ready,
784 }
785 }
786
787 fn try_write_portable(stream: &tokio::net::TcpStream, buf: &[u8]) -> io::Result<usize> {
788 stream.try_write(buf).map_err(|error| {
789 match Errno::from_io_error(&error) {
790 #[cfg(windows)]
794 Some(Errno::SHUTDOWN) => io::Error::new(io::ErrorKind::BrokenPipe, error),
795
796 _ => error,
797 }
798 })
799 }
800
801 fn background_write(&mut self, mut bytes: bytes::Bytes) {
804 assert!(matches!(self.state, WriteState::Ready));
805
806 let stream = self.stream.clone();
807 self.state = WriteState::Writing(crate::runtime::spawn(async move {
808 while !bytes.is_empty() {
814 stream.writable().await?;
815 match Self::try_write_portable(&stream, &bytes) {
816 Ok(n) => {
817 let _ = bytes.split_to(n);
818 }
819 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => continue,
820 Err(e) => return Err(e.into()),
821 }
822 }
823
824 Ok(())
825 }));
826 }
827
828 fn write(&mut self, mut bytes: bytes::Bytes) -> Result<(), StreamError> {
829 match self.state {
830 WriteState::Ready => {}
831 WriteState::Closed => return Err(StreamError::Closed),
832 WriteState::Writing(_) | WriteState::Closing(_) | WriteState::Error(_) => {
833 return Err(StreamError::Trap(anyhow::anyhow!(
834 "unpermitted: must call check_write first"
835 )));
836 }
837 }
838 while !bytes.is_empty() {
839 match Self::try_write_portable(&self.stream, &bytes) {
840 Ok(n) => {
841 let _ = bytes.split_to(n);
842 }
843
844 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
845 self.background_write(bytes);
848
849 return Ok(());
850 }
851
852 Err(e) if e.kind() == std::io::ErrorKind::BrokenPipe => {
853 self.state = WriteState::Closed;
854 return Err(StreamError::Closed);
855 }
856
857 Err(e) => return Err(StreamError::LastOperationFailed(e.into())),
858 }
859 }
860
861 Ok(())
862 }
863
864 fn flush(&mut self) -> Result<(), StreamError> {
865 match self.state {
869 WriteState::Ready
870 | WriteState::Writing(_)
871 | WriteState::Closing(_)
872 | WriteState::Error(_) => Ok(()),
873 WriteState::Closed => Err(StreamError::Closed),
874 }
875 }
876
877 fn check_write(&mut self) -> Result<usize, StreamError> {
878 match mem::replace(&mut self.state, WriteState::Closed) {
879 WriteState::Writing(task) => {
880 self.state = WriteState::Writing(task);
881 return Ok(0);
882 }
883 WriteState::Closing(task) => {
884 self.state = WriteState::Closing(task);
885 return Ok(0);
886 }
887 WriteState::Ready => {
888 self.state = WriteState::Ready;
889 }
890 WriteState::Closed => return Err(StreamError::Closed),
891 WriteState::Error(e) => return Err(StreamError::LastOperationFailed(e.into())),
892 }
893
894 let writable = self.stream.writable();
895 futures::pin_mut!(writable);
896 if crate::runtime::poll_noop(writable).is_none() {
897 return Ok(0);
898 }
899 Ok(SOCKET_READY_SIZE)
900 }
901
902 fn shutdown(&mut self) {
903 self.state = match mem::replace(&mut self.state, WriteState::Closed) {
904 WriteState::Ready => {
906 native_shutdown(&self.stream, Shutdown::Write);
907 WriteState::Closed
908 }
909
910 WriteState::Writing(write) => {
912 let stream = self.stream.clone();
913 WriteState::Closing(crate::runtime::spawn(async move {
914 let result = write.await;
915 native_shutdown(&stream, Shutdown::Write);
916 result
917 }))
918 }
919
920 s => s,
921 };
922 }
923
924 async fn cancel(&mut self) {
925 match mem::replace(&mut self.state, WriteState::Closed) {
926 WriteState::Writing(task) | WriteState::Closing(task) => _ = task.cancel().await,
927 _ => {}
928 }
929 }
930
931 async fn ready(&mut self) {
932 match &mut self.state {
933 WriteState::Writing(task) => {
934 self.state = match task.await {
935 Ok(()) => WriteState::Ready,
936 Err(e) => WriteState::Error(e),
937 }
938 }
939 WriteState::Closing(task) => {
940 self.state = match task.await {
941 Ok(()) => WriteState::Closed,
942 Err(e) => WriteState::Error(e),
943 }
944 }
945 _ => {}
946 }
947
948 if let WriteState::Ready = self.state {
949 self.stream.writable().await.unwrap();
950 }
951 }
952}
953
954struct TcpWriteStream(Arc<Mutex<TcpWriter>>);
955
956#[async_trait::async_trait]
957impl OutputStream for TcpWriteStream {
958 fn write(&mut self, bytes: bytes::Bytes) -> Result<(), StreamError> {
959 try_lock_for_stream(&self.0)?.write(bytes)
960 }
961
962 fn flush(&mut self) -> Result<(), StreamError> {
963 try_lock_for_stream(&self.0)?.flush()
964 }
965
966 fn check_write(&mut self) -> Result<usize, StreamError> {
967 try_lock_for_stream(&self.0)?.check_write()
968 }
969
970 async fn cancel(&mut self) {
971 self.0.lock().await.cancel().await
972 }
973}
974
975#[async_trait::async_trait]
976impl Pollable for TcpWriteStream {
977 async fn ready(&mut self) {
978 self.0.lock().await.ready().await
979 }
980}
981
982fn native_shutdown(stream: &tokio::net::TcpStream, how: Shutdown) {
983 _ = stream
984 .as_socketlike_view::<std::net::TcpStream>()
985 .shutdown(how);
986}
987
988fn try_lock_for_stream<T>(mutex: &Mutex<T>) -> Result<tokio::sync::MutexGuard<'_, T>, StreamError> {
989 mutex
990 .try_lock()
991 .map_err(|_| StreamError::trap("concurrent access to resource not supported"))
992}
993
994fn try_lock_for_socket<T>(mutex: &Mutex<T>) -> Result<tokio::sync::MutexGuard<'_, T>, SocketError> {
995 mutex.try_lock().map_err(|_| {
996 SocketError::trap(anyhow::anyhow!(
997 "concurrent access to resource not supported"
998 ))
999 })
1000}