pgrx_tests/framework/shutdown.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC.
//LICENSE
//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc.
//LICENSE
//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. <contact@pgcentral.org>
//LICENSE
//LICENSE All rights reserved.
//LICENSE
//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file.
use std::panic::{self, AssertUnwindSafe, Location};
use std::sync::{Mutex, PoisonError};
use std::{any, io, mem, process};
/// Register a shutdown hook to be called when the process exits.
///
/// Note that shutdown hooks are only run on the client, so must be added from
/// your `setup` function, not the `#[pg_test]` itself.
#[track_caller]
pub fn add_shutdown_hook<F: FnOnce()>(func: F)
where
F: Send + 'static,
{
SHUTDOWN_HOOKS
.lock()
.unwrap_or_else(PoisonError::into_inner)
.push(ShutdownHook { source: Location::caller(), callback: Box::new(func) });
}
pub(super) fn register_shutdown_hook() {
unsafe {
libc::atexit(run_shutdown_hooks);
}
}
/// The `atexit` callback.
///
/// If we panic from `atexit`, we end up causing `exit` to unwind. Unwinding
/// from a nounwind + noreturn function can cause some destructors to run twice,
/// causing (for example) libtest to SIGSEGV.
///
/// This ends up looking like a memory bug in either `pgrx` or the user code, and
/// is very hard to track down, so we go to some lengths to prevent it.
/// Essentially:
///
/// - Panics in each user hook are caught and reported.
/// - As a stop-gap an abort-on-drop panic guard is used to ensure there isn't a
/// place we missed.
///
/// We also write to stderr directly instead, since otherwise our output will
/// sometimes be redirected.
extern "C" fn run_shutdown_hooks() {
let guard = PanicGuard;
let mut any_panicked = false;
let mut hooks = SHUTDOWN_HOOKS.lock().unwrap_or_else(PoisonError::into_inner);
// Note: run hooks in the opposite order they were registered.
for hook in mem::take(&mut *hooks).into_iter().rev() {
any_panicked |= hook.run().is_err();
}
if any_panicked {
write_stderr("error: one or more shutdown hooks panicked (see `stderr` for details).\n");
process::abort()
}
mem::forget(guard);
}
/// Prevent panics in a block of code.
///
/// Prints a message and aborts in its drop. Intended usage is like:
/// ```ignore
/// let guard = PanicGuard;
/// // ...code that absolutely must never unwind goes here...
/// core::mem::forget(guard);
/// ```
struct PanicGuard;
impl Drop for PanicGuard {
fn drop(&mut self) {
write_stderr("Failed to catch panic in the `atexit` callback, aborting!\n");
process::abort();
}
}
static SHUTDOWN_HOOKS: Mutex<Vec<ShutdownHook>> = Mutex::new(Vec::new());
struct ShutdownHook {
source: &'static Location<'static>,
callback: Box<dyn FnOnce() + Send>,
}
impl ShutdownHook {
fn run(self) -> Result<(), ()> {
let Self { source, callback } = self;
let result = panic::catch_unwind(AssertUnwindSafe(callback));
if let Err(e) = result {
let msg = failure_message(&e);
write_stderr(&format!(
"error: shutdown hook (registered at {source}) panicked: {msg}\n"
));
Err(())
} else {
Ok(())
}
}
}
fn failure_message(e: &(dyn any::Any + Send)) -> &str {
if let Some(&msg) = e.downcast_ref::<&'static str>() {
msg
} else if let Some(msg) = e.downcast_ref::<String>() {
msg.as_str()
} else {
"<panic payload of unknown type>"
}
}
/// Write to stderr, bypassing libtest's output redirection. Doesn't append `\n`.
fn write_stderr(s: &str) {
loop {
let res = unsafe { libc::write(libc::STDERR_FILENO, s.as_ptr().cast(), s.len()) };
// Handle EINTR to ensure we don't drop messages.
// `Error::last_os_error()` just reads from errno, so it's fine to use here.
if res >= 0 || io::Error::last_os_error().kind() != io::ErrorKind::Interrupted {
break;
}
}
}