pgrx_sql_entity_graph/metadata/
sql_translatable.rs1use std::any::Any;
19use std::ffi::{CStr, CString};
20use std::fmt::Display;
21use thiserror::Error;
22
23use super::return_variant::ReturnsError;
24use super::{FunctionMetadataTypeEntity, Returns};
25
26#[derive(Clone, Copy, Debug, Hash, Ord, PartialOrd, PartialEq, Eq, Error)]
27pub enum ArgumentError {
28 #[error("Cannot use SetOfIterator as an argument")]
29 SetOf,
30 #[error("Cannot use TableIterator as an argument")]
31 Table,
32 #[error("Cannot use bare u8")]
33 BareU8,
34 #[error("SqlMapping::Skip inside Array is not valid")]
35 SkipInArray,
36 #[error("A Datum as an argument means that `sql = \"...\"` must be set in the declaration")]
37 Datum,
38 #[error("`{0}` is not able to be used as a function argument")]
39 NotValidAsArgument(&'static str),
40}
41
42#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
44pub enum SqlMapping {
45 As(String),
47 Composite {
48 array_brackets: bool,
49 },
50 Skip,
52}
53
54impl SqlMapping {
55 pub fn literal(s: &'static str) -> SqlMapping {
56 SqlMapping::As(String::from(s))
57 }
58}
59
60pub unsafe trait SqlTranslatable {
74 fn type_name() -> &'static str {
75 core::any::type_name::<Self>()
76 }
77 fn argument_sql() -> Result<SqlMapping, ArgumentError>;
78 fn return_sql() -> Result<Returns, ReturnsError>;
79 fn variadic() -> bool {
80 false
81 }
82 fn optional() -> bool {
83 false
84 }
85 fn entity() -> FunctionMetadataTypeEntity {
86 FunctionMetadataTypeEntity {
87 type_name: Self::type_name(),
88 argument_sql: Self::argument_sql(),
89 return_sql: Self::return_sql(),
90 variadic: Self::variadic(),
91 optional: Self::optional(),
92 }
93 }
94}
95
96unsafe impl SqlTranslatable for () {
97 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
98 Err(ArgumentError::NotValidAsArgument("()"))
99 }
100
101 fn return_sql() -> Result<Returns, ReturnsError> {
102 Ok(Returns::One(SqlMapping::literal("VOID")))
103 }
104}
105
106unsafe impl<T> SqlTranslatable for Option<T>
107where
108 T: SqlTranslatable,
109{
110 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
111 T::argument_sql()
112 }
113 fn return_sql() -> Result<Returns, ReturnsError> {
114 T::return_sql()
115 }
116 fn optional() -> bool {
117 true
118 }
119}
120
121unsafe impl<T> SqlTranslatable for *mut T
122where
123 T: SqlTranslatable,
124{
125 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
126 T::argument_sql()
127 }
128 fn return_sql() -> Result<Returns, ReturnsError> {
129 T::return_sql()
130 }
131 fn optional() -> bool {
132 T::optional()
133 }
134}
135
136unsafe impl<T, E> SqlTranslatable for Result<T, E>
137where
138 T: SqlTranslatable,
139 E: Any + Display,
140{
141 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
142 T::argument_sql()
143 }
144 fn return_sql() -> Result<Returns, ReturnsError> {
145 T::return_sql()
146 }
147 fn optional() -> bool {
148 true
149 }
150}
151
152unsafe impl<T> SqlTranslatable for Vec<T>
153where
154 T: SqlTranslatable,
155{
156 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
157 match T::type_name() {
158 id if id == u8::type_name() => Ok(SqlMapping::As("bytea".into())),
159 _ => match T::argument_sql() {
160 Ok(SqlMapping::As(val)) => Ok(SqlMapping::As(format!("{val}[]"))),
161 Ok(SqlMapping::Composite { array_brackets: _ }) => {
162 Ok(SqlMapping::Composite { array_brackets: true })
163 }
164 Ok(SqlMapping::Skip) => Ok(SqlMapping::Skip),
165 err @ Err(_) => err,
166 },
167 }
168 }
169
170 fn return_sql() -> Result<Returns, ReturnsError> {
171 match T::type_name() {
172 id if id == u8::type_name() => Ok(Returns::One(SqlMapping::As("bytea".into()))),
173 _ => match T::return_sql() {
174 Ok(Returns::One(SqlMapping::As(val))) => {
175 Ok(Returns::One(SqlMapping::As(format!("{val}[]"))))
176 }
177 Ok(Returns::One(SqlMapping::Composite { array_brackets: _ })) => {
178 Ok(Returns::One(SqlMapping::Composite { array_brackets: true }))
179 }
180 Ok(Returns::One(SqlMapping::Skip)) => Ok(Returns::One(SqlMapping::Skip)),
181 Ok(Returns::SetOf(_)) => Err(ReturnsError::SetOfInArray),
182 Ok(Returns::Table(_)) => Err(ReturnsError::TableInArray),
183 err @ Err(_) => err,
184 },
185 }
186 }
187 fn optional() -> bool {
188 T::optional()
189 }
190}
191
192unsafe impl SqlTranslatable for u8 {
193 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
194 Err(ArgumentError::BareU8)
195 }
196 fn return_sql() -> Result<Returns, ReturnsError> {
197 Err(ReturnsError::BareU8)
198 }
199}
200
201unsafe impl SqlTranslatable for i32 {
202 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
203 Ok(SqlMapping::literal("INT"))
204 }
205 fn return_sql() -> Result<Returns, ReturnsError> {
206 Ok(Returns::One(SqlMapping::literal("INT")))
207 }
208}
209
210unsafe impl SqlTranslatable for String {
211 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
212 Ok(SqlMapping::literal("TEXT"))
213 }
214 fn return_sql() -> Result<Returns, ReturnsError> {
215 Ok(Returns::One(SqlMapping::literal("TEXT")))
216 }
217}
218
219unsafe impl<T> SqlTranslatable for &T
220where
221 T: ?Sized + SqlTranslatable,
222{
223 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
224 T::argument_sql()
225 }
226 fn return_sql() -> Result<Returns, ReturnsError> {
227 T::return_sql()
228 }
229}
230
231unsafe impl SqlTranslatable for str {
232 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
233 Ok(SqlMapping::literal("TEXT"))
234 }
235 fn return_sql() -> Result<Returns, ReturnsError> {
236 Ok(Returns::One(SqlMapping::literal("TEXT")))
237 }
238}
239
240unsafe impl SqlTranslatable for [u8] {
241 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
242 Ok(SqlMapping::literal("bytea"))
243 }
244 fn return_sql() -> Result<Returns, ReturnsError> {
245 Ok(Returns::One(SqlMapping::literal("bytea")))
246 }
247}
248
249unsafe impl SqlTranslatable for i8 {
250 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
251 Ok(SqlMapping::As(String::from("\"char\"")))
252 }
253 fn return_sql() -> Result<Returns, ReturnsError> {
254 Ok(Returns::One(SqlMapping::As(String::from("\"char\""))))
255 }
256}
257
258unsafe impl SqlTranslatable for i16 {
259 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
260 Ok(SqlMapping::literal("smallint"))
261 }
262 fn return_sql() -> Result<Returns, ReturnsError> {
263 Ok(Returns::One(SqlMapping::literal("smallint")))
264 }
265}
266
267unsafe impl SqlTranslatable for i64 {
268 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
269 Ok(SqlMapping::literal("bigint"))
270 }
271 fn return_sql() -> Result<Returns, ReturnsError> {
272 Ok(Returns::One(SqlMapping::literal("bigint")))
273 }
274}
275
276unsafe impl SqlTranslatable for bool {
277 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
278 Ok(SqlMapping::literal("bool"))
279 }
280 fn return_sql() -> Result<Returns, ReturnsError> {
281 Ok(Returns::One(SqlMapping::literal("bool")))
282 }
283}
284
285unsafe impl SqlTranslatable for char {
286 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
287 Ok(SqlMapping::literal("varchar"))
288 }
289 fn return_sql() -> Result<Returns, ReturnsError> {
290 Ok(Returns::One(SqlMapping::literal("varchar")))
291 }
292}
293
294unsafe impl SqlTranslatable for f32 {
295 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
296 Ok(SqlMapping::literal("real"))
297 }
298 fn return_sql() -> Result<Returns, ReturnsError> {
299 Ok(Returns::One(SqlMapping::literal("real")))
300 }
301}
302
303unsafe impl SqlTranslatable for f64 {
304 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
305 Ok(SqlMapping::literal("double precision"))
306 }
307 fn return_sql() -> Result<Returns, ReturnsError> {
308 Ok(Returns::One(SqlMapping::literal("double precision")))
309 }
310}
311
312unsafe impl SqlTranslatable for CString {
313 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
314 Ok(SqlMapping::literal("cstring"))
315 }
316 fn return_sql() -> Result<Returns, ReturnsError> {
317 Ok(Returns::One(SqlMapping::literal("cstring")))
318 }
319}
320
321unsafe impl SqlTranslatable for CStr {
322 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
323 Ok(SqlMapping::literal("cstring"))
324 }
325 fn return_sql() -> Result<Returns, ReturnsError> {
326 Ok(Returns::One(SqlMapping::literal("cstring")))
327 }
328}