1#[cfg(unix)]
2use std::os::fd::FromRawFd;
3#[cfg(windows)]
4use std::os::windows::io::{FromRawHandle, FromRawSocket, RawHandle, RawSocket};
5use std::{
6 future::{Future, poll_fn},
7 mem::ManuallyDrop,
8 ops::Deref,
9 panic::RefUnwindSafe,
10 sync::{
11 Arc,
12 atomic::{AtomicBool, Ordering},
13 },
14 task::Poll,
15};
16
17use futures_util::task::AtomicWaker;
18
19use crate::{AsRawFd, RawFd};
20
21#[derive(Debug)]
22struct Inner<T> {
23 fd: T,
24 waits: AtomicBool,
26 waker: AtomicWaker,
27}
28
29impl<T> RefUnwindSafe for Inner<T> {}
30
31#[derive(Debug)]
34pub struct SharedFd<T>(Arc<Inner<T>>);
35
36impl<T> SharedFd<T> {
37 pub fn new(fd: T) -> Self {
39 Self(Arc::new(Inner {
40 fd,
41 waits: AtomicBool::new(false),
42 waker: AtomicWaker::new(),
43 }))
44 }
45
46 pub fn try_unwrap(self) -> Result<T, Self> {
48 let this = ManuallyDrop::new(self);
49 if let Some(fd) = unsafe { Self::try_unwrap_inner(&this) } {
50 Ok(fd)
51 } else {
52 Err(ManuallyDrop::into_inner(this))
53 }
54 }
55
56 unsafe fn try_unwrap_inner(this: &ManuallyDrop<Self>) -> Option<T> {
58 let ptr = ManuallyDrop::new(std::ptr::read(&this.0));
59 match Arc::try_unwrap(ManuallyDrop::into_inner(ptr)) {
61 Ok(inner) => Some(inner.fd),
62 Err(ptr) => {
63 std::mem::forget(ptr);
64 None
65 }
66 }
67 }
68
69 pub fn take(self) -> impl Future<Output = Option<T>> {
71 let this = ManuallyDrop::new(self);
72 async move {
73 if !this.0.waits.swap(true, Ordering::AcqRel) {
74 poll_fn(move |cx| {
75 if let Some(fd) = unsafe { Self::try_unwrap_inner(&this) } {
76 return Poll::Ready(Some(fd));
77 }
78
79 this.0.waker.register(cx.waker());
80
81 if let Some(fd) = unsafe { Self::try_unwrap_inner(&this) } {
82 Poll::Ready(Some(fd))
83 } else {
84 Poll::Pending
85 }
86 })
87 .await
88 } else {
89 None
90 }
91 }
92 }
93}
94
95impl<T> Drop for SharedFd<T> {
96 fn drop(&mut self) {
97 if Arc::strong_count(&self.0) == 2 && self.0.waits.load(Ordering::Acquire) {
99 self.0.waker.wake()
100 }
101 }
102}
103
104impl<T: AsRawFd> AsRawFd for SharedFd<T> {
105 fn as_raw_fd(&self) -> RawFd {
106 self.0.fd.as_raw_fd()
107 }
108}
109
110#[cfg(windows)]
111impl<T: FromRawHandle> FromRawHandle for SharedFd<T> {
112 unsafe fn from_raw_handle(handle: RawHandle) -> Self {
113 Self::new(T::from_raw_handle(handle))
114 }
115}
116
117#[cfg(windows)]
118impl<T: FromRawSocket> FromRawSocket for SharedFd<T> {
119 unsafe fn from_raw_socket(sock: RawSocket) -> Self {
120 Self::new(T::from_raw_socket(sock))
121 }
122}
123
124#[cfg(unix)]
125impl<T: FromRawFd> FromRawFd for SharedFd<T> {
126 unsafe fn from_raw_fd(fd: RawFd) -> Self {
127 Self::new(T::from_raw_fd(fd))
128 }
129}
130
131impl<T> From<T> for SharedFd<T> {
132 fn from(value: T) -> Self {
133 Self::new(value)
134 }
135}
136
137impl<T> Clone for SharedFd<T> {
138 fn clone(&self) -> Self {
139 Self(self.0.clone())
140 }
141}
142
143impl<T> Deref for SharedFd<T> {
144 type Target = T;
145
146 fn deref(&self) -> &Self::Target {
147 &self.0.fd
148 }
149}
150
151pub trait ToSharedFd<T> {
153 fn to_shared_fd(&self) -> SharedFd<T>;
155}
156
157impl<T> ToSharedFd<T> for SharedFd<T> {
158 fn to_shared_fd(&self) -> SharedFd<T> {
159 self.clone()
160 }
161}