compio_driver/
fd.rs

1#[cfg(unix)]
2use std::os::fd::FromRawFd;
3#[cfg(windows)]
4use std::os::windows::io::{FromRawHandle, FromRawSocket, RawHandle, RawSocket};
5use std::{
6    future::{Future, poll_fn},
7    mem::ManuallyDrop,
8    ops::Deref,
9    panic::RefUnwindSafe,
10    sync::{
11        Arc,
12        atomic::{AtomicBool, Ordering},
13    },
14    task::Poll,
15};
16
17use futures_util::task::AtomicWaker;
18
19use crate::{AsRawFd, RawFd};
20
21#[derive(Debug)]
22struct Inner<T> {
23    fd: T,
24    // whether there is a future waiting
25    waits: AtomicBool,
26    waker: AtomicWaker,
27}
28
29impl<T> RefUnwindSafe for Inner<T> {}
30
31/// A shared fd. It is passed to the operations to make sure the fd won't be
32/// closed before the operations complete.
33#[derive(Debug)]
34pub struct SharedFd<T>(Arc<Inner<T>>);
35
36impl<T> SharedFd<T> {
37    /// Create the shared fd from an owned fd.
38    pub fn new(fd: T) -> Self {
39        Self(Arc::new(Inner {
40            fd,
41            waits: AtomicBool::new(false),
42            waker: AtomicWaker::new(),
43        }))
44    }
45
46    /// Try to take the inner owned fd.
47    pub fn try_unwrap(self) -> Result<T, Self> {
48        let this = ManuallyDrop::new(self);
49        if let Some(fd) = unsafe { Self::try_unwrap_inner(&this) } {
50            Ok(fd)
51        } else {
52            Err(ManuallyDrop::into_inner(this))
53        }
54    }
55
56    // SAFETY: if `Some` is returned, the method should not be called again.
57    unsafe fn try_unwrap_inner(this: &ManuallyDrop<Self>) -> Option<T> {
58        let ptr = ManuallyDrop::new(std::ptr::read(&this.0));
59        // The ptr is duplicated without increasing the strong count, should forget.
60        match Arc::try_unwrap(ManuallyDrop::into_inner(ptr)) {
61            Ok(inner) => Some(inner.fd),
62            Err(ptr) => {
63                std::mem::forget(ptr);
64                None
65            }
66        }
67    }
68
69    /// Wait and take the inner owned fd.
70    pub fn take(self) -> impl Future<Output = Option<T>> {
71        let this = ManuallyDrop::new(self);
72        async move {
73            if !this.0.waits.swap(true, Ordering::AcqRel) {
74                poll_fn(move |cx| {
75                    if let Some(fd) = unsafe { Self::try_unwrap_inner(&this) } {
76                        return Poll::Ready(Some(fd));
77                    }
78
79                    this.0.waker.register(cx.waker());
80
81                    if let Some(fd) = unsafe { Self::try_unwrap_inner(&this) } {
82                        Poll::Ready(Some(fd))
83                    } else {
84                        Poll::Pending
85                    }
86                })
87                .await
88            } else {
89                None
90            }
91        }
92    }
93}
94
95impl<T> Drop for SharedFd<T> {
96    fn drop(&mut self) {
97        // It's OK to wake multiple times.
98        if Arc::strong_count(&self.0) == 2 && self.0.waits.load(Ordering::Acquire) {
99            self.0.waker.wake()
100        }
101    }
102}
103
104impl<T: AsRawFd> AsRawFd for SharedFd<T> {
105    fn as_raw_fd(&self) -> RawFd {
106        self.0.fd.as_raw_fd()
107    }
108}
109
110#[cfg(windows)]
111impl<T: FromRawHandle> FromRawHandle for SharedFd<T> {
112    unsafe fn from_raw_handle(handle: RawHandle) -> Self {
113        Self::new(T::from_raw_handle(handle))
114    }
115}
116
117#[cfg(windows)]
118impl<T: FromRawSocket> FromRawSocket for SharedFd<T> {
119    unsafe fn from_raw_socket(sock: RawSocket) -> Self {
120        Self::new(T::from_raw_socket(sock))
121    }
122}
123
124#[cfg(unix)]
125impl<T: FromRawFd> FromRawFd for SharedFd<T> {
126    unsafe fn from_raw_fd(fd: RawFd) -> Self {
127        Self::new(T::from_raw_fd(fd))
128    }
129}
130
131impl<T> From<T> for SharedFd<T> {
132    fn from(value: T) -> Self {
133        Self::new(value)
134    }
135}
136
137impl<T> Clone for SharedFd<T> {
138    fn clone(&self) -> Self {
139        Self(self.0.clone())
140    }
141}
142
143impl<T> Deref for SharedFd<T> {
144    type Target = T;
145
146    fn deref(&self) -> &Self::Target {
147        &self.0.fd
148    }
149}
150
151/// Get a clone of [`SharedFd`].
152pub trait ToSharedFd<T> {
153    /// Return a cloned [`SharedFd`].
154    fn to_shared_fd(&self) -> SharedFd<T>;
155}
156
157impl<T> ToSharedFd<T> for SharedFd<T> {
158    fn to_shared_fd(&self) -> SharedFd<T> {
159        self.clone()
160    }
161}