compio_driver/
fd.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
#[cfg(unix)]
use std::os::fd::FromRawFd;
#[cfg(windows)]
use std::os::windows::io::{FromRawHandle, FromRawSocket, RawHandle, RawSocket};
use std::{
    future::{Future, poll_fn},
    mem::ManuallyDrop,
    ops::Deref,
    panic::RefUnwindSafe,
    sync::{
        Arc,
        atomic::{AtomicBool, Ordering},
    },
    task::Poll,
};

use futures_util::task::AtomicWaker;

use crate::{AsRawFd, RawFd};

#[derive(Debug)]
struct Inner<T> {
    fd: T,
    // whether there is a future waiting
    waits: AtomicBool,
    waker: AtomicWaker,
}

impl<T> RefUnwindSafe for Inner<T> {}

/// A shared fd. It is passed to the operations to make sure the fd won't be
/// closed before the operations complete.
#[derive(Debug)]
pub struct SharedFd<T>(Arc<Inner<T>>);

impl<T> SharedFd<T> {
    /// Create the shared fd from an owned fd.
    pub fn new(fd: T) -> Self {
        Self(Arc::new(Inner {
            fd,
            waits: AtomicBool::new(false),
            waker: AtomicWaker::new(),
        }))
    }

    /// Try to take the inner owned fd.
    pub fn try_unwrap(self) -> Result<T, Self> {
        let this = ManuallyDrop::new(self);
        if let Some(fd) = unsafe { Self::try_unwrap_inner(&this) } {
            Ok(fd)
        } else {
            Err(ManuallyDrop::into_inner(this))
        }
    }

    // SAFETY: if `Some` is returned, the method should not be called again.
    unsafe fn try_unwrap_inner(this: &ManuallyDrop<Self>) -> Option<T> {
        let ptr = ManuallyDrop::new(std::ptr::read(&this.0));
        // The ptr is duplicated without increasing the strong count, should forget.
        match Arc::try_unwrap(ManuallyDrop::into_inner(ptr)) {
            Ok(inner) => Some(inner.fd),
            Err(ptr) => {
                std::mem::forget(ptr);
                None
            }
        }
    }

    /// Wait and take the inner owned fd.
    pub fn take(self) -> impl Future<Output = Option<T>> {
        let this = ManuallyDrop::new(self);
        async move {
            if !this.0.waits.swap(true, Ordering::AcqRel) {
                poll_fn(move |cx| {
                    if let Some(fd) = unsafe { Self::try_unwrap_inner(&this) } {
                        return Poll::Ready(Some(fd));
                    }

                    this.0.waker.register(cx.waker());

                    if let Some(fd) = unsafe { Self::try_unwrap_inner(&this) } {
                        Poll::Ready(Some(fd))
                    } else {
                        Poll::Pending
                    }
                })
                .await
            } else {
                None
            }
        }
    }
}

impl<T> Drop for SharedFd<T> {
    fn drop(&mut self) {
        // It's OK to wake multiple times.
        if Arc::strong_count(&self.0) == 2 && self.0.waits.load(Ordering::Acquire) {
            self.0.waker.wake()
        }
    }
}

impl<T: AsRawFd> AsRawFd for SharedFd<T> {
    fn as_raw_fd(&self) -> RawFd {
        self.0.fd.as_raw_fd()
    }
}

#[cfg(windows)]
impl<T: FromRawHandle> FromRawHandle for SharedFd<T> {
    unsafe fn from_raw_handle(handle: RawHandle) -> Self {
        Self::new(T::from_raw_handle(handle))
    }
}

#[cfg(windows)]
impl<T: FromRawSocket> FromRawSocket for SharedFd<T> {
    unsafe fn from_raw_socket(sock: RawSocket) -> Self {
        Self::new(T::from_raw_socket(sock))
    }
}

#[cfg(unix)]
impl<T: FromRawFd> FromRawFd for SharedFd<T> {
    unsafe fn from_raw_fd(fd: RawFd) -> Self {
        Self::new(T::from_raw_fd(fd))
    }
}

impl<T> From<T> for SharedFd<T> {
    fn from(value: T) -> Self {
        Self::new(value)
    }
}

impl<T> Clone for SharedFd<T> {
    fn clone(&self) -> Self {
        Self(self.0.clone())
    }
}

impl<T> Deref for SharedFd<T> {
    type Target = T;

    fn deref(&self) -> &Self::Target {
        &self.0.fd
    }
}

/// Get a clone of [`SharedFd`].
pub trait ToSharedFd<T> {
    /// Return a cloned [`SharedFd`].
    fn to_shared_fd(&self) -> SharedFd<T>;
}

impl<T> ToSharedFd<T> for SharedFd<T> {
    fn to_shared_fd(&self) -> SharedFd<T> {
        self.clone()
    }
}