compio_driver/iocp/
mod.rsuse std::{
collections::HashMap,
io,
os::windows::{
io::{OwnedHandle, OwnedSocket},
prelude::{AsRawHandle, AsRawSocket},
},
pin::Pin,
sync::Arc,
task::Poll,
time::Duration,
};
use compio_log::{instrument, trace};
use windows_sys::Win32::{Foundation::ERROR_CANCELLED, System::IO::OVERLAPPED};
use crate::{AsyncifyPool, Entry, Key, ProactorBuilder};
pub(crate) mod op;
mod cp;
mod wait;
pub(crate) use windows_sys::Win32::Networking::WinSock::{
SOCKADDR_STORAGE as sockaddr_storage, socklen_t,
};
pub type RawFd = isize;
pub trait AsRawFd {
fn as_raw_fd(&self) -> RawFd;
}
#[derive(Debug)]
pub enum OwnedFd {
File(OwnedHandle),
Socket(OwnedSocket),
}
impl AsRawFd for OwnedFd {
fn as_raw_fd(&self) -> RawFd {
match self {
Self::File(fd) => fd.as_raw_handle() as _,
Self::Socket(s) => s.as_raw_socket() as _,
}
}
}
impl AsRawFd for RawFd {
fn as_raw_fd(&self) -> RawFd {
*self
}
}
impl AsRawFd for std::fs::File {
fn as_raw_fd(&self) -> RawFd {
self.as_raw_handle() as _
}
}
impl AsRawFd for OwnedHandle {
fn as_raw_fd(&self) -> RawFd {
self.as_raw_handle() as _
}
}
impl AsRawFd for socket2::Socket {
fn as_raw_fd(&self) -> RawFd {
self.as_raw_socket() as _
}
}
impl AsRawFd for OwnedSocket {
fn as_raw_fd(&self) -> RawFd {
self.as_raw_socket() as _
}
}
impl AsRawFd for std::process::ChildStdin {
fn as_raw_fd(&self) -> RawFd {
self.as_raw_handle() as _
}
}
impl AsRawFd for std::process::ChildStdout {
fn as_raw_fd(&self) -> RawFd {
self.as_raw_handle() as _
}
}
impl AsRawFd for std::process::ChildStderr {
fn as_raw_fd(&self) -> RawFd {
self.as_raw_handle() as _
}
}
impl From<OwnedHandle> for OwnedFd {
fn from(value: OwnedHandle) -> Self {
Self::File(value)
}
}
impl From<std::fs::File> for OwnedFd {
fn from(value: std::fs::File) -> Self {
Self::File(OwnedHandle::from(value))
}
}
impl From<std::process::ChildStdin> for OwnedFd {
fn from(value: std::process::ChildStdin) -> Self {
Self::File(OwnedHandle::from(value))
}
}
impl From<std::process::ChildStdout> for OwnedFd {
fn from(value: std::process::ChildStdout) -> Self {
Self::File(OwnedHandle::from(value))
}
}
impl From<std::process::ChildStderr> for OwnedFd {
fn from(value: std::process::ChildStderr) -> Self {
Self::File(OwnedHandle::from(value))
}
}
impl From<OwnedSocket> for OwnedFd {
fn from(value: OwnedSocket) -> Self {
Self::Socket(value)
}
}
impl From<socket2::Socket> for OwnedFd {
fn from(value: socket2::Socket) -> Self {
Self::Socket(OwnedSocket::from(value))
}
}
pub enum OpType {
Overlapped,
Blocking,
Event(RawFd),
}
pub trait OpCode {
fn op_type(&self) -> OpType {
OpType::Overlapped
}
unsafe fn operate(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> Poll<io::Result<usize>>;
unsafe fn cancel(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> io::Result<()> {
let _optr = optr; Ok(())
}
}
pub(crate) struct Driver {
port: cp::Port,
waits: HashMap<usize, wait::Wait>,
pool: AsyncifyPool,
notify_overlapped: Arc<Overlapped>,
}
impl Driver {
pub fn new(builder: &ProactorBuilder) -> io::Result<Self> {
instrument!(compio_log::Level::TRACE, "new", ?builder);
let port = cp::Port::new()?;
let driver = port.as_raw_handle() as _;
Ok(Self {
port,
waits: HashMap::default(),
pool: builder.create_or_get_thread_pool(),
notify_overlapped: Arc::new(Overlapped::new(driver)),
})
}
pub fn create_op<T: OpCode + 'static>(&self, op: T) -> Key<T> {
Key::new(self.port.as_raw_handle() as _, op)
}
pub fn attach(&mut self, fd: RawFd) -> io::Result<()> {
self.port.attach(fd)
}
pub fn cancel(&mut self, op: &mut Key<dyn OpCode>) {
instrument!(compio_log::Level::TRACE, "cancel", ?op);
trace!("cancel RawOp");
let overlapped_ptr = op.as_mut_ptr();
if let Some(w) = self.waits.get_mut(&op.user_data()) {
if w.cancel().is_ok() {
self.port.post_raw(overlapped_ptr).ok();
}
}
let op = op.as_op_pin();
trace!("call OpCode::cancel");
unsafe { op.cancel(overlapped_ptr.cast()) }.ok();
}
pub fn push(&mut self, op: &mut Key<dyn OpCode>) -> Poll<io::Result<usize>> {
instrument!(compio_log::Level::TRACE, "push", ?op);
let user_data = op.user_data();
trace!("push RawOp");
let optr = op.as_mut_ptr();
let op_pin = op.as_op_pin();
match op_pin.op_type() {
OpType::Overlapped => unsafe { op_pin.operate(optr.cast()) },
OpType::Blocking => loop {
if self.push_blocking(user_data)? {
break Poll::Pending;
} else {
unsafe {
self.poll(None)?;
}
}
},
OpType::Event(e) => {
self.waits
.insert(user_data, wait::Wait::new(&self.port, e, op)?);
Poll::Pending
}
}
}
fn push_blocking(&mut self, user_data: usize) -> io::Result<bool> {
let port = self.port.handle();
Ok(self
.pool
.dispatch(move || {
let mut op = unsafe { Key::<dyn OpCode>::new_unchecked(user_data) };
let optr = op.as_mut_ptr();
let res = op.operate_blocking();
port.post(res, optr).ok();
})
.is_ok())
}
fn create_entry(
notify_user_data: usize,
waits: &mut HashMap<usize, wait::Wait>,
entry: Entry,
) -> Option<Entry> {
let user_data = entry.user_data();
if user_data != notify_user_data {
if let Some(w) = waits.remove(&user_data) {
if w.is_cancelled() {
Some(Entry::new(
user_data,
Err(io::Error::from_raw_os_error(ERROR_CANCELLED as _)),
))
} else if entry.result.is_err() {
Some(entry)
} else {
let mut op = unsafe { Key::<dyn OpCode>::new_unchecked(user_data) };
let result = op.operate_blocking();
Some(Entry::new(user_data, result))
}
} else {
Some(entry)
}
} else {
None
}
}
pub unsafe fn poll(&mut self, timeout: Option<Duration>) -> io::Result<()> {
instrument!(compio_log::Level::TRACE, "poll", ?timeout);
let notify_user_data = self.notify_overlapped.as_ref() as *const Overlapped as usize;
for e in self.port.poll(timeout)? {
if let Some(e) = Self::create_entry(notify_user_data, &mut self.waits, e) {
e.notify();
}
}
Ok(())
}
pub fn handle(&self) -> io::Result<NotifyHandle> {
Ok(NotifyHandle::new(
self.port.handle(),
self.notify_overlapped.clone(),
))
}
}
impl AsRawFd for Driver {
fn as_raw_fd(&self) -> RawFd {
self.port.as_raw_handle() as _
}
}
pub struct NotifyHandle {
port: cp::PortHandle,
overlapped: Arc<Overlapped>,
}
impl NotifyHandle {
fn new(port: cp::PortHandle, overlapped: Arc<Overlapped>) -> Self {
Self { port, overlapped }
}
pub fn notify(&self) -> io::Result<()> {
self.port.post_raw(self.overlapped.as_ref())
}
}
#[repr(C)]
pub struct Overlapped {
pub base: OVERLAPPED,
pub driver: RawFd,
}
impl Overlapped {
pub(crate) fn new(driver: RawFd) -> Self {
Self {
base: unsafe { std::mem::zeroed() },
driver,
}
}
}
unsafe impl Send for Overlapped {}
unsafe impl Sync for Overlapped {}