pgrx_sql_entity_graph/pg_extern/entity/
mod.rs

1//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC.
2//LICENSE
3//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc.
4//LICENSE
5//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. <contact@pgcentral.org>
6//LICENSE
7//LICENSE All rights reserved.
8//LICENSE
9//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file.
10/*!
11
12`#[pg_extern]` related entities for Rust to SQL translation
13
14> Like all of the [`sql_entity_graph`][crate] APIs, this is considered **internal**
15> to the `pgrx` framework and very subject to change between versions. While you may use this, please do it with caution.
16
17*/
18mod argument;
19mod cast;
20mod operator;
21mod returning;
22
23pub use argument::PgExternArgumentEntity;
24pub use cast::PgCastEntity;
25pub use operator::PgOperatorEntity;
26pub use returning::{PgExternReturnEntity, PgExternReturnEntityIteratedItem};
27
28use crate::fmt;
29use crate::metadata::{Returns, SqlMapping};
30use crate::pgrx_sql::PgrxSql;
31use crate::to_sql::entity::ToSqlConfigEntity;
32use crate::to_sql::ToSql;
33use crate::{ExternArgs, SqlGraphEntity, SqlGraphIdentifier, TypeMatch};
34
35use eyre::{eyre, WrapErr};
36
37/// The output of a [`PgExtern`](crate::pg_extern::PgExtern) from `quote::ToTokens::to_tokens`.
38#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
39pub struct PgExternEntity {
40    pub name: &'static str,
41    pub unaliased_name: &'static str,
42    pub module_path: &'static str,
43    pub full_path: &'static str,
44    pub metadata: crate::metadata::FunctionMetadataEntity,
45    pub fn_args: Vec<PgExternArgumentEntity>,
46    pub fn_return: PgExternReturnEntity,
47    pub schema: Option<&'static str>,
48    pub file: &'static str,
49    pub line: u32,
50    pub extern_attrs: Vec<ExternArgs>,
51    pub search_path: Option<Vec<&'static str>>,
52    pub operator: Option<PgOperatorEntity>,
53    pub cast: Option<PgCastEntity>,
54    pub to_sql_config: ToSqlConfigEntity,
55}
56
57impl From<PgExternEntity> for SqlGraphEntity {
58    fn from(val: PgExternEntity) -> Self {
59        SqlGraphEntity::Function(val)
60    }
61}
62
63impl SqlGraphIdentifier for PgExternEntity {
64    fn dot_identifier(&self) -> String {
65        format!("fn {}", self.name)
66    }
67    fn rust_identifier(&self) -> String {
68        self.full_path.to_string()
69    }
70
71    fn file(&self) -> Option<&'static str> {
72        Some(self.file)
73    }
74
75    fn line(&self) -> Option<u32> {
76        Some(self.line)
77    }
78}
79
80impl ToSql for PgExternEntity {
81    fn to_sql(&self, context: &PgrxSql) -> eyre::Result<String> {
82        let self_index = context.externs[self];
83        let mut extern_attrs = self.extern_attrs.clone();
84        // if we already have a STRICT marker we do not need to add it
85        // presume we can upgrade, then disprove it
86        let mut strict_upgrade = !extern_attrs.iter().any(|i| i == &ExternArgs::Strict);
87        if strict_upgrade {
88            // It may be possible to infer a `STRICT` marker though.
89            // But we can only do that if the user hasn't used `Option<T>` or `pgrx::Internal`
90            for arg in &self.metadata.arguments {
91                if arg.optional {
92                    strict_upgrade = false;
93                }
94            }
95        }
96
97        if strict_upgrade {
98            extern_attrs.push(ExternArgs::Strict);
99        }
100        extern_attrs.sort();
101        extern_attrs.dedup();
102
103        let module_pathname = &context.get_module_pathname();
104        let schema = self
105            .schema
106            .map(|schema| format!("{schema}."))
107            .unwrap_or_else(|| context.schema_prefix_for(&self_index));
108        let arguments = if !self.fn_args.is_empty() {
109            let mut args = Vec::new();
110            let metadata_without_arg_skips = &self
111                .metadata
112                .arguments
113                .iter()
114                .filter(|v| v.argument_sql != Ok(SqlMapping::Skip))
115                .collect::<Vec<_>>();
116            for (idx, arg) in self.fn_args.iter().enumerate() {
117                let graph_index = context
118                    .graph
119                    .neighbors_undirected(self_index)
120                    .find(|neighbor| context.graph[*neighbor].type_matches(&arg.used_ty))
121                    .ok_or_else(|| eyre!("Could not find arg type in graph. Got: {:?}", arg))?;
122                let needs_comma = idx < (metadata_without_arg_skips.len().saturating_sub(1));
123                let metadata_argument = &self.metadata.arguments[idx];
124                match metadata_argument.argument_sql {
125                    Ok(SqlMapping::As(ref argument_sql)) => {
126                        let buf = format!("\
127                                            \t\"{pattern}\" {variadic}{schema_prefix}{sql_type}{default}{maybe_comma}/* {type_name} */\
128                                        ",
129                                            pattern = arg.pattern,
130                                            schema_prefix = context.schema_prefix_for(&graph_index),
131                                            // First try to match on [`TypeId`] since it's most reliable.
132                                            sql_type = argument_sql,
133                                            default = if let Some(def) = arg.used_ty.default { format!(" DEFAULT {def}") } else { String::from("") },
134                                            variadic = if metadata_argument.variadic { "VARIADIC " } else { "" },
135                                            maybe_comma = if needs_comma { ", " } else { " " },
136                                            type_name = metadata_argument.type_name,
137                                    );
138                        args.push(buf);
139                    }
140                    Ok(SqlMapping::Composite { array_brackets }) => {
141                        let sql = self.fn_args[idx]
142                            .used_ty
143                            .composite_type
144                            .map(|v| fmt::with_array_brackets(v.into(), array_brackets))
145                            .ok_or_else(|| {
146                                eyre!(
147                                    "Macro expansion time suggested a composite_type!() in return"
148                                )
149                            })?;
150                        let buf = format!("\
151                            \t\"{pattern}\" {variadic}{schema_prefix}{sql_type}{default}{maybe_comma}/* {type_name} */\
152                        ",
153                            pattern = arg.pattern,
154                            schema_prefix = context.schema_prefix_for(&graph_index),
155                            // First try to match on [`TypeId`] since it's most reliable.
156                            sql_type = sql,
157                            default = if let Some(def) = arg.used_ty.default { format!(" DEFAULT {def}") } else { String::from("") },
158                            variadic = if metadata_argument.variadic { "VARIADIC " } else { "" },
159                            maybe_comma = if needs_comma { ", " } else { " " },
160                            type_name = metadata_argument.type_name,
161                    );
162                        args.push(buf);
163                    }
164                    Ok(SqlMapping::Skip) => (),
165                    Err(err) => return Err(err).wrap_err("While mapping argument"),
166                }
167            }
168            String::from("\n") + &args.join("\n") + "\n"
169        } else {
170            Default::default()
171        };
172
173        let returns = match &self.fn_return {
174            PgExternReturnEntity::None => String::from("RETURNS void"),
175            PgExternReturnEntity::Type { ty } => {
176                let graph_index = context
177                    .graph
178                    .neighbors_undirected(self_index)
179                    .find(|neighbor| context.graph[*neighbor].type_matches(ty))
180                    .ok_or_else(|| eyre!("Could not find return type in graph."))?;
181                let metadata_retval = self.metadata.retval.clone();
182                let sql_type = match metadata_retval.return_sql {
183                    Ok(Returns::One(SqlMapping::As(ref sql))) => sql.clone(),
184                    Ok(Returns::One(SqlMapping::Composite { array_brackets })) => fmt::with_array_brackets(ty.composite_type.unwrap().into(), array_brackets),
185                    Ok(other) => return Err(eyre!("Got non-plain mapped/composite return variant SQL in what macro-expansion thought was a type, got: {other:?}")),
186                    Err(err) => return Err(err).wrap_err("Error mapping return SQL")
187                };
188                format!(
189                    "RETURNS {schema_prefix}{sql_type} /* {full_path} */",
190                    schema_prefix = context.schema_prefix_for(&graph_index),
191                    full_path = ty.full_path
192                )
193            }
194            PgExternReturnEntity::SetOf { ty, .. } => {
195                let graph_index = context
196                    .graph
197                    .neighbors_undirected(self_index)
198                    .find(|neighbor| context.graph[*neighbor].type_matches(ty))
199                    .ok_or_else(|| eyre!("Could not find return type in graph."))?;
200                let metadata_retval = self.metadata.retval.clone();
201                let sql_type = match metadata_retval.return_sql {
202                        Ok(Returns::SetOf(SqlMapping::As(ref sql))) => sql.clone(),
203                        Ok(Returns::SetOf(SqlMapping::Composite { array_brackets })) => fmt::with_array_brackets(ty.composite_type.unwrap().into(), array_brackets),
204                        Ok(_other) => return Err(eyre!("Got non-setof mapped/composite return variant SQL in what macro-expansion thought was a setof")),
205                        Err(err) => return Err(err).wrap_err("Error mapping return SQL"),
206                    };
207                format!(
208                    "RETURNS SETOF {schema_prefix}{sql_type} /* {full_path} */",
209                    schema_prefix = context.schema_prefix_for(&graph_index),
210                    full_path = ty.full_path
211                )
212            }
213            PgExternReturnEntity::Iterated { tys: table_items, .. } => {
214                let mut items = String::new();
215                let metadata_retval = self.metadata.retval.clone();
216                let metadata_retval_sqls: Vec<String> = match metadata_retval.return_sql {
217                        Ok(Returns::Table(variants)) => {
218                            variants.iter().enumerate().map(|(idx, variant)| {
219                                match variant {
220                                    SqlMapping::As(sql) => sql.clone(),
221                                    SqlMapping::Composite { array_brackets } => {
222                                        let composite = table_items[idx].ty.composite_type.unwrap();
223                                        fmt::with_array_brackets(composite.into(), *array_brackets)
224                                    },
225                                    SqlMapping::Skip => todo!(),
226                                }
227                            }).collect()
228                        },
229                        Ok(_other) => return Err(eyre!("Got non-table return variant SQL in what macro-expansion thought was a table")),
230                        Err(err) => return Err(err).wrap_err("Error mapping return SQL"),
231                    };
232
233                for (idx, returning::PgExternReturnEntityIteratedItem { ty, name: col_name }) in
234                    table_items.iter().enumerate()
235                {
236                    let graph_index =
237                        context.graph.neighbors_undirected(self_index).find(|neighbor| {
238                            context.graph[*neighbor].id_or_name_matches(&ty.ty_id, ty.ty_source)
239                        });
240
241                    let needs_comma = idx < (table_items.len() - 1);
242                    let item = format!(
243                        "\n\t{col_name} {schema_prefix}{ty_resolved}{needs_comma} /* {ty_name} */",
244                        col_name = col_name.expect(
245                            "An iterator of tuples should have `named!()` macro declarations."
246                        ),
247                        schema_prefix = if let Some(graph_index) = graph_index {
248                            context.schema_prefix_for(&graph_index)
249                        } else {
250                            "".into()
251                        },
252                        ty_resolved = metadata_retval_sqls[idx],
253                        needs_comma = if needs_comma { ", " } else { " " },
254                        ty_name = ty.full_path
255                    );
256                    items.push_str(&item);
257                }
258                format!("RETURNS TABLE ({items}\n)")
259            }
260            PgExternReturnEntity::Trigger => String::from("RETURNS trigger"),
261        };
262        let PgExternEntity { name, module_path, file, line, .. } = self;
263
264        let fn_sql = format!(
265            "\
266                CREATE {or_replace} FUNCTION {schema}\"{name}\"({arguments}) {returns}\n\
267                {extern_attrs}\
268                {search_path}\
269                LANGUAGE c /* Rust */\n\
270                AS '{module_pathname}', '{unaliased_name}_wrapper';\
271            ",
272            or_replace =
273                if extern_attrs.contains(&ExternArgs::CreateOrReplace) { "OR REPLACE" } else { "" },
274            search_path = if let Some(search_path) = &self.search_path {
275                let retval = format!("SET search_path TO {}", search_path.join(", "));
276                retval + "\n"
277            } else {
278                Default::default()
279            },
280            extern_attrs = if extern_attrs.is_empty() {
281                String::default()
282            } else {
283                let mut retval = extern_attrs
284                    .iter()
285                    .filter(|attr| **attr != ExternArgs::CreateOrReplace)
286                    .map(|attr| attr.to_string().to_uppercase())
287                    .collect::<Vec<_>>()
288                    .join(" ");
289                retval.push('\n');
290                retval
291            },
292            unaliased_name = self.unaliased_name,
293        );
294
295        let requires = {
296            let requires_attrs = self
297                .extern_attrs
298                .iter()
299                .filter_map(|x| match x {
300                    ExternArgs::Requires(requirements) => Some(requirements),
301                    _ => None,
302                })
303                .flatten()
304                .collect::<Vec<_>>();
305            if !requires_attrs.is_empty() {
306                format!(
307                    "-- requires:\n{}\n",
308                    requires_attrs
309                        .iter()
310                        .map(|i| format!("--   {i}"))
311                        .collect::<Vec<_>>()
312                        .join("\n")
313                )
314            } else {
315                "".to_string()
316            }
317        };
318
319        let mut ext_sql = format!(
320            "\n\
321            -- {file}:{line}\n\
322            -- {module_path}::{name}\n\
323            {requires}\
324            {fn_sql}"
325        );
326
327        if let Some(op) = &self.operator {
328            let mut optionals = vec![];
329            if let Some(it) = op.commutator {
330                optionals.push(format!("\tCOMMUTATOR = {it}"));
331            };
332            if let Some(it) = op.negator {
333                optionals.push(format!("\tNEGATOR = {it}"));
334            };
335            if let Some(it) = op.restrict {
336                optionals.push(format!("\tRESTRICT = {it}"));
337            };
338            if let Some(it) = op.join {
339                optionals.push(format!("\tJOIN = {it}"));
340            };
341            if op.hashes {
342                optionals.push(String::from("\tHASHES"));
343            };
344            if op.merges {
345                optionals.push(String::from("\tMERGES"));
346            };
347
348            let left_arg =
349                self.metadata.arguments.first().ok_or_else(|| {
350                    eyre!("Did not find `left_arg` for operator `{}`.", self.name)
351                })?;
352            let left_fn_arg = self
353                .fn_args
354                .first()
355                .ok_or_else(|| eyre!("Did not find `left_arg` for operator `{}`.", self.name))?;
356            let left_arg_graph_index = context
357                .graph
358                .neighbors_undirected(self_index)
359                .find(|neighbor| {
360                    context.graph[*neighbor]
361                        .id_or_name_matches(&left_fn_arg.used_ty.ty_id, left_arg.type_name)
362                })
363                .ok_or_else(|| {
364                    eyre!("Could not find left arg type in graph. Got: {:?}", left_arg)
365                })?;
366            let left_arg_sql = match left_arg.argument_sql {
367                Ok(SqlMapping::As(ref sql)) => sql.clone(),
368                Ok(SqlMapping::Composite { array_brackets }) => {
369                    if array_brackets {
370                        let composite_type = self.fn_args[0].used_ty.composite_type
371                            .ok_or(eyre!("Found a composite type but macro expansion time did not reveal a name, use `pgrx::composite_type!()`"))?;
372                        format!("{composite_type}[]")
373                    } else {
374                        self.fn_args[0].used_ty.composite_type
375                            .ok_or(eyre!("Found a composite type but macro expansion time did not reveal a name, use `pgrx::composite_type!()`"))?.to_string()
376                    }
377                }
378                Ok(SqlMapping::Skip) => {
379                    return Err(eyre!(
380                        "Found an skipped SQL type in an operator, this is not valid"
381                    ))
382                }
383                Err(err) => return Err(err.into()),
384            };
385
386            let right_arg =
387                self.metadata.arguments.get(1).ok_or_else(|| {
388                    eyre!("Did not find `left_arg` for operator `{}`.", self.name)
389                })?;
390            let right_fn_arg = self
391                .fn_args
392                .get(1)
393                .ok_or_else(|| eyre!("Did not find `left_arg` for operator `{}`.", self.name))?;
394            let right_arg_graph_index = context
395                .graph
396                .neighbors_undirected(self_index)
397                .find(|neighbor| {
398                    context.graph[*neighbor]
399                        .id_or_name_matches(&right_fn_arg.used_ty.ty_id, right_arg.type_name)
400                })
401                .ok_or_else(|| {
402                    eyre!("Could not find right arg type in graph. Got: {:?}", right_arg)
403                })?;
404            let right_arg_sql = match right_arg.argument_sql {
405                Ok(SqlMapping::As(ref sql)) => sql.clone(),
406                Ok(SqlMapping::Composite { array_brackets }) => {
407                    if array_brackets {
408                        let composite_type = self.fn_args[1].used_ty.composite_type
409                            .ok_or(eyre!("Found a composite type but macro expansion time did not reveal a name, use `pgrx::composite_type!()`"))?;
410                        format!("{composite_type}[]")
411                    } else {
412                        self.fn_args[0].used_ty.composite_type
413                            .ok_or(eyre!("Found a composite type but macro expansion time did not reveal a name, use `pgrx::composite_type!()`"))?.to_string()
414                    }
415                }
416                Ok(SqlMapping::Skip) => {
417                    return Err(eyre!(
418                        "Found an skipped SQL type in an operator, this is not valid"
419                    ))
420                }
421                Err(err) => return Err(err.into()),
422            };
423
424            let schema = self
425                .schema
426                .map(|schema| format!("{schema}."))
427                .unwrap_or_else(|| context.schema_prefix_for(&self_index));
428
429            let operator_sql = format!("\n\n\
430                                                    -- {file}:{line}\n\
431                                                    -- {module_path}::{name}\n\
432                                                    CREATE OPERATOR {schema}{opname} (\n\
433                                                        \tPROCEDURE={schema}\"{name}\",\n\
434                                                        \tLEFTARG={schema_prefix_left}{left_arg_sql}, /* {left_name} */\n\
435                                                        \tRIGHTARG={schema_prefix_right}{right_arg_sql}{maybe_comma} /* {right_name} */\n\
436                                                        {optionals}\
437                                                    );\
438                                                    ",
439                                                    opname = op.opname.unwrap(),
440                                                    left_name = left_arg.type_name,
441                                                    right_name = right_arg.type_name,
442                                                    schema_prefix_left = context.schema_prefix_for(&left_arg_graph_index),
443                                                    schema_prefix_right = context.schema_prefix_for(&right_arg_graph_index),
444                                                    maybe_comma = if !optionals.is_empty() { "," } else { "" },
445                                                    optionals = if !optionals.is_empty() { optionals.join(",\n") + "\n" } else { "".to_string() },
446                                            );
447            ext_sql += &operator_sql
448        };
449        if let Some(cast) = &self.cast {
450            let target_arg = &self.metadata.retval;
451            let target_fn_arg = &self.fn_return;
452            let target_arg_graph_index = context
453                .graph
454                .neighbors_undirected(self_index)
455                .find(|neighbor| match (&context.graph[*neighbor], target_fn_arg) {
456                    (SqlGraphEntity::Type(ty), PgExternReturnEntity::Type { ty: rty }) => {
457                        ty.id_matches(&rty.ty_id)
458                    }
459                    (SqlGraphEntity::Enum(en), PgExternReturnEntity::Type { ty: rty }) => {
460                        en.id_matches(&rty.ty_id)
461                    }
462                    (SqlGraphEntity::BuiltinType(defined), _) => defined == target_arg.type_name,
463                    _ => false,
464                })
465                .ok_or_else(|| {
466                    eyre!("Could not find source type in graph. Got: {:?}", target_arg)
467                })?;
468            let target_arg_sql = match target_arg.argument_sql {
469                Ok(SqlMapping::As(ref sql)) => sql.clone(),
470                Ok(SqlMapping::Composite { array_brackets }) => {
471                    if array_brackets {
472                        let composite_type = self.fn_args[0].used_ty.composite_type
473                            .ok_or(eyre!("Found a composite type but macro expansion time did not reveal a name, use `pgrx::composite_type!()`"))?;
474                        format!("{composite_type}[]")
475                    } else {
476                        self.fn_args[0].used_ty.composite_type
477                            .ok_or(eyre!("Found a composite type but macro expansion time did not reveal a name, use `pgrx::composite_type!()`"))?.to_string()
478                    }
479                }
480                Ok(SqlMapping::Skip) => {
481                    return Err(eyre!("Found an skipped SQL type in a cast, this is not valid"))
482                }
483                Err(err) => return Err(err.into()),
484            };
485            if self.metadata.arguments.len() != 1 {
486                return Err(eyre!(
487                    "PG cast function ({}) must have exactly one argument, got {}",
488                    self.name,
489                    self.metadata.arguments.len()
490                ));
491            }
492            if self.fn_args.len() != 1 {
493                return Err(eyre!(
494                    "PG cast function ({}) must have exactly one argument, got {}",
495                    self.name,
496                    self.fn_args.len()
497                ));
498            }
499            let source_arg = self
500                .metadata
501                .arguments
502                .first()
503                .ok_or_else(|| eyre!("Did not find source type for cast `{}`.", self.name))?;
504            let source_fn_arg = self
505                .fn_args
506                .first()
507                .ok_or_else(|| eyre!("Did not find source type for cast `{}`.", self.name))?;
508            let source_arg_graph_index = context
509                .graph
510                .neighbors_undirected(self_index)
511                .find(|neighbor| {
512                    context.graph[*neighbor]
513                        .id_or_name_matches(&source_fn_arg.used_ty.ty_id, source_arg.type_name)
514                })
515                .ok_or_else(|| {
516                    eyre!("Could not find source type in graph. Got: {:?}", source_arg)
517                })?;
518            let source_arg_sql = match source_arg.argument_sql {
519                Ok(SqlMapping::As(ref sql)) => sql.clone(),
520                Ok(SqlMapping::Composite { array_brackets }) => {
521                    if array_brackets {
522                        let composite_type = self.fn_args[0].used_ty.composite_type
523                            .ok_or(eyre!("Found a composite type but macro expansion time did not reveal a name, use `pgrx::composite_type!()`"))?;
524                        format!("{composite_type}[]")
525                    } else {
526                        self.fn_args[0].used_ty.composite_type
527                            .ok_or(eyre!("Found a composite type but macro expansion time did not reveal a name, use `pgrx::composite_type!()`"))?.to_string()
528                    }
529                }
530                Ok(SqlMapping::Skip) => {
531                    return Err(eyre!("Found an skipped SQL type in a cast, this is not valid"))
532                }
533                Err(err) => return Err(err.into()),
534            };
535            let optional = match cast {
536                PgCastEntity::Default => String::from(""),
537                PgCastEntity::Assignment => String::from(" AS ASSIGNMENT"),
538                PgCastEntity::Implicit => String::from(" AS IMPLICIT"),
539            };
540
541            let cast_sql = format!("\n\n\
542                                                    -- {file}:{line}\n\
543                                                    -- {module_path}::{name}\n\
544                                                    CREATE CAST (\n\
545                                                        \t{schema_prefix_source}{source_arg_sql} /* {source_name} */\n\
546                                                        \tAS\n\
547                                                        \t{schema_prefix_target}{target_arg_sql} /* {target_name} */\n\
548                                                    )\n\
549                                                    WITH FUNCTION {function_name}{optional};\
550                                                    ",
551                                                    file = self.file,
552                                                    line = self.line,
553                                                    name = self.name,
554                                                    module_path = self.module_path,
555                                                    schema_prefix_source = context.schema_prefix_for(&source_arg_graph_index),
556                                                    source_name = source_arg.type_name,
557                                                    schema_prefix_target = context.schema_prefix_for(&target_arg_graph_index),
558                                                    target_name = target_arg.type_name,
559                                                    function_name = self.name,
560                                            );
561            ext_sql += &cast_sql
562        };
563        Ok(ext_sql)
564    }
565}