pgrx_sql_entity_graph/metadata/
sql_translatable.rs

1//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC.
2//LICENSE
3//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc.
4//LICENSE
5//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. <contact@pgcentral.org>
6//LICENSE
7//LICENSE All rights reserved.
8//LICENSE
9//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file.
10/*!
11
12A trait denoting a type can possibly be mapped to an SQL type
13
14> Like all of the [`sql_entity_graph`][crate] APIs, this is considered **internal**
15> to the `pgrx` framework and very subject to change between versions. While you may use this, please do it with caution.
16
17*/
18use std::any::Any;
19use std::ffi::{CStr, CString};
20use std::fmt::Display;
21use thiserror::Error;
22
23use super::return_variant::ReturnsError;
24use super::{FunctionMetadataTypeEntity, Returns};
25
26#[derive(Clone, Copy, Debug, Hash, Ord, PartialOrd, PartialEq, Eq, Error)]
27pub enum ArgumentError {
28    #[error("Cannot use SetOfIterator as an argument")]
29    SetOf,
30    #[error("Cannot use TableIterator as an argument")]
31    Table,
32    #[error("Cannot use bare u8")]
33    BareU8,
34    #[error("SqlMapping::Skip inside Array is not valid")]
35    SkipInArray,
36    #[error("A Datum as an argument means that `sql = \"...\"` must be set in the declaration")]
37    Datum,
38    #[error("`{0}` is not able to be used as a function argument")]
39    NotValidAsArgument(&'static str),
40}
41
42/// Describes ways that Rust types are mapped into SQL
43#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
44pub enum SqlMapping {
45    /// Explicit mappings provided by PGRX
46    As(String),
47    Composite {
48        array_brackets: bool,
49    },
50    /// A type which does not actually appear in SQL
51    Skip,
52}
53
54impl SqlMapping {
55    pub fn literal(s: &'static str) -> SqlMapping {
56        SqlMapping::As(String::from(s))
57    }
58}
59
60/**
61A value which can be represented in SQL
62
63# Safety
64
65By implementing this, you assert you are not lying to either Postgres or Rust in doing so.
66This trait asserts a safe translation exists between values of this type from Rust to SQL,
67or from SQL into Rust. If you are mistaken about how this works, either the Postgres C API
68or the Rust handling in PGRX may emit undefined behavior.
69
70It cannot be made private or sealed due to details of the structure of the PGRX framework.
71Nonetheless, if you are not confident the translation is valid: do not implement this trait.
72*/
73pub unsafe trait SqlTranslatable {
74    fn type_name() -> &'static str {
75        core::any::type_name::<Self>()
76    }
77    fn argument_sql() -> Result<SqlMapping, ArgumentError>;
78    fn return_sql() -> Result<Returns, ReturnsError>;
79    fn variadic() -> bool {
80        false
81    }
82    fn optional() -> bool {
83        false
84    }
85    fn entity() -> FunctionMetadataTypeEntity {
86        FunctionMetadataTypeEntity {
87            type_name: Self::type_name(),
88            argument_sql: Self::argument_sql(),
89            return_sql: Self::return_sql(),
90            variadic: Self::variadic(),
91            optional: Self::optional(),
92        }
93    }
94}
95
96unsafe impl SqlTranslatable for () {
97    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
98        Err(ArgumentError::NotValidAsArgument("()"))
99    }
100
101    fn return_sql() -> Result<Returns, ReturnsError> {
102        Ok(Returns::One(SqlMapping::literal("VOID")))
103    }
104}
105
106unsafe impl<T> SqlTranslatable for Option<T>
107where
108    T: SqlTranslatable,
109{
110    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
111        T::argument_sql()
112    }
113    fn return_sql() -> Result<Returns, ReturnsError> {
114        T::return_sql()
115    }
116    fn optional() -> bool {
117        true
118    }
119}
120
121unsafe impl<T> SqlTranslatable for *mut T
122where
123    T: SqlTranslatable,
124{
125    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
126        T::argument_sql()
127    }
128    fn return_sql() -> Result<Returns, ReturnsError> {
129        T::return_sql()
130    }
131    fn optional() -> bool {
132        T::optional()
133    }
134}
135
136unsafe impl<T, E> SqlTranslatable for Result<T, E>
137where
138    T: SqlTranslatable,
139    E: Any + Display,
140{
141    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
142        T::argument_sql()
143    }
144    fn return_sql() -> Result<Returns, ReturnsError> {
145        T::return_sql()
146    }
147    fn optional() -> bool {
148        true
149    }
150}
151
152unsafe impl<T> SqlTranslatable for Vec<T>
153where
154    T: SqlTranslatable,
155{
156    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
157        match T::type_name() {
158            id if id == u8::type_name() => Ok(SqlMapping::As("bytea".into())),
159            _ => match T::argument_sql() {
160                Ok(SqlMapping::As(val)) => Ok(SqlMapping::As(format!("{val}[]"))),
161                Ok(SqlMapping::Composite { array_brackets: _ }) => {
162                    Ok(SqlMapping::Composite { array_brackets: true })
163                }
164                Ok(SqlMapping::Skip) => Ok(SqlMapping::Skip),
165                err @ Err(_) => err,
166            },
167        }
168    }
169
170    fn return_sql() -> Result<Returns, ReturnsError> {
171        match T::type_name() {
172            id if id == u8::type_name() => Ok(Returns::One(SqlMapping::As("bytea".into()))),
173            _ => match T::return_sql() {
174                Ok(Returns::One(SqlMapping::As(val))) => {
175                    Ok(Returns::One(SqlMapping::As(format!("{val}[]"))))
176                }
177                Ok(Returns::One(SqlMapping::Composite { array_brackets: _ })) => {
178                    Ok(Returns::One(SqlMapping::Composite { array_brackets: true }))
179                }
180                Ok(Returns::One(SqlMapping::Skip)) => Ok(Returns::One(SqlMapping::Skip)),
181                Ok(Returns::SetOf(_)) => Err(ReturnsError::SetOfInArray),
182                Ok(Returns::Table(_)) => Err(ReturnsError::TableInArray),
183                err @ Err(_) => err,
184            },
185        }
186    }
187    fn optional() -> bool {
188        T::optional()
189    }
190}
191
192unsafe impl SqlTranslatable for u8 {
193    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
194        Err(ArgumentError::BareU8)
195    }
196    fn return_sql() -> Result<Returns, ReturnsError> {
197        Err(ReturnsError::BareU8)
198    }
199}
200
201unsafe impl SqlTranslatable for i32 {
202    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
203        Ok(SqlMapping::literal("INT"))
204    }
205    fn return_sql() -> Result<Returns, ReturnsError> {
206        Ok(Returns::One(SqlMapping::literal("INT")))
207    }
208}
209
210unsafe impl SqlTranslatable for String {
211    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
212        Ok(SqlMapping::literal("TEXT"))
213    }
214    fn return_sql() -> Result<Returns, ReturnsError> {
215        Ok(Returns::One(SqlMapping::literal("TEXT")))
216    }
217}
218
219unsafe impl<T> SqlTranslatable for &T
220where
221    T: ?Sized + SqlTranslatable,
222{
223    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
224        T::argument_sql()
225    }
226    fn return_sql() -> Result<Returns, ReturnsError> {
227        T::return_sql()
228    }
229}
230
231unsafe impl SqlTranslatable for str {
232    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
233        Ok(SqlMapping::literal("TEXT"))
234    }
235    fn return_sql() -> Result<Returns, ReturnsError> {
236        Ok(Returns::One(SqlMapping::literal("TEXT")))
237    }
238}
239
240unsafe impl SqlTranslatable for [u8] {
241    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
242        Ok(SqlMapping::literal("bytea"))
243    }
244    fn return_sql() -> Result<Returns, ReturnsError> {
245        Ok(Returns::One(SqlMapping::literal("bytea")))
246    }
247}
248
249unsafe impl SqlTranslatable for i8 {
250    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
251        Ok(SqlMapping::As(String::from("\"char\"")))
252    }
253    fn return_sql() -> Result<Returns, ReturnsError> {
254        Ok(Returns::One(SqlMapping::As(String::from("\"char\""))))
255    }
256}
257
258unsafe impl SqlTranslatable for i16 {
259    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
260        Ok(SqlMapping::literal("smallint"))
261    }
262    fn return_sql() -> Result<Returns, ReturnsError> {
263        Ok(Returns::One(SqlMapping::literal("smallint")))
264    }
265}
266
267unsafe impl SqlTranslatable for i64 {
268    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
269        Ok(SqlMapping::literal("bigint"))
270    }
271    fn return_sql() -> Result<Returns, ReturnsError> {
272        Ok(Returns::One(SqlMapping::literal("bigint")))
273    }
274}
275
276unsafe impl SqlTranslatable for bool {
277    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
278        Ok(SqlMapping::literal("bool"))
279    }
280    fn return_sql() -> Result<Returns, ReturnsError> {
281        Ok(Returns::One(SqlMapping::literal("bool")))
282    }
283}
284
285unsafe impl SqlTranslatable for char {
286    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
287        Ok(SqlMapping::literal("varchar"))
288    }
289    fn return_sql() -> Result<Returns, ReturnsError> {
290        Ok(Returns::One(SqlMapping::literal("varchar")))
291    }
292}
293
294unsafe impl SqlTranslatable for f32 {
295    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
296        Ok(SqlMapping::literal("real"))
297    }
298    fn return_sql() -> Result<Returns, ReturnsError> {
299        Ok(Returns::One(SqlMapping::literal("real")))
300    }
301}
302
303unsafe impl SqlTranslatable for f64 {
304    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
305        Ok(SqlMapping::literal("double precision"))
306    }
307    fn return_sql() -> Result<Returns, ReturnsError> {
308        Ok(Returns::One(SqlMapping::literal("double precision")))
309    }
310}
311
312unsafe impl SqlTranslatable for CString {
313    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
314        Ok(SqlMapping::literal("cstring"))
315    }
316    fn return_sql() -> Result<Returns, ReturnsError> {
317        Ok(Returns::One(SqlMapping::literal("cstring")))
318    }
319}
320
321unsafe impl SqlTranslatable for CStr {
322    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
323        Ok(SqlMapping::literal("cstring"))
324    }
325    fn return_sql() -> Result<Returns, ReturnsError> {
326        Ok(Returns::One(SqlMapping::literal("cstring")))
327    }
328}