alloy_primitives/
postgres.rs

1//! Support for the [`postgres_types`] crate.
2//!
3//! **WARNING**: this module depends entirely on [`postgres_types`, which is not yet stable,
4//! therefore this module is exempt from the semver guarantees of this crate.
5
6use super::{FixedBytes, Sign, Signed};
7use bytes::{BufMut, BytesMut};
8use derive_more::Display;
9use postgres_types::{accepts, to_sql_checked, FromSql, IsNull, ToSql, Type, WrongType};
10use std::{
11    error::Error,
12    iter,
13    str::{from_utf8, FromStr},
14};
15
16/// Converts `FixedBytes` to Postgres Bytea Type.
17impl<const BITS: usize> ToSql for FixedBytes<BITS> {
18    fn to_sql(&self, _: &Type, out: &mut BytesMut) -> Result<IsNull, BoxedError> {
19        out.put_slice(&self[..]);
20        Ok(IsNull::No)
21    }
22
23    accepts!(BYTEA);
24
25    to_sql_checked!();
26}
27
28/// Converts `FixedBytes` From Postgres Bytea Type.
29impl<'a, const BITS: usize> FromSql<'a> for FixedBytes<BITS> {
30    accepts!(BYTEA);
31
32    fn from_sql(_: &Type, raw: &'a [u8]) -> Result<Self, Box<dyn Error + Sync + Send>> {
33        Ok(Self::try_from(raw)?)
34    }
35}
36
37// https://github.com/recmo/uint/blob/6c755ad7cd54a0706d20f11f3f63b0d977af0226/src/support/postgres.rs#L22
38
39type BoxedError = Box<dyn Error + Sync + Send + 'static>;
40
41const fn rem_up(a: usize, b: usize) -> usize {
42    let rem = a % b;
43    if rem > 0 {
44        rem
45    } else {
46        b
47    }
48}
49
50fn last_idx<T: PartialEq>(x: &[T], value: &T) -> usize {
51    x.iter().rposition(|b| b != value).map_or(0, |idx| idx + 1)
52}
53
54fn trim_end_vec<T: PartialEq>(vec: &mut Vec<T>, value: &T) {
55    vec.truncate(last_idx(vec, value));
56}
57
58/// Error when converting to Postgres types.
59#[derive(Clone, Debug, PartialEq, Eq, Display)]
60pub enum ToSqlError {
61    /// The value is too large for the type.
62    #[display("Signed<{_0}> value too large to fit target type {_1}")]
63    Overflow(usize, Type),
64}
65
66impl core::error::Error for ToSqlError {}
67
68/// Convert to Postgres types.
69///
70/// Compatible [Postgres data types][dt] are:
71///
72/// * `BOOL`, `SMALLINT`, `INTEGER`, `BIGINT` which are 1, 16, 32 and 64 bit signed integers
73///   respectively.
74/// * `OID` which is a 32 bit unsigned integer.
75/// * `DECIMAL` and `NUMERIC`, which are variable length.
76/// * `MONEY` which is a 64 bit integer with two decimals.
77/// * `BYTEA`, `BIT`, `VARBIT` interpreted as a big-endian binary number.
78/// * `CHAR`, `VARCHAR`, `TEXT` as `0x`-prefixed big-endian hex strings.
79/// * `JSON`, `JSONB` as a hex string compatible with the Serde serialization.
80///
81/// # Errors
82///
83/// Returns an error when trying to convert to a value that is too small to fit
84/// the number. Note that this depends on the value, not the type, so a
85/// [`Signed<256>`] can be stored in a `SMALLINT` column, as long as the values
86/// are less than $2^{16}$.
87///
88/// # Implementation details
89///
90/// The Postgres binary formats are used in the wire-protocol and the
91/// the `COPY BINARY` command, but they have very little documentation. You are
92/// pointed to the source code, for example this is the implementation of the
93/// the `NUMERIC` type serializer: [`numeric.c`][numeric].
94///
95/// [dt]:https://www.postgresql.org/docs/9.5/datatype.html
96/// [numeric]: https://github.com/postgres/postgres/blob/05a5a1775c89f6beb326725282e7eea1373cbec8/src/backend/utils/adt/numeric.c#L1082
97impl<const BITS: usize, const LIMBS: usize> ToSql for Signed<BITS, LIMBS> {
98    fn to_sql(&self, ty: &Type, out: &mut BytesMut) -> Result<IsNull, BoxedError> {
99        match *ty {
100            // Big-endian simple types
101            // Note `BufMut::put_*` methods write big-endian by default.
102            Type::BOOL => out.put_u8(u8::from(bool::try_from(self.0)?)),
103            Type::INT2 => out.put_i16(self.0.try_into()?),
104            Type::INT4 => out.put_i32(self.0.try_into()?),
105            Type::OID => out.put_u32(self.0.try_into()?),
106            Type::INT8 => out.put_i64(self.0.try_into()?),
107
108            Type::MONEY => {
109                // Like i64, but with two decimals.
110                out.put_i64(
111                    i64::try_from(self.0)?
112                        .checked_mul(100)
113                        .ok_or(ToSqlError::Overflow(BITS, ty.clone()))?,
114                );
115            }
116
117            // Binary strings
118            Type::BYTEA => out.put_slice(&self.0.to_be_bytes_vec()),
119            Type::BIT | Type::VARBIT => {
120                // Bit in little-endian so the first bit is the least significant.
121                // Length must be at least one bit.
122                if BITS == 0 {
123                    if *ty == Type::BIT {
124                        // `bit(0)` is not a valid type, but varbit can be empty.
125                        return Err(Box::new(WrongType::new::<Self>(ty.clone())));
126                    }
127                    out.put_i32(0);
128                } else {
129                    // Bits are output in big-endian order, but padded at the
130                    // least significant end.
131                    let padding = 8 - rem_up(BITS, 8);
132                    out.put_i32(Self::BITS.try_into()?);
133                    let bytes = self.0.as_le_bytes();
134                    let mut bytes = bytes.iter().rev();
135                    let mut shifted = bytes.next().unwrap() << padding;
136                    for byte in bytes {
137                        shifted |= if padding > 0 { byte >> (8 - padding) } else { 0 };
138                        out.put_u8(shifted);
139                        shifted = byte << padding;
140                    }
141                    out.put_u8(shifted);
142                }
143            }
144
145            // Hex strings
146            Type::CHAR | Type::TEXT | Type::VARCHAR => {
147                out.put_slice(format!("{self:#x}").as_bytes());
148            }
149            Type::JSON | Type::JSONB => {
150                if *ty == Type::JSONB {
151                    // Version 1 of JSONB is just plain text JSON.
152                    out.put_u8(1);
153                }
154                out.put_slice(format!("\"{self:#x}\"").as_bytes());
155            }
156
157            // Binary coded decimal types
158            // See <https://github.com/postgres/postgres/blob/05a5a1775c89f6beb326725282e7eea1373cbec8/src/backend/utils/adt/numeric.c#L253>
159            Type::NUMERIC => {
160                // Everything is done in big-endian base 1000 digits.
161                const BASE: u64 = 10000;
162
163                let sign = match self.sign() {
164                    Sign::Positive => 0x0000,
165                    _ => 0x4000,
166                };
167
168                let mut digits: Vec<_> = self.abs().0.to_base_be(BASE).collect();
169                let exponent = digits.len().saturating_sub(1).try_into()?;
170
171                // Trailing zeros are removed.
172                trim_end_vec(&mut digits, &0);
173
174                out.put_i16(digits.len().try_into()?); // Number of digits.
175                out.put_i16(exponent); // Exponent of first digit.
176
177                out.put_i16(sign);
178                out.put_i16(0); // dscale: Number of digits to the right of the decimal point.
179                for digit in digits {
180                    debug_assert!(digit < BASE);
181                    #[allow(clippy::cast_possible_truncation)] // 10000 < i16::MAX
182                    out.put_i16(digit as i16);
183                }
184            }
185
186            // Unsupported types
187            _ => {
188                return Err(Box::new(WrongType::new::<Self>(ty.clone())));
189            }
190        };
191        Ok(IsNull::No)
192    }
193
194    fn accepts(ty: &Type) -> bool {
195        matches!(*ty, |Type::BOOL| Type::CHAR
196            | Type::INT2
197            | Type::INT4
198            | Type::INT8
199            | Type::OID
200            | Type::FLOAT4
201            | Type::FLOAT8
202            | Type::MONEY
203            | Type::NUMERIC
204            | Type::BYTEA
205            | Type::TEXT
206            | Type::VARCHAR
207            | Type::JSON
208            | Type::JSONB
209            | Type::BIT
210            | Type::VARBIT)
211    }
212
213    to_sql_checked!();
214}
215
216/// Error when converting from Postgres types.
217#[derive(Clone, Debug, PartialEq, Eq, Display)]
218pub enum FromSqlError {
219    /// The value is too large for the type.
220    #[display("the value is too large for the Signed type")]
221    Overflow,
222
223    /// The value is not valid for the type.
224    #[display("unexpected data for type {_0}")]
225    ParseError(Type),
226}
227
228impl core::error::Error for FromSqlError {}
229
230impl<'a, const BITS: usize, const LIMBS: usize> FromSql<'a> for Signed<BITS, LIMBS> {
231    fn accepts(ty: &Type) -> bool {
232        <Self as ToSql>::accepts(ty)
233    }
234
235    fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<Self, Box<dyn Error + Sync + Send>> {
236        Ok(match *ty {
237            Type::BOOL => match raw {
238                [0] => Self::ZERO,
239                [1] => Self::try_from(1)?,
240                _ => return Err(Box::new(FromSqlError::ParseError(ty.clone()))),
241            },
242            Type::INT2 => i16::from_be_bytes(raw.try_into()?).try_into()?,
243            Type::INT4 => i32::from_be_bytes(raw.try_into()?).try_into()?,
244            Type::OID => u32::from_be_bytes(raw.try_into()?).try_into()?,
245            Type::INT8 => i64::from_be_bytes(raw.try_into()?).try_into()?,
246            Type::MONEY => (i64::from_be_bytes(raw.try_into()?) / 100).try_into()?,
247
248            // Binary strings
249            Type::BYTEA => Self::try_from_be_slice(raw).ok_or(FromSqlError::Overflow)?,
250            Type::BIT | Type::VARBIT => {
251                // Parse header
252                if raw.len() < 4 {
253                    return Err(Box::new(FromSqlError::ParseError(ty.clone())));
254                }
255                let len: usize = i32::from_be_bytes(raw[..4].try_into()?).try_into()?;
256                let raw = &raw[4..];
257
258                // Shift padding to the other end
259                let padding = 8 - rem_up(len, 8);
260                let mut raw = raw.to_owned();
261                if padding > 0 {
262                    for i in (1..raw.len()).rev() {
263                        raw[i] = (raw[i] >> padding) | (raw[i - 1] << (8 - padding));
264                    }
265                    raw[0] >>= padding;
266                }
267                // Construct from bits
268                Self::try_from_be_slice(&raw).ok_or(FromSqlError::Overflow)?
269            }
270
271            // Hex strings
272            Type::CHAR | Type::TEXT | Type::VARCHAR => Self::from_str(from_utf8(raw)?)?,
273
274            // Hex strings
275            Type::JSON | Type::JSONB => {
276                let raw = if *ty == Type::JSONB {
277                    if raw[0] == 1 {
278                        &raw[1..]
279                    } else {
280                        // Unsupported version
281                        return Err(Box::new(FromSqlError::ParseError(ty.clone())));
282                    }
283                } else {
284                    raw
285                };
286                let str = from_utf8(raw)?;
287                let str = if str.starts_with('"') && str.ends_with('"') {
288                    // Stringified number
289                    &str[1..str.len() - 1]
290                } else {
291                    str
292                };
293                Self::from_str(str)?
294            }
295
296            // Numeric types
297            Type::NUMERIC => {
298                // Parse header
299                if raw.len() < 8 {
300                    return Err(Box::new(FromSqlError::ParseError(ty.clone())));
301                }
302                let digits = i16::from_be_bytes(raw[0..2].try_into()?);
303                let exponent = i16::from_be_bytes(raw[2..4].try_into()?);
304                let sign = i16::from_be_bytes(raw[4..6].try_into()?);
305                let dscale = i16::from_be_bytes(raw[6..8].try_into()?);
306                let raw = &raw[8..];
307                #[allow(clippy::cast_sign_loss)] // Signs are checked
308                if digits < 0
309                    || exponent < 0
310                    || dscale != 0
311                    || digits > exponent + 1
312                    || raw.len() != digits as usize * 2
313                {
314                    return Err(Box::new(FromSqlError::ParseError(ty.clone())));
315                }
316                let mut error = false;
317                let iter = raw.chunks_exact(2).filter_map(|raw| {
318                    if error {
319                        return None;
320                    }
321                    let digit = i16::from_be_bytes(raw.try_into().unwrap());
322                    if !(0..10000).contains(&digit) {
323                        error = true;
324                        return None;
325                    }
326                    #[allow(clippy::cast_sign_loss)] // Signs are checked
327                    Some(digit as u64)
328                });
329                #[allow(clippy::cast_sign_loss)]
330                // Expression can not be negative due to checks above
331                let iter = iter.chain(iter::repeat(0).take((exponent + 1 - digits) as usize));
332
333                let mut value = Self::from_base_be(10000, iter)?;
334                if sign == 0x4000 {
335                    value = -value;
336                }
337                if error {
338                    return Err(Box::new(FromSqlError::ParseError(ty.clone())));
339                }
340
341                value
342            }
343
344            // Unsupported types
345            _ => return Err(Box::new(WrongType::new::<Self>(ty.clone()))),
346        })
347    }
348}
349
350#[cfg(test)]
351mod test {
352    use super::*;
353
354    use crate::I256;
355
356    #[test]
357    fn positive_i256_from_sql() {
358        assert_eq!(
359            I256::from_sql(
360                &Type::NUMERIC,
361                &[
362                    0x00, 0x01, // ndigits: 1
363                    0x00, 0x00, // weight: 0
364                    0x00, 0x00, // sign: 0x0000 (positive)
365                    0x00, 0x00, // scale: 0
366                    0x00, 0x01, // digit: 1
367                ]
368            )
369            .unwrap(),
370            I256::ONE
371        );
372    }
373
374    #[test]
375    fn positive_i256_to_sql() {
376        let mut bytes = BytesMut::with_capacity(64);
377        I256::ONE.to_sql(&Type::NUMERIC, &mut bytes).unwrap();
378        assert_eq!(
379            *bytes.freeze(),
380            [
381                0x00, 0x01, // ndigits: 1
382                0x00, 0x00, // weight: 0
383                0x00, 0x00, // sign: 0x0000 (positive)
384                0x00, 0x00, // scale: 0
385                0x00, 0x01, // digit: 1
386            ],
387        );
388    }
389
390    #[test]
391    fn negative_i256_from_sql() {
392        assert_eq!(
393            I256::from_sql(
394                &Type::NUMERIC,
395                &[
396                    0x00, 0x01, // ndigits: 1
397                    0x00, 0x00, // weight: 0
398                    0x40, 0x00, // sign: 0x4000 (negative)
399                    0x00, 0x00, // scale: 0
400                    0x00, 0x01, // digit: 1
401                ]
402            )
403            .unwrap(),
404            I256::MINUS_ONE
405        );
406    }
407
408    #[test]
409    fn negative_i256_to_sql() {
410        let mut bytes = BytesMut::with_capacity(64);
411        I256::MINUS_ONE.to_sql(&Type::NUMERIC, &mut bytes).unwrap();
412        assert_eq!(
413            *bytes.freeze(),
414            [
415                0x00, 0x01, // ndigits: 1
416                0x00, 0x00, // weight: 0
417                0x40, 0x00, // sign: 0x4000 (negative)
418                0x00, 0x00, // scale: 0
419                0x00, 0x01, // digit: 1
420            ],
421        );
422    }
423}