use std::ops::Range;
use std::str::from_utf8;
use memchr::memchr;
use sqlx_core::bytes::Bytes;
use crate::error::Error;
use crate::io::ProtocolDecode;
use crate::message::{BackendMessage, BackendMessageFormat};
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
#[repr(u8)]
pub enum PgSeverity {
Panic,
Fatal,
Error,
Warning,
Notice,
Debug,
Info,
Log,
}
impl PgSeverity {
#[inline]
pub fn is_error(self) -> bool {
matches!(self, Self::Panic | Self::Fatal | Self::Error)
}
}
impl TryFrom<&str> for PgSeverity {
type Error = Error;
fn try_from(s: &str) -> Result<PgSeverity, Error> {
let result = match s {
"PANIC" => PgSeverity::Panic,
"FATAL" => PgSeverity::Fatal,
"ERROR" => PgSeverity::Error,
"WARNING" => PgSeverity::Warning,
"NOTICE" => PgSeverity::Notice,
"DEBUG" => PgSeverity::Debug,
"INFO" => PgSeverity::Info,
"LOG" => PgSeverity::Log,
severity => {
return Err(err_protocol!("unknown severity: {:?}", severity));
}
};
Ok(result)
}
}
#[derive(Debug)]
pub struct Notice {
storage: Bytes,
severity: PgSeverity,
message: Range<usize>,
code: Range<usize>,
}
impl Notice {
#[inline]
pub fn severity(&self) -> PgSeverity {
self.severity
}
#[inline]
pub fn code(&self) -> &str {
self.get_cached_str(self.code.clone())
}
#[inline]
pub fn message(&self) -> &str {
self.get_cached_str(self.message.clone())
}
#[inline]
pub fn get(&self, ty: u8) -> Option<&str> {
self.get_raw(ty).and_then(|v| from_utf8(v).ok())
}
pub fn get_raw(&self, ty: u8) -> Option<&[u8]> {
self.fields()
.filter(|(field, _)| *field == ty)
.map(|(_, range)| &self.storage[range])
.next()
}
}
impl Notice {
#[inline]
fn fields(&self) -> Fields<'_> {
Fields {
storage: &self.storage,
offset: 0,
}
}
#[inline]
fn get_cached_str(&self, cache: Range<usize>) -> &str {
from_utf8(&self.storage[cache]).unwrap()
}
}
impl ProtocolDecode<'_> for Notice {
fn decode_with(buf: Bytes, _: ()) -> Result<Self, Error> {
const DEFAULT_SEVERITY: PgSeverity = PgSeverity::Log;
let mut severity_v = None;
let mut severity_s = None;
let mut message = 0..0;
let mut code = 0..0;
let fields = Fields {
storage: &buf,
offset: 0,
};
for (field, v) in fields {
if !(message.is_empty() || code.is_empty()) {
break;
}
match field {
b'S' => {
severity_s = from_utf8(&buf[v.clone()])
.map_err(|_| notice_protocol_err())?
.try_into()
.ok();
}
b'V' => {
severity_v = Some(
from_utf8(&buf[v.clone()])
.map_err(|_| notice_protocol_err())?
.try_into()?,
);
}
b'M' => {
_ = from_utf8(&buf[v.clone()]).map_err(|_| notice_protocol_err())?;
message = v;
}
b'C' => {
_ = from_utf8(&buf[v.clone()]).map_err(|_| notice_protocol_err())?;
code = v;
}
_ => {}
}
}
Ok(Self {
severity: severity_v.or(severity_s).unwrap_or(DEFAULT_SEVERITY),
message,
code,
storage: buf,
})
}
}
impl BackendMessage for Notice {
const FORMAT: BackendMessageFormat = BackendMessageFormat::NoticeResponse;
fn decode_body(buf: Bytes) -> Result<Self, Error> {
Self::decode_with(buf, ())
}
}
struct Fields<'a> {
storage: &'a [u8],
offset: usize,
}
impl<'a> Iterator for Fields<'a> {
type Item = (u8, Range<usize>);
fn next(&mut self) -> Option<Self::Item> {
let ty = *self.storage.get(self.offset)?;
if ty == 0 {
return None;
}
self.offset = self.offset.checked_add(1)?;
let start = self.offset;
let len = memchr(b'\0', self.storage.get(start..)?)?;
let end = self.offset + len;
self.offset = end + 1;
Some((ty, start..end))
}
}
fn notice_protocol_err() -> Error {
Error::Protocol(
"Postgres returned a non-UTF-8 string for its error message. \
This is most likely due to an error that occurred during authentication and \
the default lc_messages locale is not binary-compatible with UTF-8. \
See the server logs for the error details."
.into(),
)
}
#[test]
fn test_decode_error_response() {
const DATA: &[u8] = b"SNOTICE\0VNOTICE\0C42710\0Mextension \"uuid-ossp\" already exists, skipping\0Fextension.c\0L1656\0RCreateExtension\0\0";
let m = Notice::decode(Bytes::from_static(DATA)).unwrap();
assert_eq!(
m.message(),
"extension \"uuid-ossp\" already exists, skipping"
);
assert!(matches!(m.severity(), PgSeverity::Notice));
assert_eq!(m.code(), "42710");
}
#[cfg(all(test, not(debug_assertions)))]
#[bench]
fn bench_error_response_get_message(b: &mut test::Bencher) {
const DATA: &[u8] = b"SNOTICE\0VNOTICE\0C42710\0Mextension \"uuid-ossp\" already exists, skipping\0Fextension.c\0L1656\0RCreateExtension\0\0";
let res = Notice::decode(test::black_box(Bytes::from_static(DATA))).unwrap();
b.iter(|| {
let _ = test::black_box(&res).message();
});
}
#[cfg(all(test, not(debug_assertions)))]
#[bench]
fn bench_decode_error_response(b: &mut test::Bencher) {
const DATA: &[u8] = b"SNOTICE\0VNOTICE\0C42710\0Mextension \"uuid-ossp\" already exists, skipping\0Fextension.c\0L1656\0RCreateExtension\0\0";
b.iter(|| {
let _ = Notice::decode(test::black_box(Bytes::from_static(DATA)));
});
}