compio_driver/iocp/
mod.rs1use std::{
2 collections::HashMap,
3 io,
4 os::windows::{
5 io::{OwnedHandle, OwnedSocket},
6 prelude::{AsRawHandle, AsRawSocket},
7 },
8 pin::Pin,
9 sync::Arc,
10 task::Poll,
11 time::Duration,
12};
13
14use compio_log::{instrument, trace};
15use windows_sys::Win32::{Foundation::ERROR_CANCELLED, System::IO::OVERLAPPED};
16
17use crate::{AsyncifyPool, Entry, Key, ProactorBuilder};
18
19pub(crate) mod op;
20
21mod cp;
22mod wait;
23
24pub(crate) use windows_sys::Win32::Networking::WinSock::{
25 SOCKADDR_STORAGE as sockaddr_storage, socklen_t,
26};
27
28pub type RawFd = isize;
32
33pub trait AsRawFd {
35 fn as_raw_fd(&self) -> RawFd;
37}
38
39#[derive(Debug)]
41pub enum OwnedFd {
42 File(OwnedHandle),
44 Socket(OwnedSocket),
46}
47
48impl AsRawFd for OwnedFd {
49 fn as_raw_fd(&self) -> RawFd {
50 match self {
51 Self::File(fd) => fd.as_raw_handle() as _,
52 Self::Socket(s) => s.as_raw_socket() as _,
53 }
54 }
55}
56
57impl AsRawFd for RawFd {
58 fn as_raw_fd(&self) -> RawFd {
59 *self
60 }
61}
62
63impl AsRawFd for std::fs::File {
64 fn as_raw_fd(&self) -> RawFd {
65 self.as_raw_handle() as _
66 }
67}
68
69impl AsRawFd for OwnedHandle {
70 fn as_raw_fd(&self) -> RawFd {
71 self.as_raw_handle() as _
72 }
73}
74
75impl AsRawFd for socket2::Socket {
76 fn as_raw_fd(&self) -> RawFd {
77 self.as_raw_socket() as _
78 }
79}
80
81impl AsRawFd for OwnedSocket {
82 fn as_raw_fd(&self) -> RawFd {
83 self.as_raw_socket() as _
84 }
85}
86
87impl AsRawFd for std::process::ChildStdin {
88 fn as_raw_fd(&self) -> RawFd {
89 self.as_raw_handle() as _
90 }
91}
92
93impl AsRawFd for std::process::ChildStdout {
94 fn as_raw_fd(&self) -> RawFd {
95 self.as_raw_handle() as _
96 }
97}
98
99impl AsRawFd for std::process::ChildStderr {
100 fn as_raw_fd(&self) -> RawFd {
101 self.as_raw_handle() as _
102 }
103}
104
105impl From<OwnedHandle> for OwnedFd {
106 fn from(value: OwnedHandle) -> Self {
107 Self::File(value)
108 }
109}
110
111impl From<std::fs::File> for OwnedFd {
112 fn from(value: std::fs::File) -> Self {
113 Self::File(OwnedHandle::from(value))
114 }
115}
116
117impl From<std::process::ChildStdin> for OwnedFd {
118 fn from(value: std::process::ChildStdin) -> Self {
119 Self::File(OwnedHandle::from(value))
120 }
121}
122
123impl From<std::process::ChildStdout> for OwnedFd {
124 fn from(value: std::process::ChildStdout) -> Self {
125 Self::File(OwnedHandle::from(value))
126 }
127}
128
129impl From<std::process::ChildStderr> for OwnedFd {
130 fn from(value: std::process::ChildStderr) -> Self {
131 Self::File(OwnedHandle::from(value))
132 }
133}
134
135impl From<OwnedSocket> for OwnedFd {
136 fn from(value: OwnedSocket) -> Self {
137 Self::Socket(value)
138 }
139}
140
141impl From<socket2::Socket> for OwnedFd {
142 fn from(value: socket2::Socket) -> Self {
143 Self::Socket(OwnedSocket::from(value))
144 }
145}
146
147pub enum OpType {
149 Overlapped,
151 Blocking,
154 Event(RawFd),
158}
159
160pub trait OpCode {
162 fn op_type(&self) -> OpType {
165 OpType::Overlapped
166 }
167
168 unsafe fn operate(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> Poll<io::Result<usize>>;
182
183 unsafe fn cancel(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> io::Result<()> {
191 let _optr = optr; Ok(())
193 }
194}
195
196pub(crate) struct Driver {
198 port: cp::Port,
199 waits: HashMap<usize, wait::Wait>,
200 pool: AsyncifyPool,
201 notify_overlapped: Arc<Overlapped>,
202}
203
204impl Driver {
205 pub fn new(builder: &ProactorBuilder) -> io::Result<Self> {
206 instrument!(compio_log::Level::TRACE, "new", ?builder);
207
208 let port = cp::Port::new()?;
209 let driver = port.as_raw_handle() as _;
210 Ok(Self {
211 port,
212 waits: HashMap::default(),
213 pool: builder.create_or_get_thread_pool(),
214 notify_overlapped: Arc::new(Overlapped::new(driver)),
215 })
216 }
217
218 pub fn create_op<T: OpCode + 'static>(&self, op: T) -> Key<T> {
219 Key::new(self.port.as_raw_handle() as _, op)
220 }
221
222 pub fn attach(&mut self, fd: RawFd) -> io::Result<()> {
223 self.port.attach(fd)
224 }
225
226 pub fn cancel(&mut self, op: &mut Key<dyn OpCode>) {
227 instrument!(compio_log::Level::TRACE, "cancel", ?op);
228 trace!("cancel RawOp");
229 let overlapped_ptr = op.as_mut_ptr();
230 if let Some(w) = self.waits.get_mut(&op.user_data()) {
231 if w.cancel().is_ok() {
232 self.port.post_raw(overlapped_ptr).ok();
235 }
236 }
237 let op = op.as_op_pin();
238 trace!("call OpCode::cancel");
240 unsafe { op.cancel(overlapped_ptr.cast()) }.ok();
241 }
242
243 pub fn push(&mut self, op: &mut Key<dyn OpCode>) -> Poll<io::Result<usize>> {
244 instrument!(compio_log::Level::TRACE, "push", ?op);
245 let user_data = op.user_data();
246 trace!("push RawOp");
247 let optr = op.as_mut_ptr();
248 let op_pin = op.as_op_pin();
249 match op_pin.op_type() {
250 OpType::Overlapped => unsafe { op_pin.operate(optr.cast()) },
251 OpType::Blocking => loop {
252 if self.push_blocking(user_data) {
253 break Poll::Pending;
254 } else {
255 unsafe {
258 self.poll(None)?;
259 }
260 }
261 },
262 OpType::Event(e) => {
263 self.waits
264 .insert(user_data, wait::Wait::new(&self.port, e, op)?);
265 Poll::Pending
266 }
267 }
268 }
269
270 fn push_blocking(&mut self, user_data: usize) -> bool {
271 let port = self.port.handle();
272 self.pool
273 .dispatch(move || {
274 let mut op = unsafe { Key::<dyn OpCode>::new_unchecked(user_data) };
275 let optr = op.as_mut_ptr();
276 let res = op.operate_blocking();
277 port.post(res, optr).ok();
278 })
279 .is_ok()
280 }
281
282 fn create_entry(
283 notify_user_data: usize,
284 waits: &mut HashMap<usize, wait::Wait>,
285 entry: Entry,
286 ) -> Option<Entry> {
287 let user_data = entry.user_data();
288 if user_data != notify_user_data {
289 if let Some(w) = waits.remove(&user_data) {
290 if w.is_cancelled() {
291 Some(Entry::new(
292 user_data,
293 Err(io::Error::from_raw_os_error(ERROR_CANCELLED as _)),
294 ))
295 } else if entry.result.is_err() {
296 Some(entry)
297 } else {
298 let mut op = unsafe { Key::<dyn OpCode>::new_unchecked(user_data) };
299 let result = op.operate_blocking();
300 Some(Entry::new(user_data, result))
301 }
302 } else {
303 Some(entry)
304 }
305 } else {
306 None
307 }
308 }
309
310 pub unsafe fn poll(&mut self, timeout: Option<Duration>) -> io::Result<()> {
311 instrument!(compio_log::Level::TRACE, "poll", ?timeout);
312
313 let notify_user_data = self.notify_overlapped.as_ref() as *const Overlapped as usize;
314
315 for e in self.port.poll(timeout)? {
316 if let Some(e) = Self::create_entry(notify_user_data, &mut self.waits, e) {
317 e.notify();
318 }
319 }
320
321 Ok(())
322 }
323
324 pub fn handle(&self) -> NotifyHandle {
325 NotifyHandle::new(self.port.handle(), self.notify_overlapped.clone())
326 }
327}
328
329impl AsRawFd for Driver {
330 fn as_raw_fd(&self) -> RawFd {
331 self.port.as_raw_handle() as _
332 }
333}
334
335pub struct NotifyHandle {
337 port: cp::PortHandle,
338 overlapped: Arc<Overlapped>,
339}
340
341impl NotifyHandle {
342 fn new(port: cp::PortHandle, overlapped: Arc<Overlapped>) -> Self {
343 Self { port, overlapped }
344 }
345
346 pub fn notify(&self) -> io::Result<()> {
348 self.port.post_raw(self.overlapped.as_ref())
349 }
350}
351
352#[repr(C)]
354pub struct Overlapped {
355 pub base: OVERLAPPED,
357 pub driver: RawFd,
359}
360
361impl Overlapped {
362 pub(crate) fn new(driver: RawFd) -> Self {
363 Self {
364 base: unsafe { std::mem::zeroed() },
365 driver,
366 }
367 }
368}
369
370unsafe impl Send for Overlapped {}
372unsafe impl Sync for Overlapped {}