sqlx_postgres/types/
range.rs

1use std::fmt::{self, Debug, Display, Formatter};
2use std::ops::{Bound, Range, RangeBounds, RangeFrom, RangeInclusive, RangeTo, RangeToInclusive};
3
4use bitflags::bitflags;
5use sqlx_core::bytes::Buf;
6
7use crate::decode::Decode;
8use crate::encode::{Encode, IsNull};
9use crate::error::BoxDynError;
10use crate::type_info::PgTypeKind;
11use crate::types::Type;
12use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres};
13
14// https://github.com/postgres/postgres/blob/2f48ede080f42b97b594fb14102c82ca1001b80c/src/include/utils/rangetypes.h#L35-L44
15bitflags! {
16    #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17    struct RangeFlags: u8 {
18        const EMPTY = 0x01;
19        const LB_INC = 0x02;
20        const UB_INC = 0x04;
21        const LB_INF = 0x08;
22        const UB_INF = 0x10;
23        const LB_NULL = 0x20; // not used
24        const UB_NULL = 0x40; // not used
25        const CONTAIN_EMPTY = 0x80; // internal
26    }
27}
28
29#[derive(Debug, PartialEq, Eq, Clone, Copy)]
30pub struct PgRange<T> {
31    pub start: Bound<T>,
32    pub end: Bound<T>,
33}
34
35impl<T> From<[Bound<T>; 2]> for PgRange<T> {
36    fn from(v: [Bound<T>; 2]) -> Self {
37        let [start, end] = v;
38        Self { start, end }
39    }
40}
41
42impl<T> From<(Bound<T>, Bound<T>)> for PgRange<T> {
43    fn from(v: (Bound<T>, Bound<T>)) -> Self {
44        Self {
45            start: v.0,
46            end: v.1,
47        }
48    }
49}
50
51impl<T> From<Range<T>> for PgRange<T> {
52    fn from(v: Range<T>) -> Self {
53        Self {
54            start: Bound::Included(v.start),
55            end: Bound::Excluded(v.end),
56        }
57    }
58}
59
60impl<T> From<RangeFrom<T>> for PgRange<T> {
61    fn from(v: RangeFrom<T>) -> Self {
62        Self {
63            start: Bound::Included(v.start),
64            end: Bound::Unbounded,
65        }
66    }
67}
68
69impl<T> From<RangeInclusive<T>> for PgRange<T> {
70    fn from(v: RangeInclusive<T>) -> Self {
71        let (start, end) = v.into_inner();
72        Self {
73            start: Bound::Included(start),
74            end: Bound::Included(end),
75        }
76    }
77}
78
79impl<T> From<RangeTo<T>> for PgRange<T> {
80    fn from(v: RangeTo<T>) -> Self {
81        Self {
82            start: Bound::Unbounded,
83            end: Bound::Excluded(v.end),
84        }
85    }
86}
87
88impl<T> From<RangeToInclusive<T>> for PgRange<T> {
89    fn from(v: RangeToInclusive<T>) -> Self {
90        Self {
91            start: Bound::Unbounded,
92            end: Bound::Included(v.end),
93        }
94    }
95}
96
97impl<T> RangeBounds<T> for PgRange<T> {
98    fn start_bound(&self) -> Bound<&T> {
99        match self.start {
100            Bound::Included(ref start) => Bound::Included(start),
101            Bound::Excluded(ref start) => Bound::Excluded(start),
102            Bound::Unbounded => Bound::Unbounded,
103        }
104    }
105
106    fn end_bound(&self) -> Bound<&T> {
107        match self.end {
108            Bound::Included(ref end) => Bound::Included(end),
109            Bound::Excluded(ref end) => Bound::Excluded(end),
110            Bound::Unbounded => Bound::Unbounded,
111        }
112    }
113}
114
115impl Type<Postgres> for PgRange<i32> {
116    fn type_info() -> PgTypeInfo {
117        PgTypeInfo::INT4_RANGE
118    }
119
120    fn compatible(ty: &PgTypeInfo) -> bool {
121        range_compatible::<i32>(ty)
122    }
123}
124
125impl Type<Postgres> for PgRange<i64> {
126    fn type_info() -> PgTypeInfo {
127        PgTypeInfo::INT8_RANGE
128    }
129
130    fn compatible(ty: &PgTypeInfo) -> bool {
131        range_compatible::<i64>(ty)
132    }
133}
134
135#[cfg(feature = "bigdecimal")]
136impl Type<Postgres> for PgRange<bigdecimal::BigDecimal> {
137    fn type_info() -> PgTypeInfo {
138        PgTypeInfo::NUM_RANGE
139    }
140
141    fn compatible(ty: &PgTypeInfo) -> bool {
142        range_compatible::<bigdecimal::BigDecimal>(ty)
143    }
144}
145
146#[cfg(feature = "rust_decimal")]
147impl Type<Postgres> for PgRange<rust_decimal::Decimal> {
148    fn type_info() -> PgTypeInfo {
149        PgTypeInfo::NUM_RANGE
150    }
151
152    fn compatible(ty: &PgTypeInfo) -> bool {
153        range_compatible::<rust_decimal::Decimal>(ty)
154    }
155}
156
157#[cfg(feature = "chrono")]
158impl Type<Postgres> for PgRange<chrono::NaiveDate> {
159    fn type_info() -> PgTypeInfo {
160        PgTypeInfo::DATE_RANGE
161    }
162
163    fn compatible(ty: &PgTypeInfo) -> bool {
164        range_compatible::<chrono::NaiveDate>(ty)
165    }
166}
167
168#[cfg(feature = "chrono")]
169impl Type<Postgres> for PgRange<chrono::NaiveDateTime> {
170    fn type_info() -> PgTypeInfo {
171        PgTypeInfo::TS_RANGE
172    }
173
174    fn compatible(ty: &PgTypeInfo) -> bool {
175        range_compatible::<chrono::NaiveDateTime>(ty)
176    }
177}
178
179#[cfg(feature = "chrono")]
180impl<Tz: chrono::TimeZone> Type<Postgres> for PgRange<chrono::DateTime<Tz>> {
181    fn type_info() -> PgTypeInfo {
182        PgTypeInfo::TSTZ_RANGE
183    }
184
185    fn compatible(ty: &PgTypeInfo) -> bool {
186        range_compatible::<chrono::DateTime<Tz>>(ty)
187    }
188}
189
190#[cfg(feature = "time")]
191impl Type<Postgres> for PgRange<time::Date> {
192    fn type_info() -> PgTypeInfo {
193        PgTypeInfo::DATE_RANGE
194    }
195
196    fn compatible(ty: &PgTypeInfo) -> bool {
197        range_compatible::<time::Date>(ty)
198    }
199}
200
201#[cfg(feature = "time")]
202impl Type<Postgres> for PgRange<time::PrimitiveDateTime> {
203    fn type_info() -> PgTypeInfo {
204        PgTypeInfo::TS_RANGE
205    }
206
207    fn compatible(ty: &PgTypeInfo) -> bool {
208        range_compatible::<time::PrimitiveDateTime>(ty)
209    }
210}
211
212#[cfg(feature = "time")]
213impl Type<Postgres> for PgRange<time::OffsetDateTime> {
214    fn type_info() -> PgTypeInfo {
215        PgTypeInfo::TSTZ_RANGE
216    }
217
218    fn compatible(ty: &PgTypeInfo) -> bool {
219        range_compatible::<time::OffsetDateTime>(ty)
220    }
221}
222
223impl PgHasArrayType for PgRange<i32> {
224    fn array_type_info() -> PgTypeInfo {
225        PgTypeInfo::INT4_RANGE_ARRAY
226    }
227}
228
229impl PgHasArrayType for PgRange<i64> {
230    fn array_type_info() -> PgTypeInfo {
231        PgTypeInfo::INT8_RANGE_ARRAY
232    }
233}
234
235#[cfg(feature = "bigdecimal")]
236impl PgHasArrayType for PgRange<bigdecimal::BigDecimal> {
237    fn array_type_info() -> PgTypeInfo {
238        PgTypeInfo::NUM_RANGE_ARRAY
239    }
240}
241
242#[cfg(feature = "rust_decimal")]
243impl PgHasArrayType for PgRange<rust_decimal::Decimal> {
244    fn array_type_info() -> PgTypeInfo {
245        PgTypeInfo::NUM_RANGE_ARRAY
246    }
247}
248
249#[cfg(feature = "chrono")]
250impl PgHasArrayType for PgRange<chrono::NaiveDate> {
251    fn array_type_info() -> PgTypeInfo {
252        PgTypeInfo::DATE_RANGE_ARRAY
253    }
254}
255
256#[cfg(feature = "chrono")]
257impl PgHasArrayType for PgRange<chrono::NaiveDateTime> {
258    fn array_type_info() -> PgTypeInfo {
259        PgTypeInfo::TS_RANGE_ARRAY
260    }
261}
262
263#[cfg(feature = "chrono")]
264impl<Tz: chrono::TimeZone> PgHasArrayType for PgRange<chrono::DateTime<Tz>> {
265    fn array_type_info() -> PgTypeInfo {
266        PgTypeInfo::TSTZ_RANGE_ARRAY
267    }
268}
269
270#[cfg(feature = "time")]
271impl PgHasArrayType for PgRange<time::Date> {
272    fn array_type_info() -> PgTypeInfo {
273        PgTypeInfo::DATE_RANGE_ARRAY
274    }
275}
276
277#[cfg(feature = "time")]
278impl PgHasArrayType for PgRange<time::PrimitiveDateTime> {
279    fn array_type_info() -> PgTypeInfo {
280        PgTypeInfo::TS_RANGE_ARRAY
281    }
282}
283
284#[cfg(feature = "time")]
285impl PgHasArrayType for PgRange<time::OffsetDateTime> {
286    fn array_type_info() -> PgTypeInfo {
287        PgTypeInfo::TSTZ_RANGE_ARRAY
288    }
289}
290
291impl<'q, T> Encode<'q, Postgres> for PgRange<T>
292where
293    T: Encode<'q, Postgres>,
294{
295    fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
296        // https://github.com/postgres/postgres/blob/2f48ede080f42b97b594fb14102c82ca1001b80c/src/backend/utils/adt/rangetypes.c#L245
297
298        let mut flags = RangeFlags::empty();
299
300        flags |= match self.start {
301            Bound::Included(_) => RangeFlags::LB_INC,
302            Bound::Unbounded => RangeFlags::LB_INF,
303            Bound::Excluded(_) => RangeFlags::empty(),
304        };
305
306        flags |= match self.end {
307            Bound::Included(_) => RangeFlags::UB_INC,
308            Bound::Unbounded => RangeFlags::UB_INF,
309            Bound::Excluded(_) => RangeFlags::empty(),
310        };
311
312        buf.push(flags.bits());
313
314        if let Bound::Included(v) | Bound::Excluded(v) = &self.start {
315            buf.encode(v)?;
316        }
317
318        if let Bound::Included(v) | Bound::Excluded(v) = &self.end {
319            buf.encode(v)?;
320        }
321
322        // ranges are themselves never null
323        Ok(IsNull::No)
324    }
325}
326
327impl<'r, T> Decode<'r, Postgres> for PgRange<T>
328where
329    T: Type<Postgres> + for<'a> Decode<'a, Postgres>,
330{
331    fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
332        match value.format {
333            PgValueFormat::Binary => {
334                let element_ty = if let PgTypeKind::Range(element) = &value.type_info.0.kind() {
335                    element
336                } else {
337                    return Err(format!("unexpected non-range type {}", value.type_info).into());
338                };
339
340                let mut buf = value.as_bytes()?;
341
342                let mut start = Bound::Unbounded;
343                let mut end = Bound::Unbounded;
344
345                let flags = RangeFlags::from_bits_truncate(buf.get_u8());
346
347                if flags.contains(RangeFlags::EMPTY) {
348                    return Ok(PgRange { start, end });
349                }
350
351                if !flags.contains(RangeFlags::LB_INF) {
352                    let value =
353                        T::decode(PgValueRef::get(&mut buf, value.format, element_ty.clone())?)?;
354
355                    start = if flags.contains(RangeFlags::LB_INC) {
356                        Bound::Included(value)
357                    } else {
358                        Bound::Excluded(value)
359                    };
360                }
361
362                if !flags.contains(RangeFlags::UB_INF) {
363                    let value =
364                        T::decode(PgValueRef::get(&mut buf, value.format, element_ty.clone())?)?;
365
366                    end = if flags.contains(RangeFlags::UB_INC) {
367                        Bound::Included(value)
368                    } else {
369                        Bound::Excluded(value)
370                    };
371                }
372
373                Ok(PgRange { start, end })
374            }
375
376            PgValueFormat::Text => {
377                // https://github.com/postgres/postgres/blob/2f48ede080f42b97b594fb14102c82ca1001b80c/src/backend/utils/adt/rangetypes.c#L2046
378
379                let mut start = None;
380                let mut end = None;
381
382                let s = value.as_str()?;
383
384                // remember the bounds
385                let sb = s.as_bytes();
386                let lower = sb[0] as char;
387                let upper = sb[sb.len() - 1] as char;
388
389                // trim the wrapping braces/brackets
390                let s = &s[1..(s.len() - 1)];
391
392                let mut chars = s.chars();
393
394                let mut element = String::new();
395                let mut done = false;
396                let mut quoted = false;
397                let mut in_quotes = false;
398                let mut in_escape = false;
399                let mut prev_ch = '\0';
400                let mut count = 0;
401
402                while !done {
403                    element.clear();
404
405                    loop {
406                        match chars.next() {
407                            Some(ch) => {
408                                match ch {
409                                    _ if in_escape => {
410                                        element.push(ch);
411                                        in_escape = false;
412                                    }
413
414                                    '"' if in_quotes => {
415                                        in_quotes = false;
416                                    }
417
418                                    '"' => {
419                                        in_quotes = true;
420                                        quoted = true;
421
422                                        if prev_ch == '"' {
423                                            element.push('"')
424                                        }
425                                    }
426
427                                    '\\' if !in_escape => {
428                                        in_escape = true;
429                                    }
430
431                                    ',' if !in_quotes => break,
432
433                                    _ => {
434                                        element.push(ch);
435                                    }
436                                }
437                                prev_ch = ch;
438                            }
439
440                            None => {
441                                done = true;
442                                break;
443                            }
444                        }
445                    }
446
447                    count += 1;
448                    if !element.is_empty() || quoted {
449                        let value = Some(T::decode(PgValueRef {
450                            type_info: T::type_info(),
451                            format: PgValueFormat::Text,
452                            value: Some(element.as_bytes()),
453                            row: None,
454                        })?);
455
456                        if count == 1 {
457                            start = value;
458                        } else if count == 2 {
459                            end = value;
460                        } else {
461                            return Err("more than 2 elements found in a range".into());
462                        }
463                    }
464                }
465
466                let start = parse_bound(lower, start)?;
467                let end = parse_bound(upper, end)?;
468
469                Ok(PgRange { start, end })
470            }
471        }
472    }
473}
474
475fn parse_bound<T>(ch: char, value: Option<T>) -> Result<Bound<T>, BoxDynError> {
476    Ok(if let Some(value) = value {
477        match ch {
478            '(' | ')' => Bound::Excluded(value),
479            '[' | ']' => Bound::Included(value),
480
481            _ => {
482                return Err(format!(
483                    "expected `(`, ')', '[', or `]` but found `{ch}` for range literal"
484                )
485                .into());
486            }
487        }
488    } else {
489        Bound::Unbounded
490    })
491}
492
493impl<T> Display for PgRange<T>
494where
495    T: Display,
496{
497    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
498        match &self.start {
499            Bound::Unbounded => f.write_str("(,")?,
500            Bound::Excluded(v) => write!(f, "({v},")?,
501            Bound::Included(v) => write!(f, "[{v},")?,
502        }
503
504        match &self.end {
505            Bound::Unbounded => f.write_str(")")?,
506            Bound::Excluded(v) => write!(f, "{v})")?,
507            Bound::Included(v) => write!(f, "{v}]")?,
508        }
509
510        Ok(())
511    }
512}
513
514fn range_compatible<E: Type<Postgres>>(ty: &PgTypeInfo) -> bool {
515    // we require the declared type to be a _range_ with an
516    // element type that is acceptable
517    if let PgTypeKind::Range(element) = &ty.kind() {
518        return E::compatible(element);
519    }
520
521    false
522}