use std::str::from_utf8;
use memchr::memchr;
use sqlx_core::bytes::Bytes;
use crate::error::Error;
use crate::io::Decode;
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
#[repr(u8)]
pub enum PgSeverity {
Panic,
Fatal,
Error,
Warning,
Notice,
Debug,
Info,
Log,
}
impl PgSeverity {
#[inline]
pub fn is_error(self) -> bool {
matches!(self, Self::Panic | Self::Fatal | Self::Error)
}
}
impl TryFrom<&str> for PgSeverity {
type Error = Error;
fn try_from(s: &str) -> Result<PgSeverity, Error> {
let result = match s {
"PANIC" => PgSeverity::Panic,
"FATAL" => PgSeverity::Fatal,
"ERROR" => PgSeverity::Error,
"WARNING" => PgSeverity::Warning,
"NOTICE" => PgSeverity::Notice,
"DEBUG" => PgSeverity::Debug,
"INFO" => PgSeverity::Info,
"LOG" => PgSeverity::Log,
severity => {
return Err(err_protocol!("unknown severity: {:?}", severity));
}
};
Ok(result)
}
}
#[derive(Debug)]
pub struct Notice {
storage: Bytes,
severity: PgSeverity,
message: (u16, u16),
code: (u16, u16),
}
impl Notice {
#[inline]
pub fn severity(&self) -> PgSeverity {
self.severity
}
#[inline]
pub fn code(&self) -> &str {
self.get_cached_str(self.code)
}
#[inline]
pub fn message(&self) -> &str {
self.get_cached_str(self.message)
}
#[inline]
pub fn get(&self, ty: u8) -> Option<&str> {
self.get_raw(ty).and_then(|v| from_utf8(v).ok())
}
pub fn get_raw(&self, ty: u8) -> Option<&[u8]> {
self.fields()
.filter(|(field, _)| *field == ty)
.map(|(_, (start, end))| &self.storage[start as usize..end as usize])
.next()
}
}
impl Notice {
#[inline]
fn fields(&self) -> Fields<'_> {
Fields {
storage: &self.storage,
offset: 0,
}
}
#[inline]
fn get_cached_str(&self, cache: (u16, u16)) -> &str {
from_utf8(&self.storage[cache.0 as usize..cache.1 as usize]).unwrap()
}
}
impl Decode<'_> for Notice {
fn decode_with(buf: Bytes, _: ()) -> Result<Self, Error> {
const DEFAULT_SEVERITY: PgSeverity = PgSeverity::Log;
let mut severity_v = None;
let mut severity_s = None;
let mut message = (0, 0);
let mut code = (0, 0);
let fields = Fields {
storage: &buf,
offset: 0,
};
for (field, v) in fields {
if message.0 != 0 && code.0 != 0 {
break;
}
match field {
b'S' => {
severity_s = from_utf8(&buf[v.0 as usize..v.1 as usize])
.map_err(|_| notice_protocol_err())?
.try_into()
.ok();
}
b'V' => {
severity_v = Some(
from_utf8(&buf[v.0 as usize..v.1 as usize])
.map_err(|_| notice_protocol_err())?
.try_into()?,
);
}
b'M' => {
message = v;
}
b'C' => {
code = v;
}
_ => {}
}
}
Ok(Self {
severity: severity_v.or(severity_s).unwrap_or(DEFAULT_SEVERITY),
message,
code,
storage: buf,
})
}
}
struct Fields<'a> {
storage: &'a [u8],
offset: u16,
}
impl<'a> Iterator for Fields<'a> {
type Item = (u8, (u16, u16));
fn next(&mut self) -> Option<Self::Item> {
let ty = self.storage[self.offset as usize];
if ty == 0 {
return None;
}
let nul = memchr(b'\0', &self.storage[(self.offset + 1) as usize..])? as u16;
let offset = self.offset;
self.offset += nul + 2;
Some((ty, (offset + 1, offset + nul + 1)))
}
}
fn notice_protocol_err() -> Error {
Error::Protocol(
"Postgres returned a non-UTF-8 string for its error message. \
This is most likely due to an error that occurred during authentication and \
the default lc_messages locale is not binary-compatible with UTF-8. \
See the server logs for the error details."
.into(),
)
}
#[test]
fn test_decode_error_response() {
const DATA: &[u8] = b"SNOTICE\0VNOTICE\0C42710\0Mextension \"uuid-ossp\" already exists, skipping\0Fextension.c\0L1656\0RCreateExtension\0\0";
let m = Notice::decode(Bytes::from_static(DATA)).unwrap();
assert_eq!(
m.message(),
"extension \"uuid-ossp\" already exists, skipping"
);
assert!(matches!(m.severity(), PgSeverity::Notice));
assert_eq!(m.code(), "42710");
}
#[cfg(all(test, not(debug_assertions)))]
#[bench]
fn bench_error_response_get_message(b: &mut test::Bencher) {
const DATA: &[u8] = b"SNOTICE\0VNOTICE\0C42710\0Mextension \"uuid-ossp\" already exists, skipping\0Fextension.c\0L1656\0RCreateExtension\0\0";
let res = Notice::decode(test::black_box(Bytes::from_static(DATA))).unwrap();
b.iter(|| {
let _ = test::black_box(&res).message();
});
}
#[cfg(all(test, not(debug_assertions)))]
#[bench]
fn bench_decode_error_response(b: &mut test::Bencher) {
const DATA: &[u8] = b"SNOTICE\0VNOTICE\0C42710\0Mextension \"uuid-ossp\" already exists, skipping\0Fextension.c\0L1656\0RCreateExtension\0\0";
b.iter(|| {
let _ = Notice::decode(test::black_box(Bytes::from_static(DATA)));
});
}