use super::network::SocketAddressFamily;
use super::{
with_ambient_tokio_runtime, HostInputStream, HostOutputStream, SocketResult, StreamError,
};
use crate::{AbortOnDropJoinHandle, Subscribe};
use anyhow::{Error, Result};
use cap_net_ext::AddressFamily;
use futures::Future;
use io_lifetimes::views::SocketlikeView;
use io_lifetimes::AsSocketlike;
use rustix::net::sockopt;
use std::io;
use std::mem;
use std::pin::Pin;
use std::sync::Arc;
const DEFAULT_BACKLOG: u32 = 128;
pub(crate) enum TcpState {
Default(tokio::net::TcpSocket),
BindStarted(tokio::net::TcpSocket),
Bound(tokio::net::TcpSocket),
ListenStarted(tokio::net::TcpSocket),
Listening {
listener: tokio::net::TcpListener,
pending_accept: Option<io::Result<tokio::net::TcpStream>>,
},
Connecting(Pin<Box<dyn Future<Output = io::Result<tokio::net::TcpStream>> + Send>>),
ConnectReady(io::Result<tokio::net::TcpStream>),
Connected(Arc<tokio::net::TcpStream>),
Closed,
}
pub struct TcpSocket {
pub(crate) tcp_state: TcpState,
pub(crate) listen_backlog_size: u32,
pub(crate) family: SocketAddressFamily,
#[cfg(target_os = "macos")]
pub(crate) receive_buffer_size: Option<usize>,
#[cfg(target_os = "macos")]
pub(crate) send_buffer_size: Option<usize>,
#[cfg(target_os = "macos")]
pub(crate) hop_limit: Option<u8>,
#[cfg(target_os = "macos")]
pub(crate) keep_alive_idle_time: Option<std::time::Duration>,
}
pub(crate) struct TcpReadStream {
stream: Arc<tokio::net::TcpStream>,
closed: bool,
}
impl TcpReadStream {
pub(crate) fn new(stream: Arc<tokio::net::TcpStream>) -> Self {
Self {
stream,
closed: false,
}
}
}
#[async_trait::async_trait]
impl HostInputStream for TcpReadStream {
fn read(&mut self, size: usize) -> Result<bytes::Bytes, StreamError> {
if self.closed {
return Err(StreamError::Closed);
}
if size == 0 {
return Ok(bytes::Bytes::new());
}
let mut buf = bytes::BytesMut::with_capacity(size);
let n = match self.stream.try_read_buf(&mut buf) {
Ok(0) => {
self.closed = true;
0
}
Ok(n) => n,
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => 0,
Err(e) => {
self.closed = true;
return Err(StreamError::LastOperationFailed(e.into()));
}
};
buf.truncate(n);
Ok(buf.freeze())
}
}
#[async_trait::async_trait]
impl Subscribe for TcpReadStream {
async fn ready(&mut self) {
if self.closed {
return;
}
self.stream.readable().await.unwrap();
}
}
const SOCKET_READY_SIZE: usize = 1024 * 1024 * 1024;
pub(crate) struct TcpWriteStream {
stream: Arc<tokio::net::TcpStream>,
last_write: LastWrite,
}
enum LastWrite {
Waiting(AbortOnDropJoinHandle<Result<()>>),
Error(Error),
Done,
}
impl TcpWriteStream {
pub(crate) fn new(stream: Arc<tokio::net::TcpStream>) -> Self {
Self {
stream,
last_write: LastWrite::Done,
}
}
fn background_write(&mut self, mut bytes: bytes::Bytes) {
assert!(matches!(self.last_write, LastWrite::Done));
let stream = self.stream.clone();
self.last_write = LastWrite::Waiting(crate::spawn(async move {
while !bytes.is_empty() {
stream.writable().await?;
match stream.try_write(&bytes) {
Ok(n) => {
let _ = bytes.split_to(n);
}
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => continue,
Err(e) => return Err(e.into()),
}
}
Ok(())
}));
}
}
impl HostOutputStream for TcpWriteStream {
fn write(&mut self, mut bytes: bytes::Bytes) -> Result<(), StreamError> {
match self.last_write {
LastWrite::Done => {}
LastWrite::Waiting(_) | LastWrite::Error(_) => {
return Err(StreamError::Trap(anyhow::anyhow!(
"unpermitted: must call check_write first"
)));
}
}
while !bytes.is_empty() {
match self.stream.try_write(&bytes) {
Ok(n) => {
let _ = bytes.split_to(n);
}
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
self.background_write(bytes);
return Ok(());
}
Err(e) => return Err(StreamError::LastOperationFailed(e.into())),
}
}
Ok(())
}
fn flush(&mut self) -> Result<(), StreamError> {
Ok(())
}
fn check_write(&mut self) -> Result<usize, StreamError> {
match mem::replace(&mut self.last_write, LastWrite::Done) {
LastWrite::Waiting(task) => {
self.last_write = LastWrite::Waiting(task);
return Ok(0);
}
LastWrite::Done => {}
LastWrite::Error(e) => return Err(StreamError::LastOperationFailed(e.into())),
}
let writable = self.stream.writable();
futures::pin_mut!(writable);
if super::poll_noop(writable).is_none() {
return Ok(0);
}
Ok(SOCKET_READY_SIZE)
}
}
#[async_trait::async_trait]
impl Subscribe for TcpWriteStream {
async fn ready(&mut self) {
if let LastWrite::Waiting(task) = &mut self.last_write {
self.last_write = match task.await {
Ok(()) => LastWrite::Done,
Err(e) => LastWrite::Error(e),
};
}
if let LastWrite::Done = self.last_write {
self.stream.writable().await.unwrap();
}
}
}
impl TcpSocket {
pub fn new(family: AddressFamily) -> io::Result<Self> {
with_ambient_tokio_runtime(|| {
let (socket, family) = match family {
AddressFamily::Ipv4 => {
let socket = tokio::net::TcpSocket::new_v4()?;
(socket, SocketAddressFamily::Ipv4)
}
AddressFamily::Ipv6 => {
let socket = tokio::net::TcpSocket::new_v6()?;
sockopt::set_ipv6_v6only(&socket, true)?;
(socket, SocketAddressFamily::Ipv6)
}
};
Self::from_state(TcpState::Default(socket), family)
})
}
pub(crate) fn from_state(state: TcpState, family: SocketAddressFamily) -> io::Result<Self> {
Ok(Self {
tcp_state: state,
listen_backlog_size: DEFAULT_BACKLOG,
family,
#[cfg(target_os = "macos")]
receive_buffer_size: None,
#[cfg(target_os = "macos")]
send_buffer_size: None,
#[cfg(target_os = "macos")]
hop_limit: None,
#[cfg(target_os = "macos")]
keep_alive_idle_time: None,
})
}
pub(crate) fn as_std_view(&self) -> SocketResult<SocketlikeView<'_, std::net::TcpStream>> {
use crate::bindings::sockets::network::ErrorCode;
match &self.tcp_state {
TcpState::Default(socket) | TcpState::Bound(socket) => {
Ok(socket.as_socketlike_view::<std::net::TcpStream>())
}
TcpState::Connected(stream) => Ok(stream.as_socketlike_view::<std::net::TcpStream>()),
TcpState::Listening { listener, .. } => {
Ok(listener.as_socketlike_view::<std::net::TcpStream>())
}
TcpState::BindStarted(..)
| TcpState::ListenStarted(..)
| TcpState::Connecting(..)
| TcpState::ConnectReady(..)
| TcpState::Closed => Err(ErrorCode::InvalidState.into()),
}
}
}
#[async_trait::async_trait]
impl Subscribe for TcpSocket {
async fn ready(&mut self) {
match &mut self.tcp_state {
TcpState::Default(..)
| TcpState::BindStarted(..)
| TcpState::Bound(..)
| TcpState::ListenStarted(..)
| TcpState::ConnectReady(..)
| TcpState::Closed
| TcpState::Connected(..) => {
}
TcpState::Connecting(future) => {
self.tcp_state = TcpState::ConnectReady(future.as_mut().await);
}
TcpState::Listening {
listener,
pending_accept,
} => match pending_accept {
Some(_) => {}
None => {
let result = futures::future::poll_fn(|cx| {
listener.poll_accept(cx).map_ok(|(stream, _)| stream)
})
.await;
*pending_accept = Some(result);
}
},
}
}
}