compio_driver/iour/
mod.rs

1#[cfg_attr(all(doc, docsrs), doc(cfg(all())))]
2#[allow(unused_imports)]
3pub use std::os::fd::{AsRawFd, OwnedFd, RawFd};
4use std::{io, os::fd::FromRawFd, pin::Pin, sync::Arc, task::Poll, time::Duration};
5
6use compio_log::{instrument, trace, warn};
7use crossbeam_queue::SegQueue;
8cfg_if::cfg_if! {
9    if #[cfg(feature = "io-uring-cqe32")] {
10        use io_uring::cqueue::Entry32 as CEntry;
11    } else {
12        use io_uring::cqueue::Entry as CEntry;
13    }
14}
15cfg_if::cfg_if! {
16    if #[cfg(feature = "io-uring-sqe128")] {
17        use io_uring::squeue::Entry128 as SEntry;
18    } else {
19        use io_uring::squeue::Entry as SEntry;
20    }
21}
22use io_uring::{
23    IoUring,
24    cqueue::more,
25    opcode::{AsyncCancel, PollAdd},
26    types::{Fd, SubmitArgs, Timespec},
27};
28pub(crate) use libc::{sockaddr_storage, socklen_t};
29
30use crate::{AsyncifyPool, Entry, Key, ProactorBuilder, syscall};
31
32pub(crate) mod op;
33
34/// The created entry of [`OpCode`].
35pub enum OpEntry {
36    /// This operation creates an io-uring submission entry.
37    Submission(io_uring::squeue::Entry),
38    #[cfg(feature = "io-uring-sqe128")]
39    /// This operation creates an 128-bit io-uring submission entry.
40    Submission128(io_uring::squeue::Entry128),
41    /// This operation is a blocking one.
42    Blocking,
43}
44
45impl From<io_uring::squeue::Entry> for OpEntry {
46    fn from(value: io_uring::squeue::Entry) -> Self {
47        Self::Submission(value)
48    }
49}
50
51#[cfg(feature = "io-uring-sqe128")]
52impl From<io_uring::squeue::Entry128> for OpEntry {
53    fn from(value: io_uring::squeue::Entry128) -> Self {
54        Self::Submission128(value)
55    }
56}
57
58/// Abstraction of io-uring operations.
59pub trait OpCode {
60    /// Create submission entry.
61    fn create_entry(self: Pin<&mut Self>) -> OpEntry;
62
63    /// Call the operation in a blocking way. This method will only be called if
64    /// [`create_entry`] returns [`OpEntry::Blocking`].
65    fn call_blocking(self: Pin<&mut Self>) -> io::Result<usize> {
66        unreachable!("this operation is asynchronous")
67    }
68
69    /// Set the result when it successfully completes.
70    /// The operation stores the result and is responsible to release it if the
71    /// operation is cancelled.
72    ///
73    /// # Safety
74    ///
75    /// Users should not call it.
76    unsafe fn set_result(self: Pin<&mut Self>, _: usize) {}
77}
78
79/// Low-level driver of io-uring.
80pub(crate) struct Driver {
81    inner: IoUring<SEntry, CEntry>,
82    notifier: Notifier,
83    pool: AsyncifyPool,
84    pool_completed: Arc<SegQueue<Entry>>,
85}
86
87impl Driver {
88    const CANCEL: u64 = u64::MAX;
89    const NOTIFY: u64 = u64::MAX - 1;
90
91    pub fn new(builder: &ProactorBuilder) -> io::Result<Self> {
92        instrument!(compio_log::Level::TRACE, "new", ?builder);
93        trace!("new iour driver");
94        let notifier = Notifier::new()?;
95        let mut io_uring_builder = IoUring::builder();
96        if let Some(sqpoll_idle) = builder.sqpoll_idle {
97            io_uring_builder.setup_sqpoll(sqpoll_idle.as_millis() as _);
98        }
99        let mut inner = io_uring_builder.build(builder.capacity)?;
100        #[allow(clippy::useless_conversion)]
101        unsafe {
102            inner
103                .submission()
104                .push(
105                    &PollAdd::new(Fd(notifier.as_raw_fd()), libc::POLLIN as _)
106                        .multi(true)
107                        .build()
108                        .user_data(Self::NOTIFY)
109                        .into(),
110                )
111                .expect("the squeue sould not be full");
112        }
113        Ok(Self {
114            inner,
115            notifier,
116            pool: builder.create_or_get_thread_pool(),
117            pool_completed: Arc::new(SegQueue::new()),
118        })
119    }
120
121    // Auto means that it choose to wait or not automatically.
122    fn submit_auto(&mut self, timeout: Option<Duration>) -> io::Result<()> {
123        instrument!(compio_log::Level::TRACE, "submit_auto", ?timeout);
124        let res = {
125            // Last part of submission queue, wait till timeout.
126            if let Some(duration) = timeout {
127                let timespec = timespec(duration);
128                let args = SubmitArgs::new().timespec(&timespec);
129                self.inner.submitter().submit_with_args(1, &args)
130            } else {
131                self.inner.submit_and_wait(1)
132            }
133        };
134        trace!("submit result: {res:?}");
135        match res {
136            Ok(_) => {
137                if self.inner.completion().is_empty() {
138                    Err(io::ErrorKind::TimedOut.into())
139                } else {
140                    Ok(())
141                }
142            }
143            Err(e) => match e.raw_os_error() {
144                Some(libc::ETIME) => Err(io::ErrorKind::TimedOut.into()),
145                Some(libc::EBUSY) | Some(libc::EAGAIN) => Err(io::ErrorKind::Interrupted.into()),
146                _ => Err(e),
147            },
148        }
149    }
150
151    fn poll_blocking(&mut self) {
152        // Cheaper than pop.
153        if !self.pool_completed.is_empty() {
154            while let Some(entry) = self.pool_completed.pop() {
155                unsafe {
156                    entry.notify();
157                }
158            }
159        }
160    }
161
162    fn poll_entries(&mut self) -> bool {
163        self.poll_blocking();
164
165        let mut cqueue = self.inner.completion();
166        cqueue.sync();
167        let has_entry = !cqueue.is_empty();
168        for entry in cqueue {
169            match entry.user_data() {
170                Self::CANCEL => {}
171                Self::NOTIFY => {
172                    let flags = entry.flags();
173                    debug_assert!(more(flags));
174                    self.notifier.clear().expect("cannot clear notifier");
175                }
176                _ => unsafe {
177                    create_entry(entry).notify();
178                },
179            }
180        }
181        has_entry
182    }
183
184    pub fn create_op<T: crate::sys::OpCode + 'static>(&self, op: T) -> Key<T> {
185        Key::new(self.as_raw_fd(), op)
186    }
187
188    pub fn attach(&mut self, _fd: RawFd) -> io::Result<()> {
189        Ok(())
190    }
191
192    pub fn cancel(&mut self, op: &mut Key<dyn crate::sys::OpCode>) {
193        instrument!(compio_log::Level::TRACE, "cancel", ?op);
194        trace!("cancel RawOp");
195        unsafe {
196            #[allow(clippy::useless_conversion)]
197            if self
198                .inner
199                .submission()
200                .push(
201                    &AsyncCancel::new(op.user_data() as _)
202                        .build()
203                        .user_data(Self::CANCEL)
204                        .into(),
205                )
206                .is_err()
207            {
208                warn!("could not push AsyncCancel entry");
209            }
210        }
211    }
212
213    fn push_raw(&mut self, entry: SEntry) -> io::Result<()> {
214        loop {
215            let mut squeue = self.inner.submission();
216            match unsafe { squeue.push(&entry) } {
217                Ok(()) => {
218                    squeue.sync();
219                    break Ok(());
220                }
221                Err(_) => {
222                    drop(squeue);
223                    self.poll_entries();
224                    match self.submit_auto(Some(Duration::ZERO)) {
225                        Ok(()) => {}
226                        Err(e)
227                            if matches!(
228                                e.kind(),
229                                io::ErrorKind::TimedOut | io::ErrorKind::Interrupted
230                            ) => {}
231                        Err(e) => return Err(e),
232                    }
233                }
234            }
235        }
236    }
237
238    pub fn push(&mut self, op: &mut Key<dyn crate::sys::OpCode>) -> Poll<io::Result<usize>> {
239        instrument!(compio_log::Level::TRACE, "push", ?op);
240        let user_data = op.user_data();
241        let op_pin = op.as_op_pin();
242        trace!("push RawOp");
243        match op_pin.create_entry() {
244            OpEntry::Submission(entry) => {
245                #[allow(clippy::useless_conversion)]
246                self.push_raw(entry.user_data(user_data as _).into())?;
247                Poll::Pending
248            }
249            #[cfg(feature = "io-uring-sqe128")]
250            OpEntry::Submission128(entry) => {
251                self.push_raw(entry.user_data(user_data as _))?;
252                Poll::Pending
253            }
254            OpEntry::Blocking => loop {
255                if self.push_blocking(user_data) {
256                    break Poll::Pending;
257                } else {
258                    self.poll_blocking();
259                }
260            },
261        }
262    }
263
264    fn push_blocking(&mut self, user_data: usize) -> bool {
265        let handle = self.handle();
266        let completed = self.pool_completed.clone();
267        self.pool
268            .dispatch(move || {
269                let mut op = unsafe { Key::<dyn crate::sys::OpCode>::new_unchecked(user_data) };
270                let op_pin = op.as_op_pin();
271                let res = op_pin.call_blocking();
272                completed.push(Entry::new(user_data, res));
273                handle.notify().ok();
274            })
275            .is_ok()
276    }
277
278    pub unsafe fn poll(&mut self, timeout: Option<Duration>) -> io::Result<()> {
279        instrument!(compio_log::Level::TRACE, "poll", ?timeout);
280        // Anyway we need to submit once, no matter there are entries in squeue.
281        trace!("start polling");
282
283        if !self.poll_entries() {
284            self.submit_auto(timeout)?;
285            self.poll_entries();
286        }
287
288        Ok(())
289    }
290
291    pub fn handle(&self) -> NotifyHandle {
292        self.notifier.handle()
293    }
294}
295
296impl AsRawFd for Driver {
297    fn as_raw_fd(&self) -> RawFd {
298        self.inner.as_raw_fd()
299    }
300}
301
302fn create_entry(cq_entry: CEntry) -> Entry {
303    let result = cq_entry.result();
304    let result = if result < 0 {
305        let result = if result == -libc::ECANCELED {
306            libc::ETIMEDOUT
307        } else {
308            -result
309        };
310        Err(io::Error::from_raw_os_error(result))
311    } else {
312        Ok(result as _)
313    };
314    let mut entry = Entry::new(cq_entry.user_data() as _, result);
315    entry.set_flags(cq_entry.flags());
316
317    entry
318}
319
320fn timespec(duration: std::time::Duration) -> Timespec {
321    Timespec::new()
322        .sec(duration.as_secs())
323        .nsec(duration.subsec_nanos())
324}
325
326#[derive(Debug)]
327struct Notifier {
328    fd: Arc<OwnedFd>,
329}
330
331impl Notifier {
332    /// Create a new notifier.
333    fn new() -> io::Result<Self> {
334        let fd = syscall!(libc::eventfd(0, libc::EFD_CLOEXEC | libc::EFD_NONBLOCK))?;
335        let fd = unsafe { OwnedFd::from_raw_fd(fd) };
336        Ok(Self { fd: Arc::new(fd) })
337    }
338
339    pub fn clear(&self) -> io::Result<()> {
340        loop {
341            let mut buffer = [0u64];
342            let res = syscall!(libc::read(
343                self.fd.as_raw_fd(),
344                buffer.as_mut_ptr().cast(),
345                std::mem::size_of::<u64>()
346            ));
347            match res {
348                Ok(len) => {
349                    debug_assert_eq!(len, std::mem::size_of::<u64>() as _);
350                    break Ok(());
351                }
352                // Clear the next time:)
353                Err(e) if e.kind() == io::ErrorKind::WouldBlock => break Ok(()),
354                // Just like read_exact
355                Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
356                Err(e) => break Err(e),
357            }
358        }
359    }
360
361    pub fn handle(&self) -> NotifyHandle {
362        NotifyHandle::new(self.fd.clone())
363    }
364}
365
366impl AsRawFd for Notifier {
367    fn as_raw_fd(&self) -> RawFd {
368        self.fd.as_raw_fd()
369    }
370}
371
372/// A notify handle to the inner driver.
373pub struct NotifyHandle {
374    fd: Arc<OwnedFd>,
375}
376
377impl NotifyHandle {
378    pub(crate) fn new(fd: Arc<OwnedFd>) -> Self {
379        Self { fd }
380    }
381
382    /// Notify the inner driver.
383    pub fn notify(&self) -> io::Result<()> {
384        let data = 1u64;
385        syscall!(libc::write(
386            self.fd.as_raw_fd(),
387            &data as *const _ as *const _,
388            std::mem::size_of::<u64>(),
389        ))?;
390        Ok(())
391    }
392}