compio_signal/
linux.rs

1//! Linux-specific types for signal handling.
2
3use std::{
4    cell::RefCell, collections::HashMap, io, mem::MaybeUninit, os::fd::FromRawFd, ptr::null_mut,
5    thread_local,
6};
7
8use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, SetBufInit};
9use compio_driver::{OwnedFd, SharedFd, op::Recv, syscall};
10
11thread_local! {
12    static REG_MAP: RefCell<HashMap<i32, usize>> = RefCell::new(HashMap::new());
13}
14
15fn sigset(sig: i32) -> io::Result<libc::sigset_t> {
16    let mut set: MaybeUninit<libc::sigset_t> = MaybeUninit::uninit();
17    syscall!(libc::sigemptyset(set.as_mut_ptr()))?;
18    syscall!(libc::sigaddset(set.as_mut_ptr(), sig))?;
19    // SAFETY: sigemptyset initializes the set.
20    Ok(unsafe { set.assume_init() })
21}
22
23fn register_signal(sig: i32) -> io::Result<libc::sigset_t> {
24    REG_MAP.with_borrow_mut(|map| {
25        let count = map.entry(sig).or_default();
26        let set = sigset(sig)?;
27        if *count == 0 {
28            syscall!(libc::pthread_sigmask(libc::SIG_BLOCK, &set, null_mut()))?;
29        }
30        *count += 1;
31        Ok(set)
32    })
33}
34
35fn unregister_signal(sig: i32) -> io::Result<libc::sigset_t> {
36    REG_MAP.with_borrow_mut(|map| {
37        let count = map.entry(sig).or_default();
38        if *count > 0 {
39            *count -= 1;
40        }
41        let set = sigset(sig)?;
42        if *count == 0 {
43            syscall!(libc::pthread_sigmask(libc::SIG_UNBLOCK, &set, null_mut()))?;
44        }
45        Ok(set)
46    })
47}
48
49/// Represents a listener to unix signal event.
50#[derive(Debug)]
51struct SignalFd {
52    fd: SharedFd<OwnedFd>,
53    sig: i32,
54}
55
56impl SignalFd {
57    fn new(sig: i32) -> io::Result<Self> {
58        let set = register_signal(sig)?;
59        let mut flag = libc::SFD_CLOEXEC;
60        if cfg!(not(feature = "io-uring")) {
61            flag |= libc::SFD_NONBLOCK;
62        }
63        let fd = syscall!(libc::signalfd(-1, &set, flag))?;
64        let fd = unsafe { OwnedFd::from_raw_fd(fd) };
65        Ok(Self {
66            fd: SharedFd::new(fd),
67            sig,
68        })
69    }
70
71    async fn wait(self) -> io::Result<()> {
72        const INFO_SIZE: usize = std::mem::size_of::<libc::signalfd_siginfo>();
73
74        struct SignalInfo(MaybeUninit<libc::signalfd_siginfo>);
75
76        unsafe impl IoBuf for SignalInfo {
77            fn as_buf_ptr(&self) -> *const u8 {
78                self.0.as_ptr().cast()
79            }
80
81            fn buf_len(&self) -> usize {
82                0
83            }
84
85            fn buf_capacity(&self) -> usize {
86                INFO_SIZE
87            }
88        }
89
90        unsafe impl IoBufMut for SignalInfo {
91            fn as_buf_mut_ptr(&mut self) -> *mut u8 {
92                self.0.as_mut_ptr().cast()
93            }
94        }
95
96        impl SetBufInit for SignalInfo {
97            unsafe fn set_buf_init(&mut self, len: usize) {
98                debug_assert!(len <= INFO_SIZE)
99            }
100        }
101
102        let info = SignalInfo(MaybeUninit::<libc::signalfd_siginfo>::uninit());
103        let op = Recv::new(self.fd.clone(), info);
104        let BufResult(res, op) = compio_runtime::submit(op).await;
105        let len = res?;
106        debug_assert_eq!(len, INFO_SIZE);
107        let info = op.into_inner();
108        let info = unsafe { info.0.assume_init() };
109        debug_assert_eq!(info.ssi_signo, self.sig as u32);
110        Ok(())
111    }
112}
113
114impl Drop for SignalFd {
115    fn drop(&mut self) {
116        unregister_signal(self.sig).ok();
117    }
118}
119
120/// Creates a new listener which will receive notifications when the current
121/// process receives the specified signal.
122///
123/// It sets the signal mask of the current thread.
124pub async fn signal(sig: i32) -> io::Result<()> {
125    let fd = SignalFd::new(sig)?;
126    fd.wait().await?;
127    Ok(())
128}