1use std::{
4 cell::RefCell, collections::HashMap, io, mem::MaybeUninit, os::fd::FromRawFd, ptr::null_mut,
5 thread_local,
6};
7
8use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, SetBufInit};
9use compio_driver::{OwnedFd, SharedFd, op::Recv, syscall};
10
11thread_local! {
12 static REG_MAP: RefCell<HashMap<i32, usize>> = RefCell::new(HashMap::new());
13}
14
15fn sigset(sig: i32) -> io::Result<libc::sigset_t> {
16 let mut set: MaybeUninit<libc::sigset_t> = MaybeUninit::uninit();
17 syscall!(libc::sigemptyset(set.as_mut_ptr()))?;
18 syscall!(libc::sigaddset(set.as_mut_ptr(), sig))?;
19 Ok(unsafe { set.assume_init() })
21}
22
23fn register_signal(sig: i32) -> io::Result<libc::sigset_t> {
24 REG_MAP.with_borrow_mut(|map| {
25 let count = map.entry(sig).or_default();
26 let set = sigset(sig)?;
27 if *count == 0 {
28 syscall!(libc::pthread_sigmask(libc::SIG_BLOCK, &set, null_mut()))?;
29 }
30 *count += 1;
31 Ok(set)
32 })
33}
34
35fn unregister_signal(sig: i32) -> io::Result<libc::sigset_t> {
36 REG_MAP.with_borrow_mut(|map| {
37 let count = map.entry(sig).or_default();
38 if *count > 0 {
39 *count -= 1;
40 }
41 let set = sigset(sig)?;
42 if *count == 0 {
43 syscall!(libc::pthread_sigmask(libc::SIG_UNBLOCK, &set, null_mut()))?;
44 }
45 Ok(set)
46 })
47}
48
49#[derive(Debug)]
51struct SignalFd {
52 fd: SharedFd<OwnedFd>,
53 sig: i32,
54}
55
56impl SignalFd {
57 fn new(sig: i32) -> io::Result<Self> {
58 let set = register_signal(sig)?;
59 let mut flag = libc::SFD_CLOEXEC;
60 if cfg!(not(feature = "io-uring")) {
61 flag |= libc::SFD_NONBLOCK;
62 }
63 let fd = syscall!(libc::signalfd(-1, &set, flag))?;
64 let fd = unsafe { OwnedFd::from_raw_fd(fd) };
65 Ok(Self {
66 fd: SharedFd::new(fd),
67 sig,
68 })
69 }
70
71 async fn wait(self) -> io::Result<()> {
72 const INFO_SIZE: usize = std::mem::size_of::<libc::signalfd_siginfo>();
73
74 struct SignalInfo(MaybeUninit<libc::signalfd_siginfo>);
75
76 unsafe impl IoBuf for SignalInfo {
77 fn as_buf_ptr(&self) -> *const u8 {
78 self.0.as_ptr().cast()
79 }
80
81 fn buf_len(&self) -> usize {
82 0
83 }
84
85 fn buf_capacity(&self) -> usize {
86 INFO_SIZE
87 }
88 }
89
90 unsafe impl IoBufMut for SignalInfo {
91 fn as_buf_mut_ptr(&mut self) -> *mut u8 {
92 self.0.as_mut_ptr().cast()
93 }
94 }
95
96 impl SetBufInit for SignalInfo {
97 unsafe fn set_buf_init(&mut self, len: usize) {
98 debug_assert!(len <= INFO_SIZE)
99 }
100 }
101
102 let info = SignalInfo(MaybeUninit::<libc::signalfd_siginfo>::uninit());
103 let op = Recv::new(self.fd.clone(), info);
104 let BufResult(res, op) = compio_runtime::submit(op).await;
105 let len = res?;
106 debug_assert_eq!(len, INFO_SIZE);
107 let info = op.into_inner();
108 let info = unsafe { info.0.assume_init() };
109 debug_assert_eq!(info.ssi_signo, self.sig as u32);
110 Ok(())
111 }
112}
113
114impl Drop for SignalFd {
115 fn drop(&mut self) {
116 unregister_signal(self.sig).ok();
117 }
118}
119
120pub async fn signal(sig: i32) -> io::Result<()> {
125 let fd = SignalFd::new(sig)?;
126 fd.wait().await?;
127 Ok(())
128}