1use std::fmt::{self, Debug, Display, Formatter};
2use std::ops::{Bound, Range, RangeBounds, RangeFrom, RangeInclusive, RangeTo, RangeToInclusive};
3
4use bitflags::bitflags;
5use sqlx_core::bytes::Buf;
6
7use crate::decode::Decode;
8use crate::encode::{Encode, IsNull};
9use crate::error::BoxDynError;
10use crate::type_info::PgTypeKind;
11use crate::types::Type;
12use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres};
13
14bitflags! {
16 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17 struct RangeFlags: u8 {
18 const EMPTY = 0x01;
19 const LB_INC = 0x02;
20 const UB_INC = 0x04;
21 const LB_INF = 0x08;
22 const UB_INF = 0x10;
23 const LB_NULL = 0x20; const UB_NULL = 0x40; const CONTAIN_EMPTY = 0x80; }
27}
28
29#[derive(Debug, PartialEq, Eq, Clone, Copy)]
30pub struct PgRange<T> {
31 pub start: Bound<T>,
32 pub end: Bound<T>,
33}
34
35impl<T> From<[Bound<T>; 2]> for PgRange<T> {
36 fn from(v: [Bound<T>; 2]) -> Self {
37 let [start, end] = v;
38 Self { start, end }
39 }
40}
41
42impl<T> From<(Bound<T>, Bound<T>)> for PgRange<T> {
43 fn from(v: (Bound<T>, Bound<T>)) -> Self {
44 Self {
45 start: v.0,
46 end: v.1,
47 }
48 }
49}
50
51impl<T> From<Range<T>> for PgRange<T> {
52 fn from(v: Range<T>) -> Self {
53 Self {
54 start: Bound::Included(v.start),
55 end: Bound::Excluded(v.end),
56 }
57 }
58}
59
60impl<T> From<RangeFrom<T>> for PgRange<T> {
61 fn from(v: RangeFrom<T>) -> Self {
62 Self {
63 start: Bound::Included(v.start),
64 end: Bound::Unbounded,
65 }
66 }
67}
68
69impl<T> From<RangeInclusive<T>> for PgRange<T> {
70 fn from(v: RangeInclusive<T>) -> Self {
71 let (start, end) = v.into_inner();
72 Self {
73 start: Bound::Included(start),
74 end: Bound::Included(end),
75 }
76 }
77}
78
79impl<T> From<RangeTo<T>> for PgRange<T> {
80 fn from(v: RangeTo<T>) -> Self {
81 Self {
82 start: Bound::Unbounded,
83 end: Bound::Excluded(v.end),
84 }
85 }
86}
87
88impl<T> From<RangeToInclusive<T>> for PgRange<T> {
89 fn from(v: RangeToInclusive<T>) -> Self {
90 Self {
91 start: Bound::Unbounded,
92 end: Bound::Included(v.end),
93 }
94 }
95}
96
97impl<T> RangeBounds<T> for PgRange<T> {
98 fn start_bound(&self) -> Bound<&T> {
99 match self.start {
100 Bound::Included(ref start) => Bound::Included(start),
101 Bound::Excluded(ref start) => Bound::Excluded(start),
102 Bound::Unbounded => Bound::Unbounded,
103 }
104 }
105
106 fn end_bound(&self) -> Bound<&T> {
107 match self.end {
108 Bound::Included(ref end) => Bound::Included(end),
109 Bound::Excluded(ref end) => Bound::Excluded(end),
110 Bound::Unbounded => Bound::Unbounded,
111 }
112 }
113}
114
115impl Type<Postgres> for PgRange<i32> {
116 fn type_info() -> PgTypeInfo {
117 PgTypeInfo::INT4_RANGE
118 }
119
120 fn compatible(ty: &PgTypeInfo) -> bool {
121 range_compatible::<i32>(ty)
122 }
123}
124
125impl Type<Postgres> for PgRange<i64> {
126 fn type_info() -> PgTypeInfo {
127 PgTypeInfo::INT8_RANGE
128 }
129
130 fn compatible(ty: &PgTypeInfo) -> bool {
131 range_compatible::<i64>(ty)
132 }
133}
134
135#[cfg(feature = "bigdecimal")]
136impl Type<Postgres> for PgRange<bigdecimal::BigDecimal> {
137 fn type_info() -> PgTypeInfo {
138 PgTypeInfo::NUM_RANGE
139 }
140
141 fn compatible(ty: &PgTypeInfo) -> bool {
142 range_compatible::<bigdecimal::BigDecimal>(ty)
143 }
144}
145
146#[cfg(feature = "rust_decimal")]
147impl Type<Postgres> for PgRange<rust_decimal::Decimal> {
148 fn type_info() -> PgTypeInfo {
149 PgTypeInfo::NUM_RANGE
150 }
151
152 fn compatible(ty: &PgTypeInfo) -> bool {
153 range_compatible::<rust_decimal::Decimal>(ty)
154 }
155}
156
157#[cfg(feature = "chrono")]
158impl Type<Postgres> for PgRange<chrono::NaiveDate> {
159 fn type_info() -> PgTypeInfo {
160 PgTypeInfo::DATE_RANGE
161 }
162
163 fn compatible(ty: &PgTypeInfo) -> bool {
164 range_compatible::<chrono::NaiveDate>(ty)
165 }
166}
167
168#[cfg(feature = "chrono")]
169impl Type<Postgres> for PgRange<chrono::NaiveDateTime> {
170 fn type_info() -> PgTypeInfo {
171 PgTypeInfo::TS_RANGE
172 }
173
174 fn compatible(ty: &PgTypeInfo) -> bool {
175 range_compatible::<chrono::NaiveDateTime>(ty)
176 }
177}
178
179#[cfg(feature = "chrono")]
180impl<Tz: chrono::TimeZone> Type<Postgres> for PgRange<chrono::DateTime<Tz>> {
181 fn type_info() -> PgTypeInfo {
182 PgTypeInfo::TSTZ_RANGE
183 }
184
185 fn compatible(ty: &PgTypeInfo) -> bool {
186 range_compatible::<chrono::DateTime<Tz>>(ty)
187 }
188}
189
190#[cfg(feature = "time")]
191impl Type<Postgres> for PgRange<time::Date> {
192 fn type_info() -> PgTypeInfo {
193 PgTypeInfo::DATE_RANGE
194 }
195
196 fn compatible(ty: &PgTypeInfo) -> bool {
197 range_compatible::<time::Date>(ty)
198 }
199}
200
201#[cfg(feature = "time")]
202impl Type<Postgres> for PgRange<time::PrimitiveDateTime> {
203 fn type_info() -> PgTypeInfo {
204 PgTypeInfo::TS_RANGE
205 }
206
207 fn compatible(ty: &PgTypeInfo) -> bool {
208 range_compatible::<time::PrimitiveDateTime>(ty)
209 }
210}
211
212#[cfg(feature = "time")]
213impl Type<Postgres> for PgRange<time::OffsetDateTime> {
214 fn type_info() -> PgTypeInfo {
215 PgTypeInfo::TSTZ_RANGE
216 }
217
218 fn compatible(ty: &PgTypeInfo) -> bool {
219 range_compatible::<time::OffsetDateTime>(ty)
220 }
221}
222
223impl PgHasArrayType for PgRange<i32> {
224 fn array_type_info() -> PgTypeInfo {
225 PgTypeInfo::INT4_RANGE_ARRAY
226 }
227}
228
229impl PgHasArrayType for PgRange<i64> {
230 fn array_type_info() -> PgTypeInfo {
231 PgTypeInfo::INT8_RANGE_ARRAY
232 }
233}
234
235#[cfg(feature = "bigdecimal")]
236impl PgHasArrayType for PgRange<bigdecimal::BigDecimal> {
237 fn array_type_info() -> PgTypeInfo {
238 PgTypeInfo::NUM_RANGE_ARRAY
239 }
240}
241
242#[cfg(feature = "rust_decimal")]
243impl PgHasArrayType for PgRange<rust_decimal::Decimal> {
244 fn array_type_info() -> PgTypeInfo {
245 PgTypeInfo::NUM_RANGE_ARRAY
246 }
247}
248
249#[cfg(feature = "chrono")]
250impl PgHasArrayType for PgRange<chrono::NaiveDate> {
251 fn array_type_info() -> PgTypeInfo {
252 PgTypeInfo::DATE_RANGE_ARRAY
253 }
254}
255
256#[cfg(feature = "chrono")]
257impl PgHasArrayType for PgRange<chrono::NaiveDateTime> {
258 fn array_type_info() -> PgTypeInfo {
259 PgTypeInfo::TS_RANGE_ARRAY
260 }
261}
262
263#[cfg(feature = "chrono")]
264impl<Tz: chrono::TimeZone> PgHasArrayType for PgRange<chrono::DateTime<Tz>> {
265 fn array_type_info() -> PgTypeInfo {
266 PgTypeInfo::TSTZ_RANGE_ARRAY
267 }
268}
269
270#[cfg(feature = "time")]
271impl PgHasArrayType for PgRange<time::Date> {
272 fn array_type_info() -> PgTypeInfo {
273 PgTypeInfo::DATE_RANGE_ARRAY
274 }
275}
276
277#[cfg(feature = "time")]
278impl PgHasArrayType for PgRange<time::PrimitiveDateTime> {
279 fn array_type_info() -> PgTypeInfo {
280 PgTypeInfo::TS_RANGE_ARRAY
281 }
282}
283
284#[cfg(feature = "time")]
285impl PgHasArrayType for PgRange<time::OffsetDateTime> {
286 fn array_type_info() -> PgTypeInfo {
287 PgTypeInfo::TSTZ_RANGE_ARRAY
288 }
289}
290
291impl<'q, T> Encode<'q, Postgres> for PgRange<T>
292where
293 T: Encode<'q, Postgres>,
294{
295 fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
296 let mut flags = RangeFlags::empty();
299
300 flags |= match self.start {
301 Bound::Included(_) => RangeFlags::LB_INC,
302 Bound::Unbounded => RangeFlags::LB_INF,
303 Bound::Excluded(_) => RangeFlags::empty(),
304 };
305
306 flags |= match self.end {
307 Bound::Included(_) => RangeFlags::UB_INC,
308 Bound::Unbounded => RangeFlags::UB_INF,
309 Bound::Excluded(_) => RangeFlags::empty(),
310 };
311
312 buf.push(flags.bits());
313
314 if let Bound::Included(v) | Bound::Excluded(v) = &self.start {
315 buf.encode(v)?;
316 }
317
318 if let Bound::Included(v) | Bound::Excluded(v) = &self.end {
319 buf.encode(v)?;
320 }
321
322 Ok(IsNull::No)
324 }
325}
326
327impl<'r, T> Decode<'r, Postgres> for PgRange<T>
328where
329 T: Type<Postgres> + for<'a> Decode<'a, Postgres>,
330{
331 fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
332 match value.format {
333 PgValueFormat::Binary => {
334 let element_ty = if let PgTypeKind::Range(element) = &value.type_info.0.kind() {
335 element
336 } else {
337 return Err(format!("unexpected non-range type {}", value.type_info).into());
338 };
339
340 let mut buf = value.as_bytes()?;
341
342 let mut start = Bound::Unbounded;
343 let mut end = Bound::Unbounded;
344
345 let flags = RangeFlags::from_bits_truncate(buf.get_u8());
346
347 if flags.contains(RangeFlags::EMPTY) {
348 return Ok(PgRange { start, end });
349 }
350
351 if !flags.contains(RangeFlags::LB_INF) {
352 let value =
353 T::decode(PgValueRef::get(&mut buf, value.format, element_ty.clone())?)?;
354
355 start = if flags.contains(RangeFlags::LB_INC) {
356 Bound::Included(value)
357 } else {
358 Bound::Excluded(value)
359 };
360 }
361
362 if !flags.contains(RangeFlags::UB_INF) {
363 let value =
364 T::decode(PgValueRef::get(&mut buf, value.format, element_ty.clone())?)?;
365
366 end = if flags.contains(RangeFlags::UB_INC) {
367 Bound::Included(value)
368 } else {
369 Bound::Excluded(value)
370 };
371 }
372
373 Ok(PgRange { start, end })
374 }
375
376 PgValueFormat::Text => {
377 let mut start = None;
380 let mut end = None;
381
382 let s = value.as_str()?;
383
384 let sb = s.as_bytes();
386 let lower = sb[0] as char;
387 let upper = sb[sb.len() - 1] as char;
388
389 let s = &s[1..(s.len() - 1)];
391
392 let mut chars = s.chars();
393
394 let mut element = String::new();
395 let mut done = false;
396 let mut quoted = false;
397 let mut in_quotes = false;
398 let mut in_escape = false;
399 let mut prev_ch = '\0';
400 let mut count = 0;
401
402 while !done {
403 element.clear();
404
405 loop {
406 match chars.next() {
407 Some(ch) => {
408 match ch {
409 _ if in_escape => {
410 element.push(ch);
411 in_escape = false;
412 }
413
414 '"' if in_quotes => {
415 in_quotes = false;
416 }
417
418 '"' => {
419 in_quotes = true;
420 quoted = true;
421
422 if prev_ch == '"' {
423 element.push('"')
424 }
425 }
426
427 '\\' if !in_escape => {
428 in_escape = true;
429 }
430
431 ',' if !in_quotes => break,
432
433 _ => {
434 element.push(ch);
435 }
436 }
437 prev_ch = ch;
438 }
439
440 None => {
441 done = true;
442 break;
443 }
444 }
445 }
446
447 count += 1;
448 if !element.is_empty() || quoted {
449 let value = Some(T::decode(PgValueRef {
450 type_info: T::type_info(),
451 format: PgValueFormat::Text,
452 value: Some(element.as_bytes()),
453 row: None,
454 })?);
455
456 if count == 1 {
457 start = value;
458 } else if count == 2 {
459 end = value;
460 } else {
461 return Err("more than 2 elements found in a range".into());
462 }
463 }
464 }
465
466 let start = parse_bound(lower, start)?;
467 let end = parse_bound(upper, end)?;
468
469 Ok(PgRange { start, end })
470 }
471 }
472 }
473}
474
475fn parse_bound<T>(ch: char, value: Option<T>) -> Result<Bound<T>, BoxDynError> {
476 Ok(if let Some(value) = value {
477 match ch {
478 '(' | ')' => Bound::Excluded(value),
479 '[' | ']' => Bound::Included(value),
480
481 _ => {
482 return Err(format!(
483 "expected `(`, ')', '[', or `]` but found `{ch}` for range literal"
484 )
485 .into());
486 }
487 }
488 } else {
489 Bound::Unbounded
490 })
491}
492
493impl<T> Display for PgRange<T>
494where
495 T: Display,
496{
497 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
498 match &self.start {
499 Bound::Unbounded => f.write_str("(,")?,
500 Bound::Excluded(v) => write!(f, "({v},")?,
501 Bound::Included(v) => write!(f, "[{v},")?,
502 }
503
504 match &self.end {
505 Bound::Unbounded => f.write_str(")")?,
506 Bound::Excluded(v) => write!(f, "{v})")?,
507 Bound::Included(v) => write!(f, "{v}]")?,
508 }
509
510 Ok(())
511 }
512}
513
514fn range_compatible<E: Type<Postgres>>(ty: &PgTypeInfo) -> bool {
515 if let PgTypeKind::Range(element) = &ty.kind() {
518 return E::compatible(element);
519 }
520
521 false
522}