pgrx_sql_entity_graph/extension_sql/
entity.rs

1//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC.
2//LICENSE
3//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc.
4//LICENSE
5//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. <contact@pgcentral.org>
6//LICENSE
7//LICENSE All rights reserved.
8//LICENSE
9//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file.
10/*!
11
12`pgrx::extension_sql!()` related entities for Rust to SQL translation
13
14> Like all of the [`sql_entity_graph`][crate] APIs, this is considered **internal**
15> to the `pgrx` framework and very subject to change between versions. While you may use this, please do it with caution.
16
17
18*/
19use crate::extension_sql::SqlDeclared;
20use crate::pgrx_sql::PgrxSql;
21use crate::positioning_ref::PositioningRef;
22use crate::to_sql::ToSql;
23use crate::{SqlGraphEntity, SqlGraphIdentifier};
24
25use std::fmt::Display;
26
27/// The output of a [`ExtensionSql`](crate::ExtensionSql) from `quote::ToTokens::to_tokens`.
28#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
29pub struct ExtensionSqlEntity {
30    pub module_path: &'static str,
31    pub full_path: &'static str,
32    pub sql: &'static str,
33    pub file: &'static str,
34    pub line: u32,
35    pub name: &'static str,
36    pub bootstrap: bool,
37    pub finalize: bool,
38    pub requires: Vec<PositioningRef>,
39    pub creates: Vec<SqlDeclaredEntity>,
40}
41
42impl ExtensionSqlEntity {
43    pub fn has_sql_declared_entity(&self, identifier: &SqlDeclared) -> Option<&SqlDeclaredEntity> {
44        self.creates.iter().find(|created| created.has_sql_declared_entity(identifier))
45    }
46}
47
48impl From<ExtensionSqlEntity> for SqlGraphEntity {
49    fn from(val: ExtensionSqlEntity) -> Self {
50        SqlGraphEntity::CustomSql(val)
51    }
52}
53
54impl SqlGraphIdentifier for ExtensionSqlEntity {
55    fn dot_identifier(&self) -> String {
56        format!("sql {}", self.name)
57    }
58    fn rust_identifier(&self) -> String {
59        self.name.to_string()
60    }
61
62    fn file(&self) -> Option<&'static str> {
63        Some(self.file)
64    }
65
66    fn line(&self) -> Option<u32> {
67        Some(self.line)
68    }
69}
70
71impl ToSql for ExtensionSqlEntity {
72    fn to_sql(&self, _context: &PgrxSql) -> eyre::Result<String> {
73        let ExtensionSqlEntity { file, line, sql, creates, requires, .. } = self;
74        let creates = if !creates.is_empty() {
75            let joined = creates.iter().map(|i| format!("--   {i}")).collect::<Vec<_>>().join("\n");
76            format!(
77                "\
78                -- creates:\n\
79                {joined}\n\n"
80            )
81        } else {
82            "".to_string()
83        };
84        let requires = if !requires.is_empty() {
85            let joined =
86                requires.iter().map(|i| format!("--   {i}")).collect::<Vec<_>>().join("\n");
87            format!(
88                "\
89               -- requires:\n\
90                {joined}\n\n"
91            )
92        } else {
93            "".to_string()
94        };
95        let sql = format!(
96            "\n\
97                -- {file}:{line}\n\
98                {bootstrap}\
99                {creates}\
100                {requires}\
101                {finalize}\
102                {sql}\
103                ",
104            bootstrap = if self.bootstrap { "-- bootstrap\n" } else { "" },
105            finalize = if self.finalize { "-- finalize\n" } else { "" },
106        );
107        Ok(sql)
108    }
109}
110
111#[derive(Debug, Clone, Hash, PartialEq, Eq, Ord, PartialOrd)]
112pub struct SqlDeclaredEntityData {
113    sql: String,
114    name: String,
115    option: String,
116    vec: String,
117    vec_option: String,
118    option_vec: String,
119    option_vec_option: String,
120    array: String,
121    option_array: String,
122    varlena: String,
123    pg_box: Vec<String>,
124}
125#[derive(Debug, Clone, Hash, PartialEq, Eq, Ord, PartialOrd)]
126pub enum SqlDeclaredEntity {
127    Type(SqlDeclaredEntityData),
128    Enum(SqlDeclaredEntityData),
129    Function(SqlDeclaredEntityData),
130}
131
132impl Display for SqlDeclaredEntity {
133    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134        match self {
135            SqlDeclaredEntity::Type(data) => {
136                write!(f, "Type({})", data.name)
137            }
138            SqlDeclaredEntity::Enum(data) => {
139                write!(f, "Enum({})", data.name)
140            }
141            SqlDeclaredEntity::Function(data) => {
142                write!(f, "Function({})", data.name)
143            }
144        }
145    }
146}
147
148impl SqlDeclaredEntity {
149    pub fn build(variant: impl AsRef<str>, name: impl AsRef<str>) -> eyre::Result<Self> {
150        let name = name.as_ref();
151        let data = SqlDeclaredEntityData {
152            sql: name
153                .split("::")
154                .last()
155                .ok_or_else(|| eyre::eyre!("Did not get SQL for `{}`", name))?
156                .to_string(),
157            name: name.to_string(),
158            option: format!("Option<{name}>"),
159            vec: format!("Vec<{name}>"),
160            vec_option: format!("Vec<Option<{name}>>"),
161            option_vec: format!("Option<Vec<{name}>>"),
162            option_vec_option: format!("Option<Vec<Option<{name}>>"),
163            array: format!("Array<{name}>"),
164            option_array: format!("Option<{name}>"),
165            varlena: format!("Varlena<{name}>"),
166            pg_box: vec![
167                format!("pgrx::pgbox::PgBox<{}>", name),
168                format!("pgrx::pgbox::PgBox<{}, pgrx::pgbox::AllocatedByRust>", name),
169                format!("pgrx::pgbox::PgBox<{}, pgrx::pgbox::AllocatedByPostgres>", name),
170            ],
171        };
172        let retval = match variant.as_ref() {
173            "Type" => Self::Type(data),
174            "Enum" => Self::Enum(data),
175            "Function" => Self::Function(data),
176            _ => {
177                return Err(eyre::eyre!(
178                    "Can only declare `Type(Ident)`, `Enum(Ident)` or `Function(Ident)`"
179                ))
180            }
181        };
182        Ok(retval)
183    }
184    pub fn sql(&self) -> String {
185        match self {
186            SqlDeclaredEntity::Type(data) => data.sql.clone(),
187            SqlDeclaredEntity::Enum(data) => data.sql.clone(),
188            SqlDeclaredEntity::Function(data) => data.sql.clone(),
189        }
190    }
191
192    pub fn has_sql_declared_entity(&self, identifier: &SqlDeclared) -> bool {
193        match (&identifier, &self) {
194            (SqlDeclared::Type(ident_name), &SqlDeclaredEntity::Type(data))
195            | (SqlDeclared::Enum(ident_name), &SqlDeclaredEntity::Enum(data))
196            | (SqlDeclared::Function(ident_name), &SqlDeclaredEntity::Function(data)) => {
197                let matches = |identifier_name: &str| {
198                    identifier_name == data.name
199                        || identifier_name == data.option
200                        || identifier_name == data.vec
201                        || identifier_name == data.vec_option
202                        || identifier_name == data.option_vec
203                        || identifier_name == data.option_vec_option
204                        || identifier_name == data.array
205                        || identifier_name == data.option_array
206                        || identifier_name == data.varlena
207                };
208                if matches(ident_name) || data.pg_box.contains(ident_name) {
209                    return true;
210                }
211                // there are cases where the identifier is
212                // `core::option::Option<Foo>` while the data stores
213                // `Option<Foo>` check again for this
214                let Some(generics_start) = ident_name.find('<') else { return false };
215                let Some(qual_end) = ident_name[..generics_start].rfind("::") else { return false };
216                matches(&ident_name[qual_end + 2..])
217            }
218            _ => false,
219        }
220    }
221}