1use crate::decode::Decode;
2use crate::encode::{Encode, IsNull};
3use crate::error::BoxDynError;
4use crate::types::Type;
5use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres};
6use sqlx_core::bytes::Buf;
7use sqlx_core::Error;
8use std::mem;
9use std::str::FromStr;
10
11const BYTE_WIDTH: usize = 8;
12
13const MAX_DIMENSIONS: usize = 100;
15
16const IS_POINT_FLAG: u32 = 1 << 31;
17
18#[derive(Debug, Clone, PartialEq)]
21pub enum PgCube {
22 Point(f64),
25 ZeroVolume(Vec<f64>),
28
29 OneDimensionInterval(f64, f64),
32
33 MultiDimension(Vec<Vec<f64>>),
37}
38
39#[derive(Copy, Clone, Debug, PartialEq, Eq)]
40struct Header {
41 dimensions: usize,
42 is_point: bool,
43}
44
45#[derive(Debug, thiserror::Error)]
46#[error("error decoding CUBE (is_point: {is_point}, dimensions: {dimensions})")]
47struct DecodeError {
48 is_point: bool,
49 dimensions: usize,
50 message: String,
51}
52
53impl Type<Postgres> for PgCube {
54 fn type_info() -> PgTypeInfo {
55 PgTypeInfo::with_name("cube")
56 }
57}
58
59impl PgHasArrayType for PgCube {
60 fn array_type_info() -> PgTypeInfo {
61 PgTypeInfo::with_name("_cube")
62 }
63}
64
65impl<'r> Decode<'r, Postgres> for PgCube {
66 fn decode(value: PgValueRef<'r>) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
67 match value.format() {
68 PgValueFormat::Text => Ok(PgCube::from_str(value.as_str()?)?),
69 PgValueFormat::Binary => Ok(PgCube::from_bytes(value.as_bytes()?)?),
70 }
71 }
72}
73
74impl<'q> Encode<'q, Postgres> for PgCube {
75 fn produces(&self) -> Option<PgTypeInfo> {
76 Some(PgTypeInfo::with_name("cube"))
77 }
78
79 fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
80 self.serialize(buf)?;
81 Ok(IsNull::No)
82 }
83
84 fn size_hint(&self) -> usize {
85 self.header().encoded_size()
86 }
87}
88
89impl FromStr for PgCube {
90 type Err = Error;
91
92 fn from_str(s: &str) -> Result<Self, Self::Err> {
93 let content = s
94 .trim_start_matches('(')
95 .trim_start_matches('[')
96 .trim_end_matches(')')
97 .trim_end_matches(']')
98 .replace(' ', "");
99
100 if !content.contains('(') && !content.contains(',') {
101 return parse_point(&content);
102 }
103
104 if !content.contains("),(") {
105 return parse_zero_volume(&content);
106 }
107
108 let point_vecs = content.split("),(").collect::<Vec<&str>>();
109 if point_vecs.len() == 2 && !point_vecs.iter().any(|pv| pv.contains(',')) {
110 return parse_one_dimensional_interval(point_vecs);
111 }
112
113 parse_multidimensional_interval(point_vecs)
114 }
115}
116
117impl PgCube {
118 fn header(&self) -> Header {
119 match self {
120 PgCube::Point(..) => Header {
121 is_point: true,
122 dimensions: 1,
123 },
124 PgCube::ZeroVolume(values) => Header {
125 is_point: true,
126 dimensions: values.len(),
127 },
128 PgCube::OneDimensionInterval(..) => Header {
129 is_point: false,
130 dimensions: 1,
131 },
132 PgCube::MultiDimension(multi_values) => Header {
133 is_point: false,
134 dimensions: multi_values.first().map(|arr| arr.len()).unwrap_or(0),
135 },
136 }
137 }
138
139 fn from_bytes(mut bytes: &[u8]) -> Result<Self, BoxDynError> {
140 let header = Header::try_read(&mut bytes)?;
141
142 if bytes.len() != header.data_size() {
143 return Err(DecodeError::new(
144 &header,
145 format!(
146 "expected {} bytes after header, got {}",
147 header.data_size(),
148 bytes.len()
149 ),
150 )
151 .into());
152 }
153
154 match (header.is_point, header.dimensions) {
155 (true, 1) => Ok(PgCube::Point(bytes.get_f64())),
156 (true, _) => Ok(PgCube::ZeroVolume(
157 read_vec(&mut bytes).map_err(|e| DecodeError::new(&header, e))?,
158 )),
159 (false, 1) => Ok(PgCube::OneDimensionInterval(
160 bytes.get_f64(),
161 bytes.get_f64(),
162 )),
163 (false, _) => Ok(PgCube::MultiDimension(read_cube(&header, bytes)?)),
164 }
165 }
166
167 fn serialize(&self, buff: &mut PgArgumentBuffer) -> Result<(), String> {
168 let header = self.header();
169
170 buff.reserve(header.data_size());
171
172 header.try_write(buff)?;
173
174 match self {
175 PgCube::Point(value) => {
176 buff.extend_from_slice(&value.to_be_bytes());
177 }
178 PgCube::ZeroVolume(values) => {
179 buff.extend(values.iter().flat_map(|v| v.to_be_bytes()));
180 }
181 PgCube::OneDimensionInterval(x, y) => {
182 buff.extend_from_slice(&x.to_be_bytes());
183 buff.extend_from_slice(&y.to_be_bytes());
184 }
185 PgCube::MultiDimension(multi_values) => {
186 if multi_values.len() != 2 {
187 return Err(format!("invalid CUBE value: {self:?}"));
188 }
189
190 buff.extend(
191 multi_values
192 .iter()
193 .flat_map(|point| point.iter().flat_map(|scalar| scalar.to_be_bytes())),
194 );
195 }
196 };
197 Ok(())
198 }
199
200 #[cfg(test)]
201 fn serialize_to_vec(&self) -> Vec<u8> {
202 let mut buff = PgArgumentBuffer::default();
203 self.serialize(&mut buff).unwrap();
204 buff.to_vec()
205 }
206}
207
208fn read_vec(bytes: &mut &[u8]) -> Result<Vec<f64>, String> {
209 if bytes.len() % BYTE_WIDTH != 0 {
210 return Err(format!(
211 "data length not divisible by {BYTE_WIDTH}: {}",
212 bytes.len()
213 ));
214 }
215
216 let mut out = Vec::with_capacity(bytes.len() / BYTE_WIDTH);
217
218 while bytes.has_remaining() {
219 out.push(bytes.get_f64());
220 }
221
222 Ok(out)
223}
224
225fn read_cube(header: &Header, mut bytes: &[u8]) -> Result<Vec<Vec<f64>>, String> {
226 if bytes.len() != header.data_size() {
227 return Err(format!(
228 "expected {} bytes, got {}",
229 header.data_size(),
230 bytes.len()
231 ));
232 }
233
234 let mut out = Vec::with_capacity(2);
235
236 for _ in 0..2 {
238 let mut point = Vec::new();
239
240 for _ in 0..header.dimensions {
241 point.push(bytes.get_f64());
242 }
243
244 out.push(point);
245 }
246
247 Ok(out)
248}
249
250fn parse_float_from_str(s: &str, error_msg: &str) -> Result<f64, Error> {
251 s.parse().map_err(|_| Error::Decode(error_msg.into()))
252}
253
254fn parse_point(str: &str) -> Result<PgCube, Error> {
255 Ok(PgCube::Point(parse_float_from_str(
256 str,
257 "Failed to parse point",
258 )?))
259}
260
261fn parse_zero_volume(content: &str) -> Result<PgCube, Error> {
262 content
263 .split(',')
264 .map(|p| parse_float_from_str(p, "Failed to parse into zero-volume cube"))
265 .collect::<Result<Vec<_>, _>>()
266 .map(PgCube::ZeroVolume)
267}
268
269fn parse_one_dimensional_interval(point_vecs: Vec<&str>) -> Result<PgCube, Error> {
270 let x = parse_float_from_str(
271 &remove_parentheses(point_vecs.first().ok_or(Error::Decode(
272 format!("Could not decode cube interval x: {:?}", point_vecs).into(),
273 ))?),
274 "Failed to parse X in one-dimensional interval",
275 )?;
276 let y = parse_float_from_str(
277 &remove_parentheses(point_vecs.get(1).ok_or(Error::Decode(
278 format!("Could not decode cube interval y: {:?}", point_vecs).into(),
279 ))?),
280 "Failed to parse Y in one-dimensional interval",
281 )?;
282 Ok(PgCube::OneDimensionInterval(x, y))
283}
284
285fn parse_multidimensional_interval(point_vecs: Vec<&str>) -> Result<PgCube, Error> {
286 point_vecs
287 .iter()
288 .map(|&point_vec| {
289 point_vec
290 .split(',')
291 .map(|point| {
292 parse_float_from_str(
293 &remove_parentheses(point),
294 "Failed to parse into multi-dimension cube",
295 )
296 })
297 .collect::<Result<Vec<_>, _>>()
298 })
299 .collect::<Result<Vec<_>, _>>()
300 .map(PgCube::MultiDimension)
301}
302
303fn remove_parentheses(s: &str) -> String {
304 s.trim_matches(|c| c == '(' || c == ')').to_string()
305}
306
307impl Header {
308 const PACKED_WIDTH: usize = mem::size_of::<u32>();
309
310 fn encoded_size(&self) -> usize {
311 Self::PACKED_WIDTH + self.data_size()
312 }
313
314 fn data_size(&self) -> usize {
315 if self.is_point {
316 self.dimensions * BYTE_WIDTH
317 } else {
318 self.dimensions * BYTE_WIDTH * 2
319 }
320 }
321
322 fn try_write(&self, buff: &mut PgArgumentBuffer) -> Result<(), String> {
323 if self.dimensions > MAX_DIMENSIONS {
324 return Err(format!(
325 "CUBE dimensionality exceeds allowed maximum ({} > {MAX_DIMENSIONS})",
326 self.dimensions
327 ));
328 }
329
330 #[allow(clippy::cast_possible_truncation)]
332 let mut packed = self.dimensions as u32;
333
334 if self.is_point {
336 packed |= IS_POINT_FLAG;
337 }
338
339 buff.extend(packed.to_be_bytes());
340
341 Ok(())
342 }
343
344 fn try_read(buf: &mut &[u8]) -> Result<Self, String> {
345 if buf.len() < Self::PACKED_WIDTH {
346 return Err(format!(
347 "expected CUBE data to contain at least {} bytes, got {}",
348 Self::PACKED_WIDTH,
349 buf.len()
350 ));
351 }
352
353 let packed = buf.get_u32();
354
355 let is_point = packed & IS_POINT_FLAG != 0;
356 let dimensions = packed & !IS_POINT_FLAG;
357
358 let dimensions = usize::try_from(dimensions)
360 .ok()
361 .filter(|&it| it <= MAX_DIMENSIONS)
362 .ok_or_else(|| format!("received CUBE data with higher than expected dimensionality: {dimensions} (is_point: {is_point})"))?;
363
364 Ok(Self {
365 is_point,
366 dimensions,
367 })
368 }
369}
370
371impl DecodeError {
372 fn new(header: &Header, message: String) -> Self {
373 DecodeError {
374 is_point: header.is_point,
375 dimensions: header.dimensions,
376 message,
377 }
378 }
379}
380
381#[cfg(test)]
382mod cube_tests {
383
384 use std::str::FromStr;
385
386 use super::PgCube;
387
388 const POINT_BYTES: &[u8] = &[128, 0, 0, 1, 64, 0, 0, 0, 0, 0, 0, 0];
389 const ZERO_VOLUME_BYTES: &[u8] = &[
390 128, 0, 0, 2, 64, 0, 0, 0, 0, 0, 0, 0, 64, 8, 0, 0, 0, 0, 0, 0,
391 ];
392 const ONE_DIMENSIONAL_INTERVAL_BYTES: &[u8] = &[
393 0, 0, 0, 1, 64, 28, 0, 0, 0, 0, 0, 0, 64, 32, 0, 0, 0, 0, 0, 0,
394 ];
395 const MULTI_DIMENSION_2_DIM_BYTES: &[u8] = &[
396 0, 0, 0, 2, 63, 240, 0, 0, 0, 0, 0, 0, 64, 0, 0, 0, 0, 0, 0, 0, 64, 8, 0, 0, 0, 0, 0, 0,
397 64, 16, 0, 0, 0, 0, 0, 0,
398 ];
399 const MULTI_DIMENSION_3_DIM_BYTES: &[u8] = &[
400 0, 0, 0, 3, 64, 0, 0, 0, 0, 0, 0, 0, 64, 8, 0, 0, 0, 0, 0, 0, 64, 16, 0, 0, 0, 0, 0, 0, 64,
401 20, 0, 0, 0, 0, 0, 0, 64, 24, 0, 0, 0, 0, 0, 0, 64, 28, 0, 0, 0, 0, 0, 0,
402 ];
403
404 #[test]
405 fn can_deserialise_point_type_byes() {
406 let cube = PgCube::from_bytes(POINT_BYTES).unwrap();
407 assert_eq!(cube, PgCube::Point(2.))
408 }
409
410 #[test]
411 fn can_deserialise_point_type_str() {
412 let cube_1 = PgCube::from_str("(2)").unwrap();
413 assert_eq!(cube_1, PgCube::Point(2.));
414 let cube_2 = PgCube::from_str("2").unwrap();
415 assert_eq!(cube_2, PgCube::Point(2.));
416 }
417
418 #[test]
419 fn can_serialise_point_type() {
420 assert_eq!(PgCube::Point(2.).serialize_to_vec(), POINT_BYTES,)
421 }
422 #[test]
423 fn can_deserialise_zero_volume_bytes() {
424 let cube = PgCube::from_bytes(ZERO_VOLUME_BYTES).unwrap();
425 assert_eq!(cube, PgCube::ZeroVolume(vec![2., 3.]));
426 }
427
428 #[test]
429 fn can_deserialise_zero_volume_string() {
430 let cube_1 = PgCube::from_str("(2,3,4)").unwrap();
431 assert_eq!(cube_1, PgCube::ZeroVolume(vec![2., 3., 4.]));
432 let cube_2 = PgCube::from_str("2,3,4").unwrap();
433 assert_eq!(cube_2, PgCube::ZeroVolume(vec![2., 3., 4.]));
434 }
435
436 #[test]
437 fn can_serialise_zero_volume() {
438 assert_eq!(
439 PgCube::ZeroVolume(vec![2., 3.]).serialize_to_vec(),
440 ZERO_VOLUME_BYTES
441 );
442 }
443
444 #[test]
445 fn can_deserialise_one_dimension_interval_bytes() {
446 let cube = PgCube::from_bytes(ONE_DIMENSIONAL_INTERVAL_BYTES).unwrap();
447 assert_eq!(cube, PgCube::OneDimensionInterval(7., 8.))
448 }
449
450 #[test]
451 fn can_deserialise_one_dimension_interval_string() {
452 let cube_1 = PgCube::from_str("((7),(8))").unwrap();
453 assert_eq!(cube_1, PgCube::OneDimensionInterval(7., 8.));
454 let cube_2 = PgCube::from_str("(7),(8)").unwrap();
455 assert_eq!(cube_2, PgCube::OneDimensionInterval(7., 8.));
456 }
457
458 #[test]
459 fn can_serialise_one_dimension_interval() {
460 assert_eq!(
461 PgCube::OneDimensionInterval(7., 8.).serialize_to_vec(),
462 ONE_DIMENSIONAL_INTERVAL_BYTES
463 )
464 }
465
466 #[test]
467 fn can_deserialise_multi_dimension_2_dimension_byte() {
468 let cube = PgCube::from_bytes(MULTI_DIMENSION_2_DIM_BYTES).unwrap();
469 assert_eq!(
470 cube,
471 PgCube::MultiDimension(vec![vec![1., 2.], vec![3., 4.]])
472 )
473 }
474
475 #[test]
476 fn can_deserialise_multi_dimension_2_dimension_string() {
477 let cube_1 = PgCube::from_str("((1,2),(3,4))").unwrap();
478 assert_eq!(
479 cube_1,
480 PgCube::MultiDimension(vec![vec![1., 2.], vec![3., 4.]])
481 );
482 let cube_2 = PgCube::from_str("((1, 2), (3, 4))").unwrap();
483 assert_eq!(
484 cube_2,
485 PgCube::MultiDimension(vec![vec![1., 2.], vec![3., 4.]])
486 );
487 let cube_3 = PgCube::from_str("(1,2),(3,4)").unwrap();
488 assert_eq!(
489 cube_3,
490 PgCube::MultiDimension(vec![vec![1., 2.], vec![3., 4.]])
491 );
492 let cube_4 = PgCube::from_str("(1, 2), (3, 4)").unwrap();
493 assert_eq!(
494 cube_4,
495 PgCube::MultiDimension(vec![vec![1., 2.], vec![3., 4.]])
496 )
497 }
498
499 #[test]
500 fn can_serialise_multi_dimension_2_dimension() {
501 assert_eq!(
502 PgCube::MultiDimension(vec![vec![1., 2.], vec![3., 4.]]).serialize_to_vec(),
503 MULTI_DIMENSION_2_DIM_BYTES
504 )
505 }
506
507 #[test]
508 fn can_deserialise_multi_dimension_3_dimension_bytes() {
509 let cube = PgCube::from_bytes(MULTI_DIMENSION_3_DIM_BYTES).unwrap();
510 assert_eq!(
511 cube,
512 PgCube::MultiDimension(vec![vec![2., 3., 4.], vec![5., 6., 7.]])
513 )
514 }
515
516 #[test]
517 fn can_deserialise_multi_dimension_3_dimension_string() {
518 let cube = PgCube::from_str("((2,3,4),(5,6,7))").unwrap();
519 assert_eq!(
520 cube,
521 PgCube::MultiDimension(vec![vec![2., 3., 4.], vec![5., 6., 7.]])
522 );
523 let cube_2 = PgCube::from_str("(2,3,4),(5,6,7)").unwrap();
524 assert_eq!(
525 cube_2,
526 PgCube::MultiDimension(vec![vec![2., 3., 4.], vec![5., 6., 7.]])
527 );
528 }
529
530 #[test]
531 fn can_serialise_multi_dimension_3_dimension() {
532 assert_eq!(
533 PgCube::MultiDimension(vec![vec![2., 3., 4.], vec![5., 6., 7.]]).serialize_to_vec(),
534 MULTI_DIMENSION_3_DIM_BYTES
535 )
536 }
537}