sqlx_postgres/types/
array.rs

1use sqlx_core::bytes::Buf;
2use sqlx_core::types::Text;
3use std::borrow::Cow;
4
5use crate::decode::Decode;
6use crate::encode::{Encode, IsNull};
7use crate::error::BoxDynError;
8use crate::type_info::PgType;
9use crate::types::Oid;
10use crate::types::Type;
11use crate::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres};
12
13/// Provides information necessary to encode and decode Postgres arrays as compatible Rust types.
14///
15/// Implementing this trait for some type `T` enables relevant `Type`,`Encode` and `Decode` impls
16/// for `Vec<T>`, `&[T]` (slices), `[T; N]` (arrays), etc.
17///
18/// ### Note: `#[derive(sqlx::Type)]`
19/// If you have the `postgres` feature enabled, `#[derive(sqlx::Type)]` will also generate
20/// an impl of this trait for your type if your wrapper is marked `#[sqlx(transparent)]`:
21///
22/// ```rust,ignore
23/// #[derive(sqlx::Type)]
24/// #[sqlx(transparent)]
25/// struct UserId(i64);
26///
27/// let user_ids: Vec<UserId> = sqlx::query_scalar("select '{ 123, 456 }'::int8[]")
28///    .fetch(&mut pg_connection)
29///    .await?;
30/// ```
31///
32/// However, this may cause an error if the type being wrapped does not implement `PgHasArrayType`,
33/// e.g. `Vec` itself, because we don't currently support multidimensional arrays:
34///
35/// ```rust,ignore
36/// #[derive(sqlx::Type)] // ERROR: `Vec<i64>` does not implement `PgHasArrayType`
37/// #[sqlx(transparent)]
38/// struct UserIds(Vec<i64>);
39/// ```
40///
41/// To remedy this, add `#[sqlx(no_pg_array)]`, which disables the generation
42/// of the `PgHasArrayType` impl:
43///
44/// ```rust,ignore
45/// #[derive(sqlx::Type)]
46/// #[sqlx(transparent, no_pg_array)]
47/// struct UserIds(Vec<i64>);
48/// ```
49///
50/// See [the documentation of `Type`][Type] for more details.
51pub trait PgHasArrayType {
52    fn array_type_info() -> PgTypeInfo;
53    fn array_compatible(ty: &PgTypeInfo) -> bool {
54        *ty == Self::array_type_info()
55    }
56}
57
58impl<T> PgHasArrayType for &T
59where
60    T: PgHasArrayType,
61{
62    fn array_type_info() -> PgTypeInfo {
63        T::array_type_info()
64    }
65
66    fn array_compatible(ty: &PgTypeInfo) -> bool {
67        T::array_compatible(ty)
68    }
69}
70
71impl<T> PgHasArrayType for Option<T>
72where
73    T: PgHasArrayType,
74{
75    fn array_type_info() -> PgTypeInfo {
76        T::array_type_info()
77    }
78
79    fn array_compatible(ty: &PgTypeInfo) -> bool {
80        T::array_compatible(ty)
81    }
82}
83
84impl<T> PgHasArrayType for Text<T> {
85    fn array_type_info() -> PgTypeInfo {
86        String::array_type_info()
87    }
88
89    fn array_compatible(ty: &PgTypeInfo) -> bool {
90        String::array_compatible(ty)
91    }
92}
93
94impl<T> Type<Postgres> for [T]
95where
96    T: PgHasArrayType,
97{
98    fn type_info() -> PgTypeInfo {
99        T::array_type_info()
100    }
101
102    fn compatible(ty: &PgTypeInfo) -> bool {
103        T::array_compatible(ty)
104    }
105}
106
107impl<T> Type<Postgres> for Vec<T>
108where
109    T: PgHasArrayType,
110{
111    fn type_info() -> PgTypeInfo {
112        T::array_type_info()
113    }
114
115    fn compatible(ty: &PgTypeInfo) -> bool {
116        T::array_compatible(ty)
117    }
118}
119
120impl<T, const N: usize> Type<Postgres> for [T; N]
121where
122    T: PgHasArrayType,
123{
124    fn type_info() -> PgTypeInfo {
125        T::array_type_info()
126    }
127
128    fn compatible(ty: &PgTypeInfo) -> bool {
129        T::array_compatible(ty)
130    }
131}
132
133impl<'q, T> Encode<'q, Postgres> for Vec<T>
134where
135    for<'a> &'a [T]: Encode<'q, Postgres>,
136    T: Encode<'q, Postgres>,
137{
138    #[inline]
139    fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
140        self.as_slice().encode_by_ref(buf)
141    }
142}
143
144impl<'q, T, const N: usize> Encode<'q, Postgres> for [T; N]
145where
146    for<'a> &'a [T]: Encode<'q, Postgres>,
147    T: Encode<'q, Postgres>,
148{
149    fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
150        self.as_slice().encode_by_ref(buf)
151    }
152}
153
154impl<'q, T> Encode<'q, Postgres> for &'_ [T]
155where
156    T: Encode<'q, Postgres> + Type<Postgres>,
157{
158    fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
159        let type_info = self
160            .first()
161            .and_then(Encode::produces)
162            .unwrap_or_else(T::type_info);
163
164        buf.extend(&1_i32.to_be_bytes()); // number of dimensions
165        buf.extend(&0_i32.to_be_bytes()); // flags
166
167        // element type
168        match type_info.0 {
169            PgType::DeclareWithName(name) => buf.patch_type_by_name(&name),
170            PgType::DeclareArrayOf(array) => buf.patch_array_type(array),
171
172            ty => {
173                buf.extend(&ty.oid().0.to_be_bytes());
174            }
175        }
176
177        let array_len = i32::try_from(self.len()).map_err(|_| {
178            format!(
179                "encoded array length is too large for Postgres: {}",
180                self.len()
181            )
182        })?;
183
184        buf.extend(array_len.to_be_bytes()); // len
185        buf.extend(&1_i32.to_be_bytes()); // lower bound
186
187        for element in self.iter() {
188            buf.encode(element)?;
189        }
190
191        Ok(IsNull::No)
192    }
193}
194
195impl<'r, T, const N: usize> Decode<'r, Postgres> for [T; N]
196where
197    T: for<'a> Decode<'a, Postgres> + Type<Postgres>,
198{
199    fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
200        // This could be done more efficiently by refactoring the Vec decoding below so that it can
201        // be used for arrays and Vec.
202        let vec: Vec<T> = Decode::decode(value)?;
203        let array: [T; N] = vec.try_into().map_err(|_| "wrong number of elements")?;
204        Ok(array)
205    }
206}
207
208impl<'r, T> Decode<'r, Postgres> for Vec<T>
209where
210    T: for<'a> Decode<'a, Postgres> + Type<Postgres>,
211{
212    fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
213        let format = value.format();
214
215        match format {
216            PgValueFormat::Binary => {
217                // https://github.com/postgres/postgres/blob/a995b371ae29de2d38c4b7881cf414b1560e9746/src/backend/utils/adt/arrayfuncs.c#L1548
218
219                let mut buf = value.as_bytes()?;
220
221                // number of dimensions in the array
222                let ndim = buf.get_i32();
223
224                if ndim == 0 {
225                    // zero dimensions is an empty array
226                    return Ok(Vec::new());
227                }
228
229                if ndim != 1 {
230                    return Err(format!("encountered an array of {ndim} dimensions; only one-dimensional arrays are supported").into());
231                }
232
233                // appears to have been used in the past to communicate potential NULLS
234                // but reading source code back through our supported postgres versions (9.5+)
235                // this is never used for anything
236                let _flags = buf.get_i32();
237
238                // the OID of the element
239                let element_type_oid = Oid(buf.get_u32());
240                let element_type_info: PgTypeInfo = PgTypeInfo::try_from_oid(element_type_oid)
241                    .or_else(|| value.type_info.try_array_element().map(Cow::into_owned))
242                    .ok_or_else(|| {
243                        BoxDynError::from(format!(
244                            "failed to resolve array element type for oid {}",
245                            element_type_oid.0
246                        ))
247                    })?;
248
249                // length of the array axis
250                let len = buf.get_i32();
251
252                let len = usize::try_from(len)
253                    .map_err(|_| format!("overflow converting array len ({len}) to usize"))?;
254
255                // the lower bound, we only support arrays starting from "1"
256                let lower = buf.get_i32();
257
258                if lower != 1 {
259                    return Err(format!("encountered an array with a lower bound of {lower} in the first dimension; only arrays starting at one are supported").into());
260                }
261
262                let mut elements = Vec::with_capacity(len);
263
264                for _ in 0..len {
265                    let value_ref = PgValueRef::get(&mut buf, format, element_type_info.clone())?;
266
267                    elements.push(T::decode(value_ref)?);
268                }
269
270                Ok(elements)
271            }
272
273            PgValueFormat::Text => {
274                // no type is provided from the database for the element
275                let element_type_info = T::type_info();
276
277                let s = value.as_str()?;
278
279                // https://github.com/postgres/postgres/blob/a995b371ae29de2d38c4b7881cf414b1560e9746/src/backend/utils/adt/arrayfuncs.c#L718
280
281                // trim the wrapping braces
282                let s = &s[1..(s.len() - 1)];
283
284                if s.is_empty() {
285                    // short-circuit empty arrays up here
286                    return Ok(Vec::new());
287                }
288
289                // NOTE: Nearly *all* types use ',' as the sequence delimiter. Yes, there is one
290                //       that does not. The BOX (not PostGIS) type uses ';' as a delimiter.
291
292                // TODO: When we add support for BOX we need to figure out some way to make the
293                //       delimiter selection
294
295                let delimiter = ',';
296                let mut done = false;
297                let mut in_quotes = false;
298                let mut in_escape = false;
299                let mut value = String::with_capacity(10);
300                let mut chars = s.chars();
301                let mut elements = Vec::with_capacity(4);
302
303                while !done {
304                    loop {
305                        match chars.next() {
306                            Some(ch) => match ch {
307                                _ if in_escape => {
308                                    value.push(ch);
309                                    in_escape = false;
310                                }
311
312                                '"' => {
313                                    in_quotes = !in_quotes;
314                                }
315
316                                '\\' => {
317                                    in_escape = true;
318                                }
319
320                                _ if ch == delimiter && !in_quotes => {
321                                    break;
322                                }
323
324                                _ => {
325                                    value.push(ch);
326                                }
327                            },
328
329                            None => {
330                                done = true;
331                                break;
332                            }
333                        }
334                    }
335
336                    let value_opt = if value == "NULL" {
337                        None
338                    } else {
339                        Some(value.as_bytes())
340                    };
341
342                    elements.push(T::decode(PgValueRef {
343                        value: value_opt,
344                        row: None,
345                        type_info: element_type_info.clone(),
346                        format,
347                    })?);
348
349                    value.clear();
350                }
351
352                Ok(elements)
353            }
354        }
355    }
356}