1use super::{FixedBytes, Sign, Signed};
7use bytes::{BufMut, BytesMut};
8use derive_more::Display;
9use postgres_types::{accepts, to_sql_checked, FromSql, IsNull, ToSql, Type, WrongType};
10use std::{
11 error::Error,
12 iter,
13 str::{from_utf8, FromStr},
14};
15
16impl<const BITS: usize> ToSql for FixedBytes<BITS> {
18 fn to_sql(&self, _: &Type, out: &mut BytesMut) -> Result<IsNull, BoxedError> {
19 out.put_slice(&self[..]);
20 Ok(IsNull::No)
21 }
22
23 accepts!(BYTEA);
24
25 to_sql_checked!();
26}
27
28impl<'a, const BITS: usize> FromSql<'a> for FixedBytes<BITS> {
30 accepts!(BYTEA);
31
32 fn from_sql(_: &Type, raw: &'a [u8]) -> Result<Self, Box<dyn Error + Sync + Send>> {
33 Ok(Self::try_from(raw)?)
34 }
35}
36
37type BoxedError = Box<dyn Error + Sync + Send + 'static>;
40
41const fn rem_up(a: usize, b: usize) -> usize {
42 let rem = a % b;
43 if rem > 0 {
44 rem
45 } else {
46 b
47 }
48}
49
50fn last_idx<T: PartialEq>(x: &[T], value: &T) -> usize {
51 x.iter().rposition(|b| b != value).map_or(0, |idx| idx + 1)
52}
53
54fn trim_end_vec<T: PartialEq>(vec: &mut Vec<T>, value: &T) {
55 vec.truncate(last_idx(vec, value));
56}
57
58#[derive(Clone, Debug, PartialEq, Eq, Display)]
60pub enum ToSqlError {
61 #[display("Signed<{_0}> value too large to fit target type {_1}")]
63 Overflow(usize, Type),
64}
65
66impl core::error::Error for ToSqlError {}
67
68impl<const BITS: usize, const LIMBS: usize> ToSql for Signed<BITS, LIMBS> {
98 fn to_sql(&self, ty: &Type, out: &mut BytesMut) -> Result<IsNull, BoxedError> {
99 match *ty {
100 Type::BOOL => out.put_u8(u8::from(bool::try_from(self.0)?)),
103 Type::INT2 => out.put_i16(self.0.try_into()?),
104 Type::INT4 => out.put_i32(self.0.try_into()?),
105 Type::OID => out.put_u32(self.0.try_into()?),
106 Type::INT8 => out.put_i64(self.0.try_into()?),
107
108 Type::MONEY => {
109 out.put_i64(
111 i64::try_from(self.0)?
112 .checked_mul(100)
113 .ok_or(ToSqlError::Overflow(BITS, ty.clone()))?,
114 );
115 }
116
117 Type::BYTEA => out.put_slice(&self.0.to_be_bytes_vec()),
119 Type::BIT | Type::VARBIT => {
120 if BITS == 0 {
123 if *ty == Type::BIT {
124 return Err(Box::new(WrongType::new::<Self>(ty.clone())));
126 }
127 out.put_i32(0);
128 } else {
129 let padding = 8 - rem_up(BITS, 8);
132 out.put_i32(Self::BITS.try_into()?);
133 let bytes = self.0.as_le_bytes();
134 let mut bytes = bytes.iter().rev();
135 let mut shifted = bytes.next().unwrap() << padding;
136 for byte in bytes {
137 shifted |= if padding > 0 { byte >> (8 - padding) } else { 0 };
138 out.put_u8(shifted);
139 shifted = byte << padding;
140 }
141 out.put_u8(shifted);
142 }
143 }
144
145 Type::CHAR | Type::TEXT | Type::VARCHAR => {
147 out.put_slice(format!("{self:#x}").as_bytes());
148 }
149 Type::JSON | Type::JSONB => {
150 if *ty == Type::JSONB {
151 out.put_u8(1);
153 }
154 out.put_slice(format!("\"{self:#x}\"").as_bytes());
155 }
156
157 Type::NUMERIC => {
160 const BASE: u64 = 10000;
162
163 let sign = match self.sign() {
164 Sign::Positive => 0x0000,
165 _ => 0x4000,
166 };
167
168 let mut digits: Vec<_> = self.abs().0.to_base_be(BASE).collect();
169 let exponent = digits.len().saturating_sub(1).try_into()?;
170
171 trim_end_vec(&mut digits, &0);
173
174 out.put_i16(digits.len().try_into()?); out.put_i16(exponent); out.put_i16(sign);
178 out.put_i16(0); for digit in digits {
180 debug_assert!(digit < BASE);
181 #[allow(clippy::cast_possible_truncation)] out.put_i16(digit as i16);
183 }
184 }
185
186 _ => {
188 return Err(Box::new(WrongType::new::<Self>(ty.clone())));
189 }
190 };
191 Ok(IsNull::No)
192 }
193
194 fn accepts(ty: &Type) -> bool {
195 matches!(*ty, |Type::BOOL| Type::CHAR
196 | Type::INT2
197 | Type::INT4
198 | Type::INT8
199 | Type::OID
200 | Type::FLOAT4
201 | Type::FLOAT8
202 | Type::MONEY
203 | Type::NUMERIC
204 | Type::BYTEA
205 | Type::TEXT
206 | Type::VARCHAR
207 | Type::JSON
208 | Type::JSONB
209 | Type::BIT
210 | Type::VARBIT)
211 }
212
213 to_sql_checked!();
214}
215
216#[derive(Clone, Debug, PartialEq, Eq, Display)]
218pub enum FromSqlError {
219 #[display("the value is too large for the Signed type")]
221 Overflow,
222
223 #[display("unexpected data for type {_0}")]
225 ParseError(Type),
226}
227
228impl core::error::Error for FromSqlError {}
229
230impl<'a, const BITS: usize, const LIMBS: usize> FromSql<'a> for Signed<BITS, LIMBS> {
231 fn accepts(ty: &Type) -> bool {
232 <Self as ToSql>::accepts(ty)
233 }
234
235 fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<Self, Box<dyn Error + Sync + Send>> {
236 Ok(match *ty {
237 Type::BOOL => match raw {
238 [0] => Self::ZERO,
239 [1] => Self::try_from(1)?,
240 _ => return Err(Box::new(FromSqlError::ParseError(ty.clone()))),
241 },
242 Type::INT2 => i16::from_be_bytes(raw.try_into()?).try_into()?,
243 Type::INT4 => i32::from_be_bytes(raw.try_into()?).try_into()?,
244 Type::OID => u32::from_be_bytes(raw.try_into()?).try_into()?,
245 Type::INT8 => i64::from_be_bytes(raw.try_into()?).try_into()?,
246 Type::MONEY => (i64::from_be_bytes(raw.try_into()?) / 100).try_into()?,
247
248 Type::BYTEA => Self::try_from_be_slice(raw).ok_or(FromSqlError::Overflow)?,
250 Type::BIT | Type::VARBIT => {
251 if raw.len() < 4 {
253 return Err(Box::new(FromSqlError::ParseError(ty.clone())));
254 }
255 let len: usize = i32::from_be_bytes(raw[..4].try_into()?).try_into()?;
256 let raw = &raw[4..];
257
258 let padding = 8 - rem_up(len, 8);
260 let mut raw = raw.to_owned();
261 if padding > 0 {
262 for i in (1..raw.len()).rev() {
263 raw[i] = (raw[i] >> padding) | (raw[i - 1] << (8 - padding));
264 }
265 raw[0] >>= padding;
266 }
267 Self::try_from_be_slice(&raw).ok_or(FromSqlError::Overflow)?
269 }
270
271 Type::CHAR | Type::TEXT | Type::VARCHAR => Self::from_str(from_utf8(raw)?)?,
273
274 Type::JSON | Type::JSONB => {
276 let raw = if *ty == Type::JSONB {
277 if raw[0] == 1 {
278 &raw[1..]
279 } else {
280 return Err(Box::new(FromSqlError::ParseError(ty.clone())));
282 }
283 } else {
284 raw
285 };
286 let str = from_utf8(raw)?;
287 let str = if str.starts_with('"') && str.ends_with('"') {
288 &str[1..str.len() - 1]
290 } else {
291 str
292 };
293 Self::from_str(str)?
294 }
295
296 Type::NUMERIC => {
298 if raw.len() < 8 {
300 return Err(Box::new(FromSqlError::ParseError(ty.clone())));
301 }
302 let digits = i16::from_be_bytes(raw[0..2].try_into()?);
303 let exponent = i16::from_be_bytes(raw[2..4].try_into()?);
304 let sign = i16::from_be_bytes(raw[4..6].try_into()?);
305 let dscale = i16::from_be_bytes(raw[6..8].try_into()?);
306 let raw = &raw[8..];
307 #[allow(clippy::cast_sign_loss)] if digits < 0
309 || exponent < 0
310 || dscale != 0
311 || digits > exponent + 1
312 || raw.len() != digits as usize * 2
313 {
314 return Err(Box::new(FromSqlError::ParseError(ty.clone())));
315 }
316 let mut error = false;
317 let iter = raw.chunks_exact(2).filter_map(|raw| {
318 if error {
319 return None;
320 }
321 let digit = i16::from_be_bytes(raw.try_into().unwrap());
322 if !(0..10000).contains(&digit) {
323 error = true;
324 return None;
325 }
326 #[allow(clippy::cast_sign_loss)] Some(digit as u64)
328 });
329 #[allow(clippy::cast_sign_loss)]
330 let iter = iter.chain(iter::repeat(0).take((exponent + 1 - digits) as usize));
332
333 let mut value = Self::from_base_be(10000, iter)?;
334 if sign == 0x4000 {
335 value = -value;
336 }
337 if error {
338 return Err(Box::new(FromSqlError::ParseError(ty.clone())));
339 }
340
341 value
342 }
343
344 _ => return Err(Box::new(WrongType::new::<Self>(ty.clone()))),
346 })
347 }
348}
349
350#[cfg(test)]
351mod test {
352 use super::*;
353
354 use crate::I256;
355
356 #[test]
357 fn positive_i256_from_sql() {
358 assert_eq!(
359 I256::from_sql(
360 &Type::NUMERIC,
361 &[
362 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, ]
368 )
369 .unwrap(),
370 I256::ONE
371 );
372 }
373
374 #[test]
375 fn positive_i256_to_sql() {
376 let mut bytes = BytesMut::with_capacity(64);
377 I256::ONE.to_sql(&Type::NUMERIC, &mut bytes).unwrap();
378 assert_eq!(
379 *bytes.freeze(),
380 [
381 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, ],
387 );
388 }
389
390 #[test]
391 fn negative_i256_from_sql() {
392 assert_eq!(
393 I256::from_sql(
394 &Type::NUMERIC,
395 &[
396 0x00, 0x01, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x00, 0x01, ]
402 )
403 .unwrap(),
404 I256::MINUS_ONE
405 );
406 }
407
408 #[test]
409 fn negative_i256_to_sql() {
410 let mut bytes = BytesMut::with_capacity(64);
411 I256::MINUS_ONE.to_sql(&Type::NUMERIC, &mut bytes).unwrap();
412 assert_eq!(
413 *bytes.freeze(),
414 [
415 0x00, 0x01, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x00, 0x01, ],
421 );
422 }
423}