sqlx_postgres/types/
cube.rs

1use crate::decode::Decode;
2use crate::encode::{Encode, IsNull};
3use crate::error::BoxDynError;
4use crate::types::Type;
5use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres};
6use sqlx_core::bytes::Buf;
7use sqlx_core::Error;
8use std::mem;
9use std::str::FromStr;
10
11const BYTE_WIDTH: usize = 8;
12
13/// <https://github.com/postgres/postgres/blob/e3ec9dc1bf4983fcedb6f43c71ea12ee26aefc7a/contrib/cube/cubedata.h#L7>
14const MAX_DIMENSIONS: usize = 100;
15
16const IS_POINT_FLAG: u32 = 1 << 31;
17
18// FIXME(breaking): these variants are confusingly named and structured
19// consider changing them or making this an opaque wrapper around `Vec<f64>`
20#[derive(Debug, Clone, PartialEq)]
21pub enum PgCube {
22    /// A one-dimensional point.
23    // FIXME: `Point1D(f64)
24    Point(f64),
25    /// An N-dimensional point ("represented internally as a zero-volume cube").
26    // FIXME: `PointND(f64)`
27    ZeroVolume(Vec<f64>),
28
29    /// A one-dimensional interval with starting and ending points.
30    // FIXME: `Interval1D { start: f64, end: f64 }`
31    OneDimensionInterval(f64, f64),
32
33    // FIXME: add `Cube3D { lower_left: [f64; 3], upper_right: [f64; 3] }`?
34    /// An N-dimensional cube with points representing lower-left and upper-right corners, respectively.
35    // FIXME: CubeND { lower_left: Vec<f64>, upper_right: Vec<f64> }`
36    MultiDimension(Vec<Vec<f64>>),
37}
38
39#[derive(Copy, Clone, Debug, PartialEq, Eq)]
40struct Header {
41    dimensions: usize,
42    is_point: bool,
43}
44
45#[derive(Debug, thiserror::Error)]
46#[error("error decoding CUBE (is_point: {is_point}, dimensions: {dimensions})")]
47struct DecodeError {
48    is_point: bool,
49    dimensions: usize,
50    message: String,
51}
52
53impl Type<Postgres> for PgCube {
54    fn type_info() -> PgTypeInfo {
55        PgTypeInfo::with_name("cube")
56    }
57}
58
59impl PgHasArrayType for PgCube {
60    fn array_type_info() -> PgTypeInfo {
61        PgTypeInfo::with_name("_cube")
62    }
63}
64
65impl<'r> Decode<'r, Postgres> for PgCube {
66    fn decode(value: PgValueRef<'r>) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
67        match value.format() {
68            PgValueFormat::Text => Ok(PgCube::from_str(value.as_str()?)?),
69            PgValueFormat::Binary => Ok(PgCube::from_bytes(value.as_bytes()?)?),
70        }
71    }
72}
73
74impl<'q> Encode<'q, Postgres> for PgCube {
75    fn produces(&self) -> Option<PgTypeInfo> {
76        Some(PgTypeInfo::with_name("cube"))
77    }
78
79    fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
80        self.serialize(buf)?;
81        Ok(IsNull::No)
82    }
83
84    fn size_hint(&self) -> usize {
85        self.header().encoded_size()
86    }
87}
88
89impl FromStr for PgCube {
90    type Err = Error;
91
92    fn from_str(s: &str) -> Result<Self, Self::Err> {
93        let content = s
94            .trim_start_matches('(')
95            .trim_start_matches('[')
96            .trim_end_matches(')')
97            .trim_end_matches(']')
98            .replace(' ', "");
99
100        if !content.contains('(') && !content.contains(',') {
101            return parse_point(&content);
102        }
103
104        if !content.contains("),(") {
105            return parse_zero_volume(&content);
106        }
107
108        let point_vecs = content.split("),(").collect::<Vec<&str>>();
109        if point_vecs.len() == 2 && !point_vecs.iter().any(|pv| pv.contains(',')) {
110            return parse_one_dimensional_interval(point_vecs);
111        }
112
113        parse_multidimensional_interval(point_vecs)
114    }
115}
116
117impl PgCube {
118    fn header(&self) -> Header {
119        match self {
120            PgCube::Point(..) => Header {
121                is_point: true,
122                dimensions: 1,
123            },
124            PgCube::ZeroVolume(values) => Header {
125                is_point: true,
126                dimensions: values.len(),
127            },
128            PgCube::OneDimensionInterval(..) => Header {
129                is_point: false,
130                dimensions: 1,
131            },
132            PgCube::MultiDimension(multi_values) => Header {
133                is_point: false,
134                dimensions: multi_values.first().map(|arr| arr.len()).unwrap_or(0),
135            },
136        }
137    }
138
139    fn from_bytes(mut bytes: &[u8]) -> Result<Self, BoxDynError> {
140        let header = Header::try_read(&mut bytes)?;
141
142        if bytes.len() != header.data_size() {
143            return Err(DecodeError::new(
144                &header,
145                format!(
146                    "expected {} bytes after header, got {}",
147                    header.data_size(),
148                    bytes.len()
149                ),
150            )
151            .into());
152        }
153
154        match (header.is_point, header.dimensions) {
155            (true, 1) => Ok(PgCube::Point(bytes.get_f64())),
156            (true, _) => Ok(PgCube::ZeroVolume(
157                read_vec(&mut bytes).map_err(|e| DecodeError::new(&header, e))?,
158            )),
159            (false, 1) => Ok(PgCube::OneDimensionInterval(
160                bytes.get_f64(),
161                bytes.get_f64(),
162            )),
163            (false, _) => Ok(PgCube::MultiDimension(read_cube(&header, bytes)?)),
164        }
165    }
166
167    fn serialize(&self, buff: &mut PgArgumentBuffer) -> Result<(), String> {
168        let header = self.header();
169
170        buff.reserve(header.data_size());
171
172        header.try_write(buff)?;
173
174        match self {
175            PgCube::Point(value) => {
176                buff.extend_from_slice(&value.to_be_bytes());
177            }
178            PgCube::ZeroVolume(values) => {
179                buff.extend(values.iter().flat_map(|v| v.to_be_bytes()));
180            }
181            PgCube::OneDimensionInterval(x, y) => {
182                buff.extend_from_slice(&x.to_be_bytes());
183                buff.extend_from_slice(&y.to_be_bytes());
184            }
185            PgCube::MultiDimension(multi_values) => {
186                if multi_values.len() != 2 {
187                    return Err(format!("invalid CUBE value: {self:?}"));
188                }
189
190                buff.extend(
191                    multi_values
192                        .iter()
193                        .flat_map(|point| point.iter().flat_map(|scalar| scalar.to_be_bytes())),
194                );
195            }
196        };
197        Ok(())
198    }
199
200    #[cfg(test)]
201    fn serialize_to_vec(&self) -> Vec<u8> {
202        let mut buff = PgArgumentBuffer::default();
203        self.serialize(&mut buff).unwrap();
204        buff.to_vec()
205    }
206}
207
208fn read_vec(bytes: &mut &[u8]) -> Result<Vec<f64>, String> {
209    if bytes.len() % BYTE_WIDTH != 0 {
210        return Err(format!(
211            "data length not divisible by {BYTE_WIDTH}: {}",
212            bytes.len()
213        ));
214    }
215
216    let mut out = Vec::with_capacity(bytes.len() / BYTE_WIDTH);
217
218    while bytes.has_remaining() {
219        out.push(bytes.get_f64());
220    }
221
222    Ok(out)
223}
224
225fn read_cube(header: &Header, mut bytes: &[u8]) -> Result<Vec<Vec<f64>>, String> {
226    if bytes.len() != header.data_size() {
227        return Err(format!(
228            "expected {} bytes, got {}",
229            header.data_size(),
230            bytes.len()
231        ));
232    }
233
234    let mut out = Vec::with_capacity(2);
235
236    // Expecting exactly 2 N-dimensional points
237    for _ in 0..2 {
238        let mut point = Vec::new();
239
240        for _ in 0..header.dimensions {
241            point.push(bytes.get_f64());
242        }
243
244        out.push(point);
245    }
246
247    Ok(out)
248}
249
250fn parse_float_from_str(s: &str, error_msg: &str) -> Result<f64, Error> {
251    s.parse().map_err(|_| Error::Decode(error_msg.into()))
252}
253
254fn parse_point(str: &str) -> Result<PgCube, Error> {
255    Ok(PgCube::Point(parse_float_from_str(
256        str,
257        "Failed to parse point",
258    )?))
259}
260
261fn parse_zero_volume(content: &str) -> Result<PgCube, Error> {
262    content
263        .split(',')
264        .map(|p| parse_float_from_str(p, "Failed to parse into zero-volume cube"))
265        .collect::<Result<Vec<_>, _>>()
266        .map(PgCube::ZeroVolume)
267}
268
269fn parse_one_dimensional_interval(point_vecs: Vec<&str>) -> Result<PgCube, Error> {
270    let x = parse_float_from_str(
271        &remove_parentheses(point_vecs.first().ok_or(Error::Decode(
272            format!("Could not decode cube interval x: {:?}", point_vecs).into(),
273        ))?),
274        "Failed to parse X in one-dimensional interval",
275    )?;
276    let y = parse_float_from_str(
277        &remove_parentheses(point_vecs.get(1).ok_or(Error::Decode(
278            format!("Could not decode cube interval y: {:?}", point_vecs).into(),
279        ))?),
280        "Failed to parse Y in one-dimensional interval",
281    )?;
282    Ok(PgCube::OneDimensionInterval(x, y))
283}
284
285fn parse_multidimensional_interval(point_vecs: Vec<&str>) -> Result<PgCube, Error> {
286    point_vecs
287        .iter()
288        .map(|&point_vec| {
289            point_vec
290                .split(',')
291                .map(|point| {
292                    parse_float_from_str(
293                        &remove_parentheses(point),
294                        "Failed to parse into multi-dimension cube",
295                    )
296                })
297                .collect::<Result<Vec<_>, _>>()
298        })
299        .collect::<Result<Vec<_>, _>>()
300        .map(PgCube::MultiDimension)
301}
302
303fn remove_parentheses(s: &str) -> String {
304    s.trim_matches(|c| c == '(' || c == ')').to_string()
305}
306
307impl Header {
308    const PACKED_WIDTH: usize = mem::size_of::<u32>();
309
310    fn encoded_size(&self) -> usize {
311        Self::PACKED_WIDTH + self.data_size()
312    }
313
314    fn data_size(&self) -> usize {
315        if self.is_point {
316            self.dimensions * BYTE_WIDTH
317        } else {
318            self.dimensions * BYTE_WIDTH * 2
319        }
320    }
321
322    fn try_write(&self, buff: &mut PgArgumentBuffer) -> Result<(), String> {
323        if self.dimensions > MAX_DIMENSIONS {
324            return Err(format!(
325                "CUBE dimensionality exceeds allowed maximum ({} > {MAX_DIMENSIONS})",
326                self.dimensions
327            ));
328        }
329
330        // Cannot overflow thanks to the above check.
331        #[allow(clippy::cast_possible_truncation)]
332        let mut packed = self.dimensions as u32;
333
334        // https://github.com/postgres/postgres/blob/e3ec9dc1bf4983fcedb6f43c71ea12ee26aefc7a/contrib/cube/cubedata.h#L18-L24
335        if self.is_point {
336            packed |= IS_POINT_FLAG;
337        }
338
339        buff.extend(packed.to_be_bytes());
340
341        Ok(())
342    }
343
344    fn try_read(buf: &mut &[u8]) -> Result<Self, String> {
345        if buf.len() < Self::PACKED_WIDTH {
346            return Err(format!(
347                "expected CUBE data to contain at least {} bytes, got {}",
348                Self::PACKED_WIDTH,
349                buf.len()
350            ));
351        }
352
353        let packed = buf.get_u32();
354
355        let is_point = packed & IS_POINT_FLAG != 0;
356        let dimensions = packed & !IS_POINT_FLAG;
357
358        // can only overflow on 16-bit platforms
359        let dimensions = usize::try_from(dimensions)
360            .ok()
361            .filter(|&it| it <= MAX_DIMENSIONS)
362            .ok_or_else(|| format!("received CUBE data with higher than expected dimensionality: {dimensions} (is_point: {is_point})"))?;
363
364        Ok(Self {
365            is_point,
366            dimensions,
367        })
368    }
369}
370
371impl DecodeError {
372    fn new(header: &Header, message: String) -> Self {
373        DecodeError {
374            is_point: header.is_point,
375            dimensions: header.dimensions,
376            message,
377        }
378    }
379}
380
381#[cfg(test)]
382mod cube_tests {
383
384    use std::str::FromStr;
385
386    use super::PgCube;
387
388    const POINT_BYTES: &[u8] = &[128, 0, 0, 1, 64, 0, 0, 0, 0, 0, 0, 0];
389    const ZERO_VOLUME_BYTES: &[u8] = &[
390        128, 0, 0, 2, 64, 0, 0, 0, 0, 0, 0, 0, 64, 8, 0, 0, 0, 0, 0, 0,
391    ];
392    const ONE_DIMENSIONAL_INTERVAL_BYTES: &[u8] = &[
393        0, 0, 0, 1, 64, 28, 0, 0, 0, 0, 0, 0, 64, 32, 0, 0, 0, 0, 0, 0,
394    ];
395    const MULTI_DIMENSION_2_DIM_BYTES: &[u8] = &[
396        0, 0, 0, 2, 63, 240, 0, 0, 0, 0, 0, 0, 64, 0, 0, 0, 0, 0, 0, 0, 64, 8, 0, 0, 0, 0, 0, 0,
397        64, 16, 0, 0, 0, 0, 0, 0,
398    ];
399    const MULTI_DIMENSION_3_DIM_BYTES: &[u8] = &[
400        0, 0, 0, 3, 64, 0, 0, 0, 0, 0, 0, 0, 64, 8, 0, 0, 0, 0, 0, 0, 64, 16, 0, 0, 0, 0, 0, 0, 64,
401        20, 0, 0, 0, 0, 0, 0, 64, 24, 0, 0, 0, 0, 0, 0, 64, 28, 0, 0, 0, 0, 0, 0,
402    ];
403
404    #[test]
405    fn can_deserialise_point_type_byes() {
406        let cube = PgCube::from_bytes(POINT_BYTES).unwrap();
407        assert_eq!(cube, PgCube::Point(2.))
408    }
409
410    #[test]
411    fn can_deserialise_point_type_str() {
412        let cube_1 = PgCube::from_str("(2)").unwrap();
413        assert_eq!(cube_1, PgCube::Point(2.));
414        let cube_2 = PgCube::from_str("2").unwrap();
415        assert_eq!(cube_2, PgCube::Point(2.));
416    }
417
418    #[test]
419    fn can_serialise_point_type() {
420        assert_eq!(PgCube::Point(2.).serialize_to_vec(), POINT_BYTES,)
421    }
422    #[test]
423    fn can_deserialise_zero_volume_bytes() {
424        let cube = PgCube::from_bytes(ZERO_VOLUME_BYTES).unwrap();
425        assert_eq!(cube, PgCube::ZeroVolume(vec![2., 3.]));
426    }
427
428    #[test]
429    fn can_deserialise_zero_volume_string() {
430        let cube_1 = PgCube::from_str("(2,3,4)").unwrap();
431        assert_eq!(cube_1, PgCube::ZeroVolume(vec![2., 3., 4.]));
432        let cube_2 = PgCube::from_str("2,3,4").unwrap();
433        assert_eq!(cube_2, PgCube::ZeroVolume(vec![2., 3., 4.]));
434    }
435
436    #[test]
437    fn can_serialise_zero_volume() {
438        assert_eq!(
439            PgCube::ZeroVolume(vec![2., 3.]).serialize_to_vec(),
440            ZERO_VOLUME_BYTES
441        );
442    }
443
444    #[test]
445    fn can_deserialise_one_dimension_interval_bytes() {
446        let cube = PgCube::from_bytes(ONE_DIMENSIONAL_INTERVAL_BYTES).unwrap();
447        assert_eq!(cube, PgCube::OneDimensionInterval(7., 8.))
448    }
449
450    #[test]
451    fn can_deserialise_one_dimension_interval_string() {
452        let cube_1 = PgCube::from_str("((7),(8))").unwrap();
453        assert_eq!(cube_1, PgCube::OneDimensionInterval(7., 8.));
454        let cube_2 = PgCube::from_str("(7),(8)").unwrap();
455        assert_eq!(cube_2, PgCube::OneDimensionInterval(7., 8.));
456    }
457
458    #[test]
459    fn can_serialise_one_dimension_interval() {
460        assert_eq!(
461            PgCube::OneDimensionInterval(7., 8.).serialize_to_vec(),
462            ONE_DIMENSIONAL_INTERVAL_BYTES
463        )
464    }
465
466    #[test]
467    fn can_deserialise_multi_dimension_2_dimension_byte() {
468        let cube = PgCube::from_bytes(MULTI_DIMENSION_2_DIM_BYTES).unwrap();
469        assert_eq!(
470            cube,
471            PgCube::MultiDimension(vec![vec![1., 2.], vec![3., 4.]])
472        )
473    }
474
475    #[test]
476    fn can_deserialise_multi_dimension_2_dimension_string() {
477        let cube_1 = PgCube::from_str("((1,2),(3,4))").unwrap();
478        assert_eq!(
479            cube_1,
480            PgCube::MultiDimension(vec![vec![1., 2.], vec![3., 4.]])
481        );
482        let cube_2 = PgCube::from_str("((1, 2), (3, 4))").unwrap();
483        assert_eq!(
484            cube_2,
485            PgCube::MultiDimension(vec![vec![1., 2.], vec![3., 4.]])
486        );
487        let cube_3 = PgCube::from_str("(1,2),(3,4)").unwrap();
488        assert_eq!(
489            cube_3,
490            PgCube::MultiDimension(vec![vec![1., 2.], vec![3., 4.]])
491        );
492        let cube_4 = PgCube::from_str("(1, 2), (3, 4)").unwrap();
493        assert_eq!(
494            cube_4,
495            PgCube::MultiDimension(vec![vec![1., 2.], vec![3., 4.]])
496        )
497    }
498
499    #[test]
500    fn can_serialise_multi_dimension_2_dimension() {
501        assert_eq!(
502            PgCube::MultiDimension(vec![vec![1., 2.], vec![3., 4.]]).serialize_to_vec(),
503            MULTI_DIMENSION_2_DIM_BYTES
504        )
505    }
506
507    #[test]
508    fn can_deserialise_multi_dimension_3_dimension_bytes() {
509        let cube = PgCube::from_bytes(MULTI_DIMENSION_3_DIM_BYTES).unwrap();
510        assert_eq!(
511            cube,
512            PgCube::MultiDimension(vec![vec![2., 3., 4.], vec![5., 6., 7.]])
513        )
514    }
515
516    #[test]
517    fn can_deserialise_multi_dimension_3_dimension_string() {
518        let cube = PgCube::from_str("((2,3,4),(5,6,7))").unwrap();
519        assert_eq!(
520            cube,
521            PgCube::MultiDimension(vec![vec![2., 3., 4.], vec![5., 6., 7.]])
522        );
523        let cube_2 = PgCube::from_str("(2,3,4),(5,6,7)").unwrap();
524        assert_eq!(
525            cube_2,
526            PgCube::MultiDimension(vec![vec![2., 3., 4.], vec![5., 6., 7.]])
527        );
528    }
529
530    #[test]
531    fn can_serialise_multi_dimension_3_dimension() {
532        assert_eq!(
533            PgCube::MultiDimension(vec![vec![2., 3., 4.], vec![5., 6., 7.]]).serialize_to_vec(),
534            MULTI_DIMENSION_3_DIM_BYTES
535        )
536    }
537}