spin_sdk/
pg3.rs

1//! Conversions between Rust, WIT and **Postgres** types.
2//!
3//! # Types
4//!
5//! | Rust type               | WIT (db-value)                                | Postgres type(s)             |
6//! |-------------------------|-----------------------------------------------|----------------------------- |
7//! | `bool`                  | boolean(bool)                                 | BOOL                         |
8//! | `i16`                   | int16(s16)                                    | SMALLINT, SMALLSERIAL, INT2  |
9//! | `i32`                   | int32(s32)                                    | INT, SERIAL, INT4            |
10//! | `i64`                   | int64(s64)                                    | BIGINT, BIGSERIAL, INT8      |
11//! | `f32`                   | floating32(float32)                           | REAL, FLOAT4                 |
12//! | `f64`                   | floating64(float64)                           | DOUBLE PRECISION, FLOAT8     |
13//! | `String`                | str(string)                                   | VARCHAR, CHAR(N), TEXT       |
14//! | `Vec<u8>`               | binary(list\<u8\>)                            | BYTEA                        |
15//! | `chrono::NaiveDate`     | date(tuple<s32, u8, u8>)                      | DATE                         |
16//! | `chrono::NaiveTime`     | time(tuple<u8, u8, u8, u32>)                  | TIME                         |
17//! | `chrono::NaiveDateTime` | datetime(tuple<s32, u8, u8, u8, u8, u8, u32>) | TIMESTAMP                    |
18//! | `chrono::Duration`      | timestamp(s64)                                | BIGINT                       |
19
20#[doc(inline)]
21pub use super::wit::pg3::{Error as PgError, *};
22
23use chrono::{Datelike, Timelike};
24
25/// A pg error
26#[derive(Debug, thiserror::Error)]
27pub enum Error {
28    /// Failed to deserialize [`DbValue`]
29    #[error("error value decoding: {0}")]
30    Decode(String),
31    /// Pg query failed with an error
32    #[error(transparent)]
33    PgError(#[from] PgError),
34}
35
36/// A type that can be decoded from the database.
37pub trait Decode: Sized {
38    /// Decode a new value of this type using a [`DbValue`].
39    fn decode(value: &DbValue) -> Result<Self, Error>;
40}
41
42impl<T> Decode for Option<T>
43where
44    T: Decode,
45{
46    fn decode(value: &DbValue) -> Result<Self, Error> {
47        match value {
48            DbValue::DbNull => Ok(None),
49            v => Ok(Some(T::decode(v)?)),
50        }
51    }
52}
53
54impl Decode for bool {
55    fn decode(value: &DbValue) -> Result<Self, Error> {
56        match value {
57            DbValue::Boolean(boolean) => Ok(*boolean),
58            _ => Err(Error::Decode(format_decode_err("BOOL", value))),
59        }
60    }
61}
62
63impl Decode for i16 {
64    fn decode(value: &DbValue) -> Result<Self, Error> {
65        match value {
66            DbValue::Int16(n) => Ok(*n),
67            _ => Err(Error::Decode(format_decode_err("SMALLINT", value))),
68        }
69    }
70}
71
72impl Decode for i32 {
73    fn decode(value: &DbValue) -> Result<Self, Error> {
74        match value {
75            DbValue::Int32(n) => Ok(*n),
76            _ => Err(Error::Decode(format_decode_err("INT", value))),
77        }
78    }
79}
80
81impl Decode for i64 {
82    fn decode(value: &DbValue) -> Result<Self, Error> {
83        match value {
84            DbValue::Int64(n) => Ok(*n),
85            _ => Err(Error::Decode(format_decode_err("BIGINT", value))),
86        }
87    }
88}
89
90impl Decode for f32 {
91    fn decode(value: &DbValue) -> Result<Self, Error> {
92        match value {
93            DbValue::Floating32(n) => Ok(*n),
94            _ => Err(Error::Decode(format_decode_err("REAL", value))),
95        }
96    }
97}
98
99impl Decode for f64 {
100    fn decode(value: &DbValue) -> Result<Self, Error> {
101        match value {
102            DbValue::Floating64(n) => Ok(*n),
103            _ => Err(Error::Decode(format_decode_err("DOUBLE PRECISION", value))),
104        }
105    }
106}
107
108impl Decode for Vec<u8> {
109    fn decode(value: &DbValue) -> Result<Self, Error> {
110        match value {
111            DbValue::Binary(n) => Ok(n.to_owned()),
112            _ => Err(Error::Decode(format_decode_err("BYTEA", value))),
113        }
114    }
115}
116
117impl Decode for String {
118    fn decode(value: &DbValue) -> Result<Self, Error> {
119        match value {
120            DbValue::Str(s) => Ok(s.to_owned()),
121            _ => Err(Error::Decode(format_decode_err(
122                "CHAR, VARCHAR, TEXT",
123                value,
124            ))),
125        }
126    }
127}
128
129impl Decode for chrono::NaiveDate {
130    fn decode(value: &DbValue) -> Result<Self, Error> {
131        match value {
132            DbValue::Date((year, month, day)) => {
133                let naive_date =
134                    chrono::NaiveDate::from_ymd_opt(*year, (*month).into(), (*day).into())
135                        .ok_or_else(|| {
136                            Error::Decode(format!(
137                                "invalid date y={}, m={}, d={}",
138                                year, month, day
139                            ))
140                        })?;
141                Ok(naive_date)
142            }
143            _ => Err(Error::Decode(format_decode_err("DATE", value))),
144        }
145    }
146}
147
148impl Decode for chrono::NaiveTime {
149    fn decode(value: &DbValue) -> Result<Self, Error> {
150        match value {
151            DbValue::Time((hour, minute, second, nanosecond)) => {
152                let naive_time = chrono::NaiveTime::from_hms_nano_opt(
153                    (*hour).into(),
154                    (*minute).into(),
155                    (*second).into(),
156                    *nanosecond,
157                )
158                .ok_or_else(|| {
159                    Error::Decode(format!(
160                        "invalid time {}:{}:{}:{}",
161                        hour, minute, second, nanosecond
162                    ))
163                })?;
164                Ok(naive_time)
165            }
166            _ => Err(Error::Decode(format_decode_err("TIME", value))),
167        }
168    }
169}
170
171impl Decode for chrono::NaiveDateTime {
172    fn decode(value: &DbValue) -> Result<Self, Error> {
173        match value {
174            DbValue::Datetime((year, month, day, hour, minute, second, nanosecond)) => {
175                let naive_date =
176                    chrono::NaiveDate::from_ymd_opt(*year, (*month).into(), (*day).into())
177                        .ok_or_else(|| {
178                            Error::Decode(format!(
179                                "invalid date y={}, m={}, d={}",
180                                year, month, day
181                            ))
182                        })?;
183                let naive_time = chrono::NaiveTime::from_hms_nano_opt(
184                    (*hour).into(),
185                    (*minute).into(),
186                    (*second).into(),
187                    *nanosecond,
188                )
189                .ok_or_else(|| {
190                    Error::Decode(format!(
191                        "invalid time {}:{}:{}:{}",
192                        hour, minute, second, nanosecond
193                    ))
194                })?;
195                let dt = chrono::NaiveDateTime::new(naive_date, naive_time);
196                Ok(dt)
197            }
198            _ => Err(Error::Decode(format_decode_err("DATETIME", value))),
199        }
200    }
201}
202
203impl Decode for chrono::Duration {
204    fn decode(value: &DbValue) -> Result<Self, Error> {
205        match value {
206            DbValue::Timestamp(n) => Ok(chrono::Duration::seconds(*n)),
207            _ => Err(Error::Decode(format_decode_err("BIGINT", value))),
208        }
209    }
210}
211
212macro_rules! impl_parameter_value_conversions {
213    ($($ty:ty => $id:ident),*) => {
214        $(
215            impl From<$ty> for ParameterValue {
216                fn from(v: $ty) -> ParameterValue {
217                    ParameterValue::$id(v)
218                }
219            }
220        )*
221    };
222}
223
224impl_parameter_value_conversions! {
225    i8 => Int8,
226    i16 => Int16,
227    i32 => Int32,
228    i64 => Int64,
229    f32 => Floating32,
230    f64 => Floating64,
231    bool => Boolean,
232    String => Str,
233    Vec<u8> => Binary
234}
235
236impl From<chrono::NaiveDateTime> for ParameterValue {
237    fn from(v: chrono::NaiveDateTime) -> ParameterValue {
238        ParameterValue::Datetime((
239            v.year(),
240            v.month() as u8,
241            v.day() as u8,
242            v.hour() as u8,
243            v.minute() as u8,
244            v.second() as u8,
245            v.nanosecond(),
246        ))
247    }
248}
249
250impl From<chrono::NaiveTime> for ParameterValue {
251    fn from(v: chrono::NaiveTime) -> ParameterValue {
252        ParameterValue::Time((
253            v.hour() as u8,
254            v.minute() as u8,
255            v.second() as u8,
256            v.nanosecond(),
257        ))
258    }
259}
260
261impl From<chrono::NaiveDate> for ParameterValue {
262    fn from(v: chrono::NaiveDate) -> ParameterValue {
263        ParameterValue::Date((v.year(), v.month() as u8, v.day() as u8))
264    }
265}
266
267impl From<chrono::TimeDelta> for ParameterValue {
268    fn from(v: chrono::TimeDelta) -> ParameterValue {
269        ParameterValue::Timestamp(v.num_seconds())
270    }
271}
272
273impl<T: Into<ParameterValue>> From<Option<T>> for ParameterValue {
274    fn from(o: Option<T>) -> ParameterValue {
275        match o {
276            Some(v) => v.into(),
277            None => ParameterValue::DbNull,
278        }
279    }
280}
281
282fn format_decode_err(types: &str, value: &DbValue) -> String {
283    format!("Expected {} from the DB but got {:?}", types, value)
284}
285
286#[cfg(test)]
287mod tests {
288    use chrono::NaiveDateTime;
289
290    use super::*;
291
292    #[test]
293    fn boolean() {
294        assert!(bool::decode(&DbValue::Boolean(true)).unwrap());
295        assert!(bool::decode(&DbValue::Int32(0)).is_err());
296        assert!(Option::<bool>::decode(&DbValue::DbNull).unwrap().is_none());
297    }
298
299    #[test]
300    fn int16() {
301        assert_eq!(i16::decode(&DbValue::Int16(0)).unwrap(), 0);
302        assert!(i16::decode(&DbValue::Int32(0)).is_err());
303        assert!(Option::<i16>::decode(&DbValue::DbNull).unwrap().is_none());
304    }
305
306    #[test]
307    fn int32() {
308        assert_eq!(i32::decode(&DbValue::Int32(0)).unwrap(), 0);
309        assert!(i32::decode(&DbValue::Boolean(false)).is_err());
310        assert!(Option::<i32>::decode(&DbValue::DbNull).unwrap().is_none());
311    }
312
313    #[test]
314    fn int64() {
315        assert_eq!(i64::decode(&DbValue::Int64(0)).unwrap(), 0);
316        assert!(i64::decode(&DbValue::Boolean(false)).is_err());
317        assert!(Option::<i64>::decode(&DbValue::DbNull).unwrap().is_none());
318    }
319
320    #[test]
321    fn floating32() {
322        assert!(f32::decode(&DbValue::Floating32(0.0)).is_ok());
323        assert!(f32::decode(&DbValue::Boolean(false)).is_err());
324        assert!(Option::<f32>::decode(&DbValue::DbNull).unwrap().is_none());
325    }
326
327    #[test]
328    fn floating64() {
329        assert!(f64::decode(&DbValue::Floating64(0.0)).is_ok());
330        assert!(f64::decode(&DbValue::Boolean(false)).is_err());
331        assert!(Option::<f64>::decode(&DbValue::DbNull).unwrap().is_none());
332    }
333
334    #[test]
335    fn str() {
336        assert_eq!(
337            String::decode(&DbValue::Str(String::from("foo"))).unwrap(),
338            String::from("foo")
339        );
340
341        assert!(String::decode(&DbValue::Int32(0)).is_err());
342        assert!(Option::<String>::decode(&DbValue::DbNull)
343            .unwrap()
344            .is_none());
345    }
346
347    #[test]
348    fn binary() {
349        assert!(Vec::<u8>::decode(&DbValue::Binary(vec![0, 0])).is_ok());
350        assert!(Vec::<u8>::decode(&DbValue::Boolean(false)).is_err());
351        assert!(Option::<Vec<u8>>::decode(&DbValue::DbNull)
352            .unwrap()
353            .is_none());
354    }
355
356    #[test]
357    fn date() {
358        assert_eq!(
359            chrono::NaiveDate::decode(&DbValue::Date((1, 2, 4))).unwrap(),
360            chrono::NaiveDate::from_ymd_opt(1, 2, 4).unwrap()
361        );
362        assert_ne!(
363            chrono::NaiveDate::decode(&DbValue::Date((1, 2, 4))).unwrap(),
364            chrono::NaiveDate::from_ymd_opt(1, 2, 5).unwrap()
365        );
366        assert!(Option::<chrono::NaiveDate>::decode(&DbValue::DbNull)
367            .unwrap()
368            .is_none());
369    }
370
371    #[test]
372    fn time() {
373        assert_eq!(
374            chrono::NaiveTime::decode(&DbValue::Time((1, 2, 3, 4))).unwrap(),
375            chrono::NaiveTime::from_hms_nano_opt(1, 2, 3, 4).unwrap()
376        );
377        assert_ne!(
378            chrono::NaiveTime::decode(&DbValue::Time((1, 2, 3, 4))).unwrap(),
379            chrono::NaiveTime::from_hms_nano_opt(1, 2, 4, 5).unwrap()
380        );
381        assert!(Option::<chrono::NaiveTime>::decode(&DbValue::DbNull)
382            .unwrap()
383            .is_none());
384    }
385
386    #[test]
387    fn datetime() {
388        let date = chrono::NaiveDate::from_ymd_opt(1, 2, 3).unwrap();
389        let mut time = chrono::NaiveTime::from_hms_nano_opt(4, 5, 6, 7).unwrap();
390        assert_eq!(
391            chrono::NaiveDateTime::decode(&DbValue::Datetime((1, 2, 3, 4, 5, 6, 7))).unwrap(),
392            chrono::NaiveDateTime::new(date, time)
393        );
394
395        time = chrono::NaiveTime::from_hms_nano_opt(4, 5, 6, 8).unwrap();
396        assert_ne!(
397            NaiveDateTime::decode(&DbValue::Datetime((1, 2, 3, 4, 5, 6, 7))).unwrap(),
398            chrono::NaiveDateTime::new(date, time)
399        );
400        assert!(Option::<chrono::NaiveDateTime>::decode(&DbValue::DbNull)
401            .unwrap()
402            .is_none());
403    }
404
405    #[test]
406    fn timestamp() {
407        assert_eq!(
408            chrono::Duration::decode(&DbValue::Timestamp(1)).unwrap(),
409            chrono::Duration::seconds(1),
410        );
411        assert_ne!(
412            chrono::Duration::decode(&DbValue::Timestamp(2)).unwrap(),
413            chrono::Duration::seconds(1)
414        );
415        assert!(Option::<chrono::Duration>::decode(&DbValue::DbNull)
416            .unwrap()
417            .is_none());
418    }
419}