compio_driver/iocp/
mod.rs

1use std::{
2    collections::HashMap,
3    io,
4    os::windows::{
5        io::{OwnedHandle, OwnedSocket},
6        prelude::{AsRawHandle, AsRawSocket},
7    },
8    pin::Pin,
9    sync::Arc,
10    task::Poll,
11    time::Duration,
12};
13
14use compio_log::{instrument, trace};
15use windows_sys::Win32::{Foundation::ERROR_CANCELLED, System::IO::OVERLAPPED};
16
17use crate::{AsyncifyPool, Entry, Key, ProactorBuilder};
18
19pub(crate) mod op;
20
21mod cp;
22mod wait;
23
24pub(crate) use windows_sys::Win32::Networking::WinSock::{
25    SOCKADDR_STORAGE as sockaddr_storage, socklen_t,
26};
27
28/// On windows, handle and socket are in the same size.
29/// Both of them could be attached to an IOCP.
30/// Therefore, both could be seen as fd.
31pub type RawFd = isize;
32
33/// Extracts raw fds.
34pub trait AsRawFd {
35    /// Extracts the raw fd.
36    fn as_raw_fd(&self) -> RawFd;
37}
38
39/// Owned handle or socket on Windows.
40#[derive(Debug)]
41pub enum OwnedFd {
42    /// Win32 handle.
43    File(OwnedHandle),
44    /// Windows socket handle.
45    Socket(OwnedSocket),
46}
47
48impl AsRawFd for OwnedFd {
49    fn as_raw_fd(&self) -> RawFd {
50        match self {
51            Self::File(fd) => fd.as_raw_handle() as _,
52            Self::Socket(s) => s.as_raw_socket() as _,
53        }
54    }
55}
56
57impl AsRawFd for RawFd {
58    fn as_raw_fd(&self) -> RawFd {
59        *self
60    }
61}
62
63impl AsRawFd for std::fs::File {
64    fn as_raw_fd(&self) -> RawFd {
65        self.as_raw_handle() as _
66    }
67}
68
69impl AsRawFd for OwnedHandle {
70    fn as_raw_fd(&self) -> RawFd {
71        self.as_raw_handle() as _
72    }
73}
74
75impl AsRawFd for socket2::Socket {
76    fn as_raw_fd(&self) -> RawFd {
77        self.as_raw_socket() as _
78    }
79}
80
81impl AsRawFd for OwnedSocket {
82    fn as_raw_fd(&self) -> RawFd {
83        self.as_raw_socket() as _
84    }
85}
86
87impl AsRawFd for std::process::ChildStdin {
88    fn as_raw_fd(&self) -> RawFd {
89        self.as_raw_handle() as _
90    }
91}
92
93impl AsRawFd for std::process::ChildStdout {
94    fn as_raw_fd(&self) -> RawFd {
95        self.as_raw_handle() as _
96    }
97}
98
99impl AsRawFd for std::process::ChildStderr {
100    fn as_raw_fd(&self) -> RawFd {
101        self.as_raw_handle() as _
102    }
103}
104
105impl From<OwnedHandle> for OwnedFd {
106    fn from(value: OwnedHandle) -> Self {
107        Self::File(value)
108    }
109}
110
111impl From<std::fs::File> for OwnedFd {
112    fn from(value: std::fs::File) -> Self {
113        Self::File(OwnedHandle::from(value))
114    }
115}
116
117impl From<std::process::ChildStdin> for OwnedFd {
118    fn from(value: std::process::ChildStdin) -> Self {
119        Self::File(OwnedHandle::from(value))
120    }
121}
122
123impl From<std::process::ChildStdout> for OwnedFd {
124    fn from(value: std::process::ChildStdout) -> Self {
125        Self::File(OwnedHandle::from(value))
126    }
127}
128
129impl From<std::process::ChildStderr> for OwnedFd {
130    fn from(value: std::process::ChildStderr) -> Self {
131        Self::File(OwnedHandle::from(value))
132    }
133}
134
135impl From<OwnedSocket> for OwnedFd {
136    fn from(value: OwnedSocket) -> Self {
137        Self::Socket(value)
138    }
139}
140
141impl From<socket2::Socket> for OwnedFd {
142    fn from(value: socket2::Socket) -> Self {
143        Self::Socket(OwnedSocket::from(value))
144    }
145}
146
147/// Operation type.
148pub enum OpType {
149    /// An overlapped operation.
150    Overlapped,
151    /// A blocking operation, needs a thread to spawn. The `operate` method
152    /// should be thread safe.
153    Blocking,
154    /// A Win32 event object to be waited. The user should ensure that the
155    /// handle is valid till operation completes. The `operate` method should be
156    /// thread safe.
157    Event(RawFd),
158}
159
160/// Abstraction of IOCP operations.
161pub trait OpCode {
162    /// Determines that the operation is really overlapped defined by Windows
163    /// API. If not, the driver will try to operate it in another thread.
164    fn op_type(&self) -> OpType {
165        OpType::Overlapped
166    }
167
168    /// Perform Windows API call with given pointer to overlapped struct.
169    ///
170    /// It is always safe to cast `optr` to a pointer to
171    /// [`Overlapped<Self>`].
172    ///
173    /// Don't do heavy work here if [`OpCode::op_type`] returns
174    /// [`OpType::Event`].
175    ///
176    /// # Safety
177    ///
178    /// * `self` must be alive until the operation completes.
179    /// * When [`OpCode::op_type`] returns [`OpType::Blocking`], this method is
180    ///   called in another thread.
181    unsafe fn operate(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> Poll<io::Result<usize>>;
182
183    /// Cancel the async IO operation.
184    ///
185    /// Usually it calls `CancelIoEx`.
186    ///
187    /// # Safety
188    ///
189    /// * Should not use [`Overlapped::op`].
190    unsafe fn cancel(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> io::Result<()> {
191        let _optr = optr; // ignore it
192        Ok(())
193    }
194}
195
196/// Low-level driver of IOCP.
197pub(crate) struct Driver {
198    port: cp::Port,
199    waits: HashMap<usize, wait::Wait>,
200    pool: AsyncifyPool,
201    notify_overlapped: Arc<Overlapped>,
202}
203
204impl Driver {
205    pub fn new(builder: &ProactorBuilder) -> io::Result<Self> {
206        instrument!(compio_log::Level::TRACE, "new", ?builder);
207
208        let port = cp::Port::new()?;
209        let driver = port.as_raw_handle() as _;
210        Ok(Self {
211            port,
212            waits: HashMap::default(),
213            pool: builder.create_or_get_thread_pool(),
214            notify_overlapped: Arc::new(Overlapped::new(driver)),
215        })
216    }
217
218    pub fn create_op<T: OpCode + 'static>(&self, op: T) -> Key<T> {
219        Key::new(self.port.as_raw_handle() as _, op)
220    }
221
222    pub fn attach(&mut self, fd: RawFd) -> io::Result<()> {
223        self.port.attach(fd)
224    }
225
226    pub fn cancel(&mut self, op: &mut Key<dyn OpCode>) {
227        instrument!(compio_log::Level::TRACE, "cancel", ?op);
228        trace!("cancel RawOp");
229        let overlapped_ptr = op.as_mut_ptr();
230        if let Some(w) = self.waits.get_mut(&op.user_data()) {
231            if w.cancel().is_ok() {
232                // The pack has been cancelled successfully, which means no packet will be post
233                // to IOCP. Need not set the result because `create_entry` handles it.
234                self.port.post_raw(overlapped_ptr).ok();
235            }
236        }
237        let op = op.as_op_pin();
238        // It's OK to fail to cancel.
239        trace!("call OpCode::cancel");
240        unsafe { op.cancel(overlapped_ptr.cast()) }.ok();
241    }
242
243    pub fn push(&mut self, op: &mut Key<dyn OpCode>) -> Poll<io::Result<usize>> {
244        instrument!(compio_log::Level::TRACE, "push", ?op);
245        let user_data = op.user_data();
246        trace!("push RawOp");
247        let optr = op.as_mut_ptr();
248        let op_pin = op.as_op_pin();
249        match op_pin.op_type() {
250            OpType::Overlapped => unsafe { op_pin.operate(optr.cast()) },
251            OpType::Blocking => loop {
252                if self.push_blocking(user_data) {
253                    break Poll::Pending;
254                } else {
255                    // It's OK to wait forever, because any blocking task will notify the IOCP after
256                    // it completes.
257                    unsafe {
258                        self.poll(None)?;
259                    }
260                }
261            },
262            OpType::Event(e) => {
263                self.waits
264                    .insert(user_data, wait::Wait::new(&self.port, e, op)?);
265                Poll::Pending
266            }
267        }
268    }
269
270    fn push_blocking(&mut self, user_data: usize) -> bool {
271        let port = self.port.handle();
272        self.pool
273            .dispatch(move || {
274                let mut op = unsafe { Key::<dyn OpCode>::new_unchecked(user_data) };
275                let optr = op.as_mut_ptr();
276                let res = op.operate_blocking();
277                port.post(res, optr).ok();
278            })
279            .is_ok()
280    }
281
282    fn create_entry(
283        notify_user_data: usize,
284        waits: &mut HashMap<usize, wait::Wait>,
285        entry: Entry,
286    ) -> Option<Entry> {
287        let user_data = entry.user_data();
288        if user_data != notify_user_data {
289            if let Some(w) = waits.remove(&user_data) {
290                if w.is_cancelled() {
291                    Some(Entry::new(
292                        user_data,
293                        Err(io::Error::from_raw_os_error(ERROR_CANCELLED as _)),
294                    ))
295                } else if entry.result.is_err() {
296                    Some(entry)
297                } else {
298                    let mut op = unsafe { Key::<dyn OpCode>::new_unchecked(user_data) };
299                    let result = op.operate_blocking();
300                    Some(Entry::new(user_data, result))
301                }
302            } else {
303                Some(entry)
304            }
305        } else {
306            None
307        }
308    }
309
310    pub unsafe fn poll(&mut self, timeout: Option<Duration>) -> io::Result<()> {
311        instrument!(compio_log::Level::TRACE, "poll", ?timeout);
312
313        let notify_user_data = self.notify_overlapped.as_ref() as *const Overlapped as usize;
314
315        for e in self.port.poll(timeout)? {
316            if let Some(e) = Self::create_entry(notify_user_data, &mut self.waits, e) {
317                e.notify();
318            }
319        }
320
321        Ok(())
322    }
323
324    pub fn handle(&self) -> NotifyHandle {
325        NotifyHandle::new(self.port.handle(), self.notify_overlapped.clone())
326    }
327}
328
329impl AsRawFd for Driver {
330    fn as_raw_fd(&self) -> RawFd {
331        self.port.as_raw_handle() as _
332    }
333}
334
335/// A notify handle to the inner driver.
336pub struct NotifyHandle {
337    port: cp::PortHandle,
338    overlapped: Arc<Overlapped>,
339}
340
341impl NotifyHandle {
342    fn new(port: cp::PortHandle, overlapped: Arc<Overlapped>) -> Self {
343        Self { port, overlapped }
344    }
345
346    /// Notify the inner driver.
347    pub fn notify(&self) -> io::Result<()> {
348        self.port.post_raw(self.overlapped.as_ref())
349    }
350}
351
352/// The overlapped struct we actually used for IOCP.
353#[repr(C)]
354pub struct Overlapped {
355    /// The base [`OVERLAPPED`].
356    pub base: OVERLAPPED,
357    /// The unique ID of created driver.
358    pub driver: RawFd,
359}
360
361impl Overlapped {
362    pub(crate) fn new(driver: RawFd) -> Self {
363        Self {
364            base: unsafe { std::mem::zeroed() },
365            driver,
366        }
367    }
368}
369
370// SAFETY: neither field of `OVERLAPPED` is used
371unsafe impl Send for Overlapped {}
372unsafe impl Sync for Overlapped {}