1#![deny(unsafe_op_in_unsafe_fn)]
11#![allow(non_snake_case)]
12
13use core::ffi::CStr;
14use std::any::Any;
15use std::cell::Cell;
16use std::fmt::{Display, Formatter};
17use std::hint::unreachable_unchecked;
18use std::panic::{
19 catch_unwind, panic_any, resume_unwind, AssertUnwindSafe, Location, PanicInfo, UnwindSafe,
20};
21
22use crate::elog::PgLogLevel;
23use crate::errcodes::PgSqlErrorCode;
24use crate::{pfree, AsPgCStr, MemoryContextSwitchTo};
25
26pub trait ErrorReportable {
28 type Inner;
29
30 fn unwrap_or_report(self) -> Self::Inner;
32}
33
34impl<T, E> ErrorReportable for Result<T, E>
35where
36 E: Any + Display,
37{
38 type Inner = T;
39
40 fn unwrap_or_report(self) -> Self::Inner {
46 self.unwrap_or_else(|e| {
47 let any: Box<&dyn Any> = Box::new(&e);
48 if any.downcast_ref::<ErrorReport>().is_some() {
49 let any: Box<dyn Any> = Box::new(e);
50 any.downcast::<ErrorReport>().unwrap().report(PgLogLevel::ERROR);
51 unreachable!();
52 } else {
53 ereport!(ERROR, PgSqlErrorCode::ERRCODE_DATA_EXCEPTION, &format!("{e}"));
54 }
55 })
56 }
57}
58
59#[derive(Debug)]
60pub struct ErrorReportLocation {
61 pub(crate) file: String,
62 pub(crate) funcname: Option<String>,
63 pub(crate) line: u32,
64 pub(crate) col: u32,
65 pub(crate) backtrace: Option<std::backtrace::Backtrace>,
66}
67
68impl Default for ErrorReportLocation {
69 fn default() -> Self {
70 Self {
71 file: std::string::String::from("<unknown>"),
72 funcname: None,
73 line: 0,
74 col: 0,
75 backtrace: None,
76 }
77 }
78}
79
80impl Display for ErrorReportLocation {
81 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
82 match &self.funcname {
83 Some(funcname) => {
84 write!(f, "{}, {}:{}:{}", funcname, self.file, self.line, self.col)?;
86 }
87
88 None => {
89 write!(f, "{}:{}:{}", self.file, self.line, self.col)?;
90 }
91 }
92
93 if let Some(backtrace) = &self.backtrace {
94 if backtrace.status() == std::backtrace::BacktraceStatus::Captured {
95 write!(f, "\n{backtrace}")?;
96 }
97 }
98
99 Ok(())
100 }
101}
102
103impl From<&Location<'_>> for ErrorReportLocation {
104 fn from(location: &Location<'_>) -> Self {
105 Self {
106 file: location.file().to_string(),
107 funcname: None,
108 line: location.line(),
109 col: location.column(),
110 backtrace: None,
111 }
112 }
113}
114
115impl From<&PanicInfo<'_>> for ErrorReportLocation {
116 fn from(pi: &PanicInfo<'_>) -> Self {
117 pi.location().map(|l| l.into()).unwrap_or_default()
118 }
119}
120
121#[derive(Debug)]
124pub struct ErrorReport {
125 pub(crate) sqlerrcode: PgSqlErrorCode,
126 pub(crate) message: String,
127 pub(crate) hint: Option<String>,
128 pub(crate) detail: Option<String>,
129 pub(crate) location: ErrorReportLocation,
130}
131
132impl Display for ErrorReport {
133 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
134 write!(f, "{}: {}", self.sqlerrcode, self.message)?;
135 if let Some(hint) = &self.hint {
136 write!(f, "\nHINT: {hint}")?;
137 }
138 if let Some(detail) = &self.detail {
139 write!(f, "\nDETAIL: {detail}")?;
140 }
141 write!(f, "\nLOCATION: {}", self.location)
142 }
143}
144
145#[derive(Debug)]
146pub struct ErrorReportWithLevel {
147 pub(crate) level: PgLogLevel,
148 pub(crate) inner: ErrorReport,
149}
150
151impl ErrorReportWithLevel {
152 fn report(self) {
153 match self.level {
154 PgLogLevel::ERROR => panic_any(self),
156
157 PgLogLevel::FATAL | PgLogLevel::PANIC => {
159 do_ereport(self);
160 unreachable!()
161 }
162
163 _ => do_ereport(self),
165 }
166 }
167
168 pub fn level(&self) -> PgLogLevel {
170 self.level
171 }
172
173 pub fn sql_error_code(&self) -> PgSqlErrorCode {
175 self.inner.sqlerrcode
176 }
177
178 pub fn message(&self) -> &str {
180 self.inner.message()
181 }
182
183 pub fn detail(&self) -> Option<&str> {
185 self.inner.detail()
186 }
187
188 pub fn detail_with_backtrace(&self) -> Option<String> {
190 match (self.detail(), self.backtrace()) {
191 (Some(detail), Some(bt))
192 if bt.status() == std::backtrace::BacktraceStatus::Captured =>
193 {
194 Some(format!("{detail}\n{bt}"))
195 }
196 (Some(d), _) => Some(d.to_string()),
197 (None, Some(bt)) if bt.status() == std::backtrace::BacktraceStatus::Captured => {
198 Some(format!("\n{bt}"))
199 }
200 (None, _) => None,
201 }
202 }
203
204 pub fn hint(&self) -> Option<&str> {
206 self.inner.hint()
207 }
208
209 pub fn file(&self) -> &str {
211 &self.inner.location.file
212 }
213
214 pub fn line_number(&self) -> u32 {
216 self.inner.location.line
217 }
218
219 pub fn backtrace(&self) -> Option<&std::backtrace::Backtrace> {
221 self.inner.location.backtrace.as_ref()
222 }
223
224 pub fn function_name(&self) -> Option<&str> {
226 self.inner.location.funcname.as_deref()
227 }
228
229 fn context_message(&self) -> Option<String> {
231 None
233 }
234}
235
236impl ErrorReport {
237 #[track_caller]
242 pub fn new<S: Into<String>>(
243 sqlerrcode: PgSqlErrorCode,
244 message: S,
245 funcname: &'static str,
246 ) -> Self {
247 let mut location: ErrorReportLocation = Location::caller().into();
248 location.funcname = Some(funcname.to_string());
249
250 Self { sqlerrcode, message: message.into(), hint: None, detail: None, location }
251 }
252
253 fn with_location<S: Into<String>>(
258 sqlerrcode: PgSqlErrorCode,
259 message: S,
260 location: ErrorReportLocation,
261 ) -> Self {
262 Self { sqlerrcode, message: message.into(), hint: None, detail: None, location }
263 }
264
265 pub fn set_detail<S: Into<String>>(mut self, detail: S) -> Self {
267 self.detail = Some(detail.into());
268 self
269 }
270
271 pub fn set_hint<S: Into<String>>(mut self, hint: S) -> Self {
273 self.hint = Some(hint.into());
274 self
275 }
276
277 pub fn message(&self) -> &str {
279 &self.message
280 }
281
282 pub fn detail(&self) -> Option<&str> {
284 self.detail.as_deref()
285 }
286
287 pub fn hint(&self) -> Option<&str> {
289 self.hint.as_deref()
290 }
291
292 pub fn report(self, level: PgLogLevel) {
296 ErrorReportWithLevel { level, inner: self }.report()
297 }
298}
299
300thread_local! { static PANIC_LOCATION: Cell<Option<ErrorReportLocation>> = const { Cell::new(None) }}
301
302fn take_panic_location() -> ErrorReportLocation {
303 PANIC_LOCATION.with(|p| p.take().unwrap_or_default())
304}
305
306pub fn register_pg_guard_panic_hook() {
307 use super::thread_check::is_os_main_thread;
308
309 let default_hook = std::panic::take_hook();
310 std::panic::set_hook(Box::new(move |info: _| {
311 if is_os_main_thread() == Some(true) {
312 PANIC_LOCATION.with(|thread_local| {
314 thread_local.replace({
315 let mut info: ErrorReportLocation = info.into();
316 info.backtrace = Some(std::backtrace::Backtrace::capture());
317 Some(info)
318 })
319 });
320 } else {
321 default_hook(info)
323 }
324 }))
325}
326
327#[derive(Debug)]
329pub enum CaughtError {
330 PostgresError(ErrorReportWithLevel),
332
333 ErrorReport(ErrorReportWithLevel),
335
336 RustPanic { ereport: ErrorReportWithLevel, payload: Box<dyn Any + Send> },
338}
339
340impl CaughtError {
341 pub fn rethrow(self) -> ! {
345 resume_unwind(Box::new(self))
348 }
349}
350
351#[derive(Debug)]
352enum GuardAction<R> {
353 Return(R),
354 ReThrow,
355 Report(ErrorReportWithLevel),
356}
357
358#[doc(hidden)]
389pub unsafe fn pgrx_extern_c_guard<Func, R>(f: Func) -> R
392where
393 Func: FnOnce() -> R,
394{
395 match unsafe { run_guarded(AssertUnwindSafe(f)) } {
396 GuardAction::Return(r) => r,
397 GuardAction::ReThrow => {
398 #[cfg_attr(target_os = "windows", link(name = "postgres"))]
399 extern "C" {
400 fn pg_re_throw() -> !;
401 }
402 unsafe {
403 crate::CurrentMemoryContext = crate::ErrorContext;
404 pg_re_throw()
405 }
406 }
407 GuardAction::Report(ereport) => {
408 do_ereport(ereport);
409 unreachable!("pgrx reported a CaughtError that wasn't raised at ERROR or above");
410 }
411 }
412}
413
414#[inline(never)]
416unsafe fn run_guarded<F, R>(f: F) -> GuardAction<R>
417where
418 F: FnOnce() -> R + UnwindSafe,
419{
420 match catch_unwind(f) {
421 Ok(v) => GuardAction::Return(v),
422 Err(e) => match downcast_panic_payload(e) {
423 CaughtError::PostgresError(_) => {
424 GuardAction::ReThrow
427 }
428 CaughtError::ErrorReport(ereport) | CaughtError::RustPanic { ereport, .. } => {
429 GuardAction::Report(ereport)
430 }
431 },
432 }
433}
434
435pub(crate) fn downcast_panic_payload(e: Box<dyn Any + Send>) -> CaughtError {
437 if e.downcast_ref::<CaughtError>().is_some() {
438 *e.downcast::<CaughtError>().unwrap()
440 } else if e.downcast_ref::<ErrorReportWithLevel>().is_some() {
441 CaughtError::ErrorReport(*e.downcast().unwrap())
443 } else if e.downcast_ref::<ErrorReport>().is_some() {
444 CaughtError::ErrorReport(ErrorReportWithLevel {
446 level: PgLogLevel::ERROR,
447 inner: *e.downcast().unwrap(),
448 })
449 } else if let Some(message) = e.downcast_ref::<&str>() {
450 CaughtError::RustPanic {
452 ereport: ErrorReportWithLevel {
453 level: PgLogLevel::ERROR,
454 inner: ErrorReport::with_location(
455 PgSqlErrorCode::ERRCODE_INTERNAL_ERROR,
456 *message,
457 take_panic_location(),
458 ),
459 },
460 payload: e,
461 }
462 } else if let Some(message) = e.downcast_ref::<String>() {
463 CaughtError::RustPanic {
465 ereport: ErrorReportWithLevel {
466 level: PgLogLevel::ERROR,
467 inner: ErrorReport::with_location(
468 PgSqlErrorCode::ERRCODE_INTERNAL_ERROR,
469 message,
470 take_panic_location(),
471 ),
472 },
473 payload: e,
474 }
475 } else {
476 CaughtError::RustPanic {
478 ereport: ErrorReportWithLevel {
479 level: PgLogLevel::ERROR,
480 inner: ErrorReport::with_location(
481 PgSqlErrorCode::ERRCODE_INTERNAL_ERROR,
482 "Box<Any>",
483 take_panic_location(),
484 ),
485 },
486 payload: e,
487 }
488 }
489}
490
491fn do_ereport(ereport: ErrorReportWithLevel) {
500 const PERCENT_S: &CStr = c"%s";
501 const DOMAIN: *const ::std::os::raw::c_char = std::ptr::null_mut();
502
503 crate::thread_check::check_active_thread();
507
508 #[cfg_attr(target_os = "windows", link(name = "postgres"))]
514 extern "C" {
515 fn errcode(sqlerrcode: ::std::os::raw::c_int) -> ::std::os::raw::c_int;
516 fn errmsg(fmt: *const ::std::os::raw::c_char, ...) -> ::std::os::raw::c_int;
517 fn errdetail(fmt: *const ::std::os::raw::c_char, ...) -> ::std::os::raw::c_int;
518 fn errhint(fmt: *const ::std::os::raw::c_char, ...) -> ::std::os::raw::c_int;
519 fn errcontext_msg(fmt: *const ::std::os::raw::c_char, ...) -> ::std::os::raw::c_int;
520 }
521
522 #[inline(always)]
525 #[rustfmt::skip] #[cfg(any(feature = "pg13", feature = "pg14", feature = "pg15", feature = "pg16", feature = "pg17"))]
527 fn do_ereport_impl(ereport: ErrorReportWithLevel) {
528
529 #[cfg_attr(target_os = "windows", link(name = "postgres"))]
530 extern "C" {
531 fn errstart(elevel: ::std::os::raw::c_int, domain: *const ::std::os::raw::c_char) -> bool;
532 fn errfinish(filename: *const ::std::os::raw::c_char, lineno: ::std::os::raw::c_int, funcname: *const ::std::os::raw::c_char);
533 }
534
535 let level = ereport.level();
536 unsafe {
537 if errstart(level as _, DOMAIN) {
538
539 let sqlerrcode = ereport.sql_error_code();
540 let message = ereport.message().as_pg_cstr();
541 let detail = ereport.detail_with_backtrace().as_pg_cstr();
542 let hint = ereport.hint().as_pg_cstr();
543 let context = ereport.context_message().as_pg_cstr();
544 let lineno = ereport.line_number();
545
546 let prev_cxt = MemoryContextSwitchTo(crate::ErrorContext);
551 let file = ereport.file().as_pg_cstr();
552 let funcname = ereport.function_name().as_pg_cstr();
553 MemoryContextSwitchTo(prev_cxt);
554
555 drop(ereport);
557
558 errcode(sqlerrcode as _);
565 if !message.is_null() { errmsg(PERCENT_S.as_ptr(), message); pfree(message.cast()); }
566 if !detail.is_null() { errdetail(PERCENT_S.as_ptr(), detail); pfree(detail.cast()); }
567 if !hint.is_null() { errhint(PERCENT_S.as_ptr(), hint); pfree(hint.cast()); }
568 if !context.is_null() { errcontext_msg(PERCENT_S.as_ptr(), context); pfree(context.cast()); }
569
570 errfinish(file, lineno as _, funcname);
571
572 if level >= PgLogLevel::ERROR {
573 unreachable_unchecked()
577 } else {
578 if !file.is_null() { pfree(file.cast()); }
580 if !funcname.is_null() { pfree(funcname.cast()); }
581 }
582 }
583 }
584 }
585
586 #[inline(always)]
590 #[rustfmt::skip] #[cfg(feature = "pg12")]
592 fn do_ereport_impl(ereport: ErrorReportWithLevel) {
593
594 #[cfg_attr(target_os = "windows", link(name = "postgres"))]
595 extern "C" {
596 fn errstart(elevel: ::std::os::raw::c_int, filename: *const ::std::os::raw::c_char, lineno: ::std::os::raw::c_int, funcname: *const ::std::os::raw::c_char, domain: *const ::std::os::raw::c_char) -> bool;
597 fn errfinish(dummy: ::std::os::raw::c_int, ...);
598 }
599
600 unsafe {
601 let prev_cxt = MemoryContextSwitchTo(crate::ErrorContext);
606 let file = ereport.file().as_pg_cstr();
607 let lineno = ereport.line_number();
608 let funcname = ereport.function_name().as_pg_cstr();
609 MemoryContextSwitchTo(prev_cxt);
610
611 let level = ereport.level();
612 if errstart(level as _, file, lineno as _, funcname, DOMAIN) {
613
614 let sqlerrcode = ereport.sql_error_code();
615 let message = ereport.message().as_pg_cstr();
616 let detail = ereport.detail_with_backtrace().as_pg_cstr();
617 let hint = ereport.hint().as_pg_cstr();
618 let context = ereport.context_message().as_pg_cstr();
619
620
621 drop(ereport);
623
624 errcode(sqlerrcode as _);
631 if !message.is_null() { errmsg(PERCENT_S.as_ptr(), message); pfree(message.cast()); }
632 if !detail.is_null() { errdetail(PERCENT_S.as_ptr(), detail); pfree(detail.cast()); }
633 if !hint.is_null() { errhint(PERCENT_S.as_ptr(), hint); pfree(hint.cast()); }
634 if !context.is_null() { errcontext_msg(PERCENT_S.as_ptr(), context); pfree(context.cast()); }
635
636 errfinish(0);
637 }
638
639 if level >= PgLogLevel::ERROR {
640 unreachable_unchecked()
644 } else {
645 if !file.is_null() { pfree(file.cast()); }
647 if !funcname.is_null() { pfree(funcname.cast()); }
648 }
649 }
650 }
651
652 do_ereport_impl(ereport)
653}