sqlx_postgres/types/
lquery.rs

1use crate::decode::Decode;
2use crate::encode::{Encode, IsNull};
3use crate::error::BoxDynError;
4use crate::types::Type;
5use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres};
6use bitflags::bitflags;
7use std::fmt::{self, Display, Formatter};
8use std::io::Write;
9use std::ops::Deref;
10use std::str::FromStr;
11
12use crate::types::ltree::{PgLTreeLabel, PgLTreeParseError};
13
14/// Represents lquery specific errors
15#[derive(Debug, thiserror::Error)]
16#[non_exhaustive]
17pub enum PgLQueryParseError {
18    #[error("lquery cannot be empty")]
19    EmptyString,
20    #[error("unexpected character in lquery")]
21    UnexpectedCharacter,
22    #[error("error parsing integer: {0}")]
23    ParseIntError(#[from] std::num::ParseIntError),
24    #[error("error parsing integer: {0}")]
25    LTreeParrseError(#[from] PgLTreeParseError),
26    /// LQuery version not supported
27    #[error("lquery version not supported")]
28    InvalidLqueryVersion,
29}
30
31/// Container for a Label Tree Query (`lquery`) in Postgres.
32///
33/// See <https://www.postgresql.org/docs/current/ltree.html>
34///
35/// ### Note: Requires Postgres 13+
36///
37/// This integration requires that the `lquery` type support the binary format in the Postgres
38/// wire protocol, which only became available in Postgres 13.
39/// ([Postgres 13.0 Release Notes, Additional Modules](https://www.postgresql.org/docs/13/release-13.html#id-1.11.6.11.5.14))
40///
41/// Ideally, SQLx's Postgres driver should support falling back to text format for types
42/// which don't have `typsend` and `typrecv` entries in `pg_type`, but that work still needs
43/// to be done.
44///
45/// ### Note: Extension Required
46/// The `ltree` extension is not enabled by default in Postgres. You will need to do so explicitly:
47///
48/// ```ignore
49/// CREATE EXTENSION IF NOT EXISTS "ltree";
50/// ```
51#[derive(Clone, Debug, Default, PartialEq)]
52pub struct PgLQuery {
53    levels: Vec<PgLQueryLevel>,
54}
55
56// TODO: maybe a QueryBuilder pattern would be nice here
57impl PgLQuery {
58    /// creates default/empty lquery
59    pub fn new() -> Self {
60        Self::default()
61    }
62
63    pub fn from(levels: Vec<PgLQueryLevel>) -> Self {
64        Self { levels }
65    }
66
67    /// push a query level
68    pub fn push(&mut self, level: PgLQueryLevel) {
69        self.levels.push(level);
70    }
71
72    /// pop a query level
73    pub fn pop(&mut self) -> Option<PgLQueryLevel> {
74        self.levels.pop()
75    }
76
77    /// creates lquery from an iterator with checking labels
78    // TODO: this should just be removed but I didn't want to bury it in a massive diff
79    #[deprecated = "renamed to `try_from_iter()`"]
80    #[allow(clippy::should_implement_trait)]
81    pub fn from_iter<I, S>(levels: I) -> Result<Self, PgLQueryParseError>
82    where
83        S: Into<String>,
84        I: IntoIterator<Item = S>,
85    {
86        let mut lquery = Self::default();
87        for level in levels {
88            lquery.push(PgLQueryLevel::from_str(&level.into())?);
89        }
90        Ok(lquery)
91    }
92
93    /// Create an `LQUERY` from an iterator of label strings.
94    ///
95    /// Returns an error if any label fails to parse according to [`PgLQueryLevel::from_str()`].
96    pub fn try_from_iter<I, S>(levels: I) -> Result<Self, PgLQueryParseError>
97    where
98        S: AsRef<str>,
99        I: IntoIterator<Item = S>,
100    {
101        levels
102            .into_iter()
103            .map(|level| level.as_ref().parse::<PgLQueryLevel>())
104            .collect()
105    }
106}
107
108impl FromIterator<PgLQueryLevel> for PgLQuery {
109    fn from_iter<T: IntoIterator<Item = PgLQueryLevel>>(iter: T) -> Self {
110        Self::from(iter.into_iter().collect())
111    }
112}
113
114impl IntoIterator for PgLQuery {
115    type Item = PgLQueryLevel;
116    type IntoIter = std::vec::IntoIter<Self::Item>;
117
118    fn into_iter(self) -> Self::IntoIter {
119        self.levels.into_iter()
120    }
121}
122
123impl FromStr for PgLQuery {
124    type Err = PgLQueryParseError;
125
126    fn from_str(s: &str) -> Result<Self, Self::Err> {
127        Ok(Self {
128            levels: s
129                .split('.')
130                .map(PgLQueryLevel::from_str)
131                .collect::<Result<_, Self::Err>>()?,
132        })
133    }
134}
135
136impl Display for PgLQuery {
137    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
138        let mut iter = self.levels.iter();
139        if let Some(label) = iter.next() {
140            write!(f, "{label}")?;
141            for label in iter {
142                write!(f, ".{label}")?;
143            }
144        }
145        Ok(())
146    }
147}
148
149impl Deref for PgLQuery {
150    type Target = [PgLQueryLevel];
151
152    fn deref(&self) -> &Self::Target {
153        &self.levels
154    }
155}
156
157impl Type<Postgres> for PgLQuery {
158    fn type_info() -> PgTypeInfo {
159        // Since `ltree` is enabled by an extension, it does not have a stable OID.
160        PgTypeInfo::with_name("lquery")
161    }
162}
163
164impl PgHasArrayType for PgLQuery {
165    fn array_type_info() -> PgTypeInfo {
166        PgTypeInfo::with_name("_lquery")
167    }
168}
169
170impl Encode<'_, Postgres> for PgLQuery {
171    fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
172        buf.extend(1i8.to_le_bytes());
173        write!(buf, "{self}")?;
174
175        Ok(IsNull::No)
176    }
177}
178
179impl<'r> Decode<'r, Postgres> for PgLQuery {
180    fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
181        match value.format() {
182            PgValueFormat::Binary => {
183                let bytes = value.as_bytes()?;
184                let version = i8::from_le_bytes([bytes[0]; 1]);
185                if version != 1 {
186                    return Err(Box::new(PgLQueryParseError::InvalidLqueryVersion));
187                }
188                Ok(Self::from_str(std::str::from_utf8(&bytes[1..])?)?)
189            }
190            PgValueFormat::Text => Ok(Self::from_str(value.as_str()?)?),
191        }
192    }
193}
194
195bitflags! {
196    /// Modifiers that can be set to non-star labels
197    #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
198    pub struct PgLQueryVariantFlag: u16 {
199        /// * - Match any label with this prefix, for example foo* matches foobar
200        const ANY_END = 0x01;
201        /// @ - Match case-insensitively, for example a@ matches A
202        const IN_CASE = 0x02;
203        /// % - Match initial underscore-separated words
204        const SUBLEXEME = 0x04;
205    }
206}
207
208impl Display for PgLQueryVariantFlag {
209    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
210        if self.contains(PgLQueryVariantFlag::ANY_END) {
211            write!(f, "*")?;
212        }
213        if self.contains(PgLQueryVariantFlag::IN_CASE) {
214            write!(f, "@")?;
215        }
216        if self.contains(PgLQueryVariantFlag::SUBLEXEME) {
217            write!(f, "%")?;
218        }
219
220        Ok(())
221    }
222}
223
224#[derive(Clone, Debug, PartialEq)]
225pub struct PgLQueryVariant {
226    label: PgLTreeLabel,
227    modifiers: PgLQueryVariantFlag,
228}
229
230impl Display for PgLQueryVariant {
231    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
232        write!(f, "{}{}", self.label, self.modifiers)
233    }
234}
235
236#[derive(Clone, Debug, PartialEq)]
237pub enum PgLQueryLevel {
238    /// match any label (*) with optional at least / at most numbers
239    Star(Option<u16>, Option<u16>),
240    /// match any of specified labels with optional flags
241    NonStar(Vec<PgLQueryVariant>),
242    /// match none of specified labels with optional flags
243    NotNonStar(Vec<PgLQueryVariant>),
244}
245
246impl FromStr for PgLQueryLevel {
247    type Err = PgLQueryParseError;
248
249    fn from_str(s: &str) -> Result<Self, Self::Err> {
250        let bytes = s.as_bytes();
251        if bytes.is_empty() {
252            Err(PgLQueryParseError::EmptyString)
253        } else {
254            match bytes[0] {
255                b'*' => {
256                    if bytes.len() > 1 {
257                        let parts = s[2..s.len() - 1].split(',').collect::<Vec<_>>();
258                        match parts.len() {
259                            1 => {
260                                let number = parts[0].parse()?;
261                                Ok(PgLQueryLevel::Star(Some(number), Some(number)))
262                            }
263                            2 => Ok(PgLQueryLevel::Star(
264                                Some(parts[0].parse()?),
265                                Some(parts[1].parse()?),
266                            )),
267                            _ => Err(PgLQueryParseError::UnexpectedCharacter),
268                        }
269                    } else {
270                        Ok(PgLQueryLevel::Star(None, None))
271                    }
272                }
273                b'!' => Ok(PgLQueryLevel::NotNonStar(
274                    s[1..]
275                        .split('|')
276                        .map(PgLQueryVariant::from_str)
277                        .collect::<Result<Vec<_>, PgLQueryParseError>>()?,
278                )),
279                _ => Ok(PgLQueryLevel::NonStar(
280                    s.split('|')
281                        .map(PgLQueryVariant::from_str)
282                        .collect::<Result<Vec<_>, PgLQueryParseError>>()?,
283                )),
284            }
285        }
286    }
287}
288
289impl FromStr for PgLQueryVariant {
290    type Err = PgLQueryParseError;
291
292    fn from_str(s: &str) -> Result<Self, Self::Err> {
293        let mut label_length = s.len();
294        let mut modifiers = PgLQueryVariantFlag::empty();
295
296        for b in s.bytes().rev() {
297            match b {
298                b'@' => modifiers.insert(PgLQueryVariantFlag::IN_CASE),
299                b'*' => modifiers.insert(PgLQueryVariantFlag::ANY_END),
300                b'%' => modifiers.insert(PgLQueryVariantFlag::SUBLEXEME),
301                _ => break,
302            }
303            label_length -= 1;
304        }
305
306        Ok(PgLQueryVariant {
307            label: PgLTreeLabel::new(&s[0..label_length])?,
308            modifiers,
309        })
310    }
311}
312
313fn write_variants(f: &mut Formatter<'_>, variants: &[PgLQueryVariant], not: bool) -> fmt::Result {
314    let mut iter = variants.iter();
315    if let Some(variant) = iter.next() {
316        write!(f, "{}{}", if not { "!" } else { "" }, variant)?;
317        for variant in iter {
318            write!(f, ".{variant}")?;
319        }
320    }
321    Ok(())
322}
323
324impl Display for PgLQueryLevel {
325    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
326        match self {
327            PgLQueryLevel::Star(Some(at_least), Some(at_most)) => {
328                if at_least == at_most {
329                    write!(f, "*{{{at_least}}}")
330                } else {
331                    write!(f, "*{{{at_least},{at_most}}}")
332                }
333            }
334            PgLQueryLevel::Star(Some(at_least), _) => write!(f, "*{{{at_least},}}"),
335            PgLQueryLevel::Star(_, Some(at_most)) => write!(f, "*{{,{at_most}}}"),
336            PgLQueryLevel::Star(_, _) => write!(f, "*"),
337            PgLQueryLevel::NonStar(variants) => write_variants(f, variants, false),
338            PgLQueryLevel::NotNonStar(variants) => write_variants(f, variants, true),
339        }
340    }
341}