use super::network::SocketAddressFamily;
use super::{HostInputStream, HostOutputStream, StreamError};
use crate::preview2::{
with_ambient_tokio_runtime, AbortOnDropJoinHandle, InputStream, OutputStream, Subscribe,
};
use anyhow::{Error, Result};
use cap_net_ext::{AddressFamily, Blocking, TcpListenerExt};
use cap_std::net::TcpListener;
use io_lifetimes::raw::{FromRawSocketlike, IntoRawSocketlike};
use rustix::net::sockopt;
use std::io;
use std::mem;
use std::sync::Arc;
use tokio::io::Interest;
pub(crate) enum TcpState {
Default,
BindStarted,
Bound,
ListenStarted,
Listening,
Connecting,
ConnectReady,
ConnectFailed,
Connected,
}
pub struct TcpSocket {
pub(crate) inner: Arc<tokio::net::TcpStream>,
pub(crate) tcp_state: TcpState,
pub(crate) listen_backlog_size: Option<i32>,
pub(crate) family: SocketAddressFamily,
#[cfg(target_os = "macos")]
pub(crate) receive_buffer_size: Option<usize>,
#[cfg(target_os = "macos")]
pub(crate) send_buffer_size: Option<usize>,
#[cfg(target_os = "macos")]
pub(crate) hop_limit: Option<u8>,
#[cfg(target_os = "macos")]
pub(crate) keep_alive_idle_time: Option<std::time::Duration>,
}
pub(crate) struct TcpReadStream {
stream: Arc<tokio::net::TcpStream>,
closed: bool,
}
impl TcpReadStream {
fn new(stream: Arc<tokio::net::TcpStream>) -> Self {
Self {
stream,
closed: false,
}
}
}
#[async_trait::async_trait]
impl HostInputStream for TcpReadStream {
fn read(&mut self, size: usize) -> Result<bytes::Bytes, StreamError> {
if self.closed {
return Err(StreamError::Closed);
}
if size == 0 {
return Ok(bytes::Bytes::new());
}
let mut buf = bytes::BytesMut::with_capacity(size);
let n = match self.stream.try_read_buf(&mut buf) {
Ok(0) => {
self.closed = true;
0
}
Ok(n) => n,
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => 0,
Err(e) => {
self.closed = true;
return Err(StreamError::LastOperationFailed(e.into()));
}
};
buf.truncate(n);
Ok(buf.freeze())
}
}
#[async_trait::async_trait]
impl Subscribe for TcpReadStream {
async fn ready(&mut self) {
if self.closed {
return;
}
self.stream.readable().await.unwrap();
}
}
const SOCKET_READY_SIZE: usize = 1024 * 1024 * 1024;
pub(crate) struct TcpWriteStream {
stream: Arc<tokio::net::TcpStream>,
last_write: LastWrite,
}
enum LastWrite {
Waiting(AbortOnDropJoinHandle<Result<()>>),
Error(Error),
Done,
}
impl TcpWriteStream {
pub(crate) fn new(stream: Arc<tokio::net::TcpStream>) -> Self {
Self {
stream,
last_write: LastWrite::Done,
}
}
fn background_write(&mut self, mut bytes: bytes::Bytes) {
assert!(matches!(self.last_write, LastWrite::Done));
let stream = self.stream.clone();
self.last_write = LastWrite::Waiting(crate::preview2::spawn(async move {
while !bytes.is_empty() {
stream.writable().await?;
match stream.try_write(&bytes) {
Ok(n) => {
let _ = bytes.split_to(n);
}
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => continue,
Err(e) => return Err(e.into()),
}
}
Ok(())
}));
}
}
impl HostOutputStream for TcpWriteStream {
fn write(&mut self, mut bytes: bytes::Bytes) -> Result<(), StreamError> {
match self.last_write {
LastWrite::Done => {}
LastWrite::Waiting(_) | LastWrite::Error(_) => {
return Err(StreamError::Trap(anyhow::anyhow!(
"unpermitted: must call check_write first"
)));
}
}
while !bytes.is_empty() {
match self.stream.try_write(&bytes) {
Ok(n) => {
let _ = bytes.split_to(n);
}
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
self.background_write(bytes);
return Ok(());
}
Err(e) => return Err(StreamError::LastOperationFailed(e.into())),
}
}
Ok(())
}
fn flush(&mut self) -> Result<(), StreamError> {
Ok(())
}
fn check_write(&mut self) -> Result<usize, StreamError> {
match mem::replace(&mut self.last_write, LastWrite::Done) {
LastWrite::Waiting(task) => {
self.last_write = LastWrite::Waiting(task);
return Ok(0);
}
LastWrite::Done => {}
LastWrite::Error(e) => return Err(StreamError::LastOperationFailed(e.into())),
}
let writable = self.stream.writable();
futures::pin_mut!(writable);
if super::poll_noop(writable).is_none() {
return Ok(0);
}
Ok(SOCKET_READY_SIZE)
}
}
#[async_trait::async_trait]
impl Subscribe for TcpWriteStream {
async fn ready(&mut self) {
if let LastWrite::Waiting(task) = &mut self.last_write {
self.last_write = match task.await {
Ok(()) => LastWrite::Done,
Err(e) => LastWrite::Error(e),
};
}
if let LastWrite::Done = self.last_write {
self.stream.writable().await.unwrap();
}
}
}
impl TcpSocket {
pub fn new(family: AddressFamily) -> io::Result<Self> {
let tcp_listener = TcpListener::new(family, Blocking::No)?;
let socket_address_family = match family {
AddressFamily::Ipv4 => SocketAddressFamily::Ipv4,
AddressFamily::Ipv6 => SocketAddressFamily::Ipv6 {
v6only: sockopt::get_ipv6_v6only(&tcp_listener)?,
},
};
Self::from_tcp_listener(tcp_listener, socket_address_family)
}
pub(crate) fn from_tcp_stream(
tcp_socket: cap_std::net::TcpStream,
family: SocketAddressFamily,
) -> io::Result<Self> {
let tcp_listener = TcpListener::from(rustix::fd::OwnedFd::from(tcp_socket));
Self::from_tcp_listener(tcp_listener, family)
}
pub(crate) fn from_tcp_listener(
tcp_listener: cap_std::net::TcpListener,
family: SocketAddressFamily,
) -> io::Result<Self> {
let fd = tcp_listener.into_raw_socketlike();
let std_stream = unsafe { std::net::TcpStream::from_raw_socketlike(fd) };
let stream = with_ambient_tokio_runtime(|| tokio::net::TcpStream::try_from(std_stream))?;
Ok(Self {
inner: Arc::new(stream),
tcp_state: TcpState::Default,
listen_backlog_size: None,
family,
#[cfg(target_os = "macos")]
receive_buffer_size: None,
#[cfg(target_os = "macos")]
send_buffer_size: None,
#[cfg(target_os = "macos")]
hop_limit: None,
#[cfg(target_os = "macos")]
keep_alive_idle_time: None,
})
}
pub fn tcp_socket(&self) -> &tokio::net::TcpStream {
&self.inner
}
pub fn as_split(&self) -> (InputStream, OutputStream) {
let input = Box::new(TcpReadStream::new(self.inner.clone()));
let output = Box::new(TcpWriteStream::new(self.inner.clone()));
(InputStream::Host(input), output)
}
}
#[async_trait::async_trait]
impl Subscribe for TcpSocket {
async fn ready(&mut self) {
match self.tcp_state {
TcpState::BindStarted | TcpState::ListenStarted | TcpState::ConnectReady => return,
_ => {}
}
self.inner
.ready(Interest::READABLE | Interest::WRITABLE)
.await
.unwrap();
}
}