compio_signal/
linux.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
//! Linux-specific types for signal handling.

use std::{
    cell::RefCell, collections::HashMap, io, mem::MaybeUninit, os::fd::FromRawFd, ptr::null_mut,
    thread_local,
};

use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, SetBufInit};
use compio_driver::{OwnedFd, SharedFd, op::Recv, syscall};

thread_local! {
    static REG_MAP: RefCell<HashMap<i32, usize>> = RefCell::new(HashMap::new());
}

fn sigset(sig: i32) -> io::Result<libc::sigset_t> {
    let mut set: MaybeUninit<libc::sigset_t> = MaybeUninit::uninit();
    syscall!(libc::sigemptyset(set.as_mut_ptr()))?;
    syscall!(libc::sigaddset(set.as_mut_ptr(), sig))?;
    // SAFETY: sigemptyset initializes the set.
    Ok(unsafe { set.assume_init() })
}

fn register_signal(sig: i32) -> io::Result<libc::sigset_t> {
    REG_MAP.with_borrow_mut(|map| {
        let count = map.entry(sig).or_default();
        let set = sigset(sig)?;
        if *count == 0 {
            syscall!(libc::pthread_sigmask(libc::SIG_BLOCK, &set, null_mut()))?;
        }
        *count += 1;
        Ok(set)
    })
}

fn unregister_signal(sig: i32) -> io::Result<libc::sigset_t> {
    REG_MAP.with_borrow_mut(|map| {
        let count = map.entry(sig).or_default();
        if *count > 0 {
            *count -= 1;
        }
        let set = sigset(sig)?;
        if *count == 0 {
            syscall!(libc::pthread_sigmask(libc::SIG_UNBLOCK, &set, null_mut()))?;
        }
        Ok(set)
    })
}

/// Represents a listener to unix signal event.
#[derive(Debug)]
struct SignalFd {
    fd: SharedFd<OwnedFd>,
    sig: i32,
}

impl SignalFd {
    fn new(sig: i32) -> io::Result<Self> {
        let set = register_signal(sig)?;
        let mut flag = libc::SFD_CLOEXEC;
        if cfg!(not(feature = "io-uring")) {
            flag |= libc::SFD_NONBLOCK;
        }
        let fd = syscall!(libc::signalfd(-1, &set, flag))?;
        let fd = unsafe { OwnedFd::from_raw_fd(fd) };
        Ok(Self {
            fd: SharedFd::new(fd),
            sig,
        })
    }

    async fn wait(self) -> io::Result<()> {
        const INFO_SIZE: usize = std::mem::size_of::<libc::signalfd_siginfo>();

        struct SignalInfo(MaybeUninit<libc::signalfd_siginfo>);

        unsafe impl IoBuf for SignalInfo {
            fn as_buf_ptr(&self) -> *const u8 {
                self.0.as_ptr().cast()
            }

            fn buf_len(&self) -> usize {
                0
            }

            fn buf_capacity(&self) -> usize {
                INFO_SIZE
            }
        }

        unsafe impl IoBufMut for SignalInfo {
            fn as_buf_mut_ptr(&mut self) -> *mut u8 {
                self.0.as_mut_ptr().cast()
            }
        }

        impl SetBufInit for SignalInfo {
            unsafe fn set_buf_init(&mut self, len: usize) {
                debug_assert!(len <= INFO_SIZE)
            }
        }

        let info = SignalInfo(MaybeUninit::<libc::signalfd_siginfo>::uninit());
        let op = Recv::new(self.fd.clone(), info);
        let BufResult(res, op) = compio_runtime::submit(op).await;
        let len = res?;
        debug_assert_eq!(len, INFO_SIZE);
        let info = op.into_inner();
        let info = unsafe { info.0.assume_init() };
        debug_assert_eq!(info.ssi_signo, self.sig as u32);
        Ok(())
    }
}

impl Drop for SignalFd {
    fn drop(&mut self) {
        unregister_signal(self.sig).ok();
    }
}

/// Creates a new listener which will receive notifications when the current
/// process receives the specified signal.
///
/// It sets the signal mask of the current thread.
pub async fn signal(sig: i32) -> io::Result<()> {
    let fd = SignalFd::new(sig)?;
    fd.wait().await?;
    Ok(())
}