pgrx_sql_entity_graph/extension_sql/
mod.rs1pub mod entity;
20
21use crate::positioning_ref::PositioningRef;
22
23use crate::enrich::{CodeEnrichment, ToEntityGraphTokens, ToRustCodeTokens};
24use proc_macro2::{Ident, TokenStream as TokenStream2};
25use quote::{format_ident, quote, ToTokens, TokenStreamExt};
26use syn::parse::{Parse, ParseStream};
27use syn::punctuated::Punctuated;
28use syn::{LitStr, Token};
29
30#[derive(Debug, Clone)]
55pub struct ExtensionSqlFile {
56 pub path: LitStr,
57 pub attrs: Punctuated<ExtensionSqlAttribute, Token![,]>,
58}
59
60impl ToEntityGraphTokens for ExtensionSqlFile {
61 fn to_entity_graph_tokens(&self) -> TokenStream2 {
62 let path = &self.path;
63 let mut name = None;
64 let mut bootstrap = false;
65 let mut finalize = false;
66 let mut requires = vec![];
67 let mut creates = vec![];
68 for attr in &self.attrs {
69 match attr {
70 ExtensionSqlAttribute::Creates(items) => {
71 creates.append(&mut items.iter().map(|x| x.to_token_stream()).collect());
72 }
73 ExtensionSqlAttribute::Requires(items) => {
74 requires.append(&mut items.iter().map(|x| x.to_token_stream()).collect());
75 }
76 ExtensionSqlAttribute::Bootstrap => {
77 bootstrap = true;
78 }
79 ExtensionSqlAttribute::Finalize => {
80 finalize = true;
81 }
82 ExtensionSqlAttribute::Name(found_name) => {
83 name = Some(found_name.value());
84 }
85 }
86 }
87 let name = name.unwrap_or(
88 std::path::PathBuf::from(path.value())
89 .file_stem()
90 .expect("No file name for extension_sql_file!()")
91 .to_str()
92 .expect("No UTF-8 file name for extension_sql_file!()")
93 .to_string(),
94 );
95 let requires_iter = requires.iter();
96 let creates_iter = creates.iter();
97 let sql_graph_entity_fn_name = format_ident!("__pgrx_internals_sql_{}", name.clone());
98 quote! {
99 #[no_mangle]
100 #[doc(hidden)]
101 #[allow(unknown_lints, clippy::no_mangle_with_rust_abi)]
102 pub extern "Rust" fn #sql_graph_entity_fn_name() -> ::pgrx::pgrx_sql_entity_graph::SqlGraphEntity {
103 extern crate alloc;
104 use alloc::vec::Vec;
105 use alloc::vec;
106 let submission = ::pgrx::pgrx_sql_entity_graph::ExtensionSqlEntity {
107 sql: include_str!(#path),
108 module_path: module_path!(),
109 full_path: concat!(file!(), ':', line!()),
110 file: file!(),
111 line: line!(),
112 name: #name,
113 bootstrap: #bootstrap,
114 finalize: #finalize,
115 requires: vec![#(#requires_iter),*],
116 creates: vec![#(#creates_iter),*],
117 };
118 ::pgrx::pgrx_sql_entity_graph::SqlGraphEntity::CustomSql(submission)
119 }
120 }
121 }
122}
123
124impl ToRustCodeTokens for ExtensionSqlFile {}
125
126impl Parse for CodeEnrichment<ExtensionSqlFile> {
127 fn parse(input: ParseStream) -> Result<Self, syn::Error> {
128 let path = input.parse()?;
129 let _after_sql_comma: Option<Token![,]> = input.parse()?;
130 let attrs = input.parse_terminated(ExtensionSqlAttribute::parse, Token![,])?;
131 Ok(CodeEnrichment(ExtensionSqlFile { path, attrs }))
132 }
133}
134
135#[derive(Debug, Clone)]
160pub struct ExtensionSql {
161 pub sql: LitStr,
162 pub name: LitStr,
163 pub attrs: Punctuated<ExtensionSqlAttribute, Token![,]>,
164}
165
166impl ToEntityGraphTokens for ExtensionSql {
167 fn to_entity_graph_tokens(&self) -> TokenStream2 {
168 let sql = &self.sql;
169 let mut bootstrap = false;
170 let mut finalize = false;
171 let mut creates = vec![];
172 let mut requires = vec![];
173 for attr in &self.attrs {
174 match attr {
175 ExtensionSqlAttribute::Requires(items) => {
176 requires.append(&mut items.iter().map(|x| x.to_token_stream()).collect());
177 }
178 ExtensionSqlAttribute::Creates(items) => {
179 creates.append(&mut items.iter().map(|x| x.to_token_stream()).collect());
180 }
181 ExtensionSqlAttribute::Bootstrap => {
182 bootstrap = true;
183 }
184 ExtensionSqlAttribute::Finalize => {
185 finalize = true;
186 }
187 ExtensionSqlAttribute::Name(_found_name) => (), }
189 }
190 let requires_iter = requires.iter();
191 let creates_iter = creates.iter();
192 let name = &self.name;
193
194 let sql_graph_entity_fn_name = format_ident!("__pgrx_internals_sql_{}", name.value());
195 quote! {
196 #[no_mangle]
197 #[doc(hidden)]
198 #[allow(unknown_lints, clippy::no_mangle_with_rust_abi)]
199 pub extern "Rust" fn #sql_graph_entity_fn_name() -> ::pgrx::pgrx_sql_entity_graph::SqlGraphEntity {
200 extern crate alloc;
201 use alloc::vec::Vec;
202 use alloc::vec;
203 let submission = ::pgrx::pgrx_sql_entity_graph::ExtensionSqlEntity {
204 sql: #sql,
205 module_path: module_path!(),
206 full_path: concat!(file!(), ':', line!()),
207 file: file!(),
208 line: line!(),
209 name: #name,
210 bootstrap: #bootstrap,
211 finalize: #finalize,
212 requires: vec![#(#requires_iter),*],
213 creates: vec![#(#creates_iter),*],
214 };
215 ::pgrx::pgrx_sql_entity_graph::SqlGraphEntity::CustomSql(submission)
216 }
217 }
218 }
219}
220
221impl ToRustCodeTokens for ExtensionSql {}
222
223impl Parse for CodeEnrichment<ExtensionSql> {
224 fn parse(input: ParseStream) -> Result<Self, syn::Error> {
225 let sql = input.parse()?;
226 let _after_sql_comma: Option<Token![,]> = input.parse()?;
227 let attrs = input.parse_terminated(ExtensionSqlAttribute::parse, Token![,])?;
228 let name = attrs.iter().rev().find_map(|attr| match attr {
229 ExtensionSqlAttribute::Name(found_name) => Some(found_name.clone()),
230 _ => None,
231 });
232 let name =
233 name.ok_or_else(|| syn::Error::new(input.span(), "expected `name` to be set"))?;
234 Ok(CodeEnrichment(ExtensionSql { sql, attrs, name }))
235 }
236}
237
238impl ToTokens for ExtensionSql {
239 fn to_tokens(&self, tokens: &mut TokenStream2) {
240 tokens.append_all(self.to_entity_graph_tokens())
241 }
242}
243
244#[derive(Debug, Clone)]
245pub enum ExtensionSqlAttribute {
246 Requires(Punctuated<PositioningRef, Token![,]>),
247 Creates(Punctuated<SqlDeclared, Token![,]>),
248 Bootstrap,
249 Finalize,
250 Name(LitStr),
251}
252
253impl Parse for ExtensionSqlAttribute {
254 fn parse(input: ParseStream) -> Result<Self, syn::Error> {
255 let ident: Ident = input.parse()?;
256 let found = match ident.to_string().as_str() {
257 "creates" => {
258 let _eq: syn::token::Eq = input.parse()?;
259 let content;
260 let _bracket = syn::bracketed!(content in input);
261 Self::Creates(content.parse_terminated(SqlDeclared::parse, Token![,])?)
262 }
263 "requires" => {
264 let _eq: syn::token::Eq = input.parse()?;
265 let content;
266 let _bracket = syn::bracketed!(content in input);
267 Self::Requires(content.parse_terminated(PositioningRef::parse, Token![,])?)
268 }
269 "bootstrap" => Self::Bootstrap,
270 "finalize" => Self::Finalize,
271 "name" => {
272 let _eq: syn::token::Eq = input.parse()?;
273 Self::Name(input.parse()?)
274 }
275 other => {
276 return Err(syn::Error::new(
277 ident.span(),
278 format!("Unknown extension_sql attribute: {other}"),
279 ))
280 }
281 };
282 Ok(found)
283 }
284}
285
286#[derive(Debug, Clone, Hash, PartialEq, Eq, Ord, PartialOrd)]
287pub enum SqlDeclared {
288 Type(String),
289 Enum(String),
290 Function(String),
291}
292
293impl ToEntityGraphTokens for SqlDeclared {
294 fn to_entity_graph_tokens(&self) -> TokenStream2 {
295 let (variant, identifier) = match &self {
296 SqlDeclared::Type(val) => ("Type", val),
297 SqlDeclared::Enum(val) => ("Enum", val),
298 SqlDeclared::Function(val) => ("Function", val),
299 };
300 let identifier_split = identifier.split("::").collect::<Vec<_>>();
301 let identifier = if identifier_split.len() == 1 {
302 let identifier_infer =
303 Ident::new(identifier_split.last().unwrap(), proc_macro2::Span::call_site());
304 quote! { concat!(module_path!(), "::", stringify!(#identifier_infer)) }
305 } else {
306 quote! { stringify!(#identifier) }
307 };
308 quote! {
309 ::pgrx::pgrx_sql_entity_graph::SqlDeclaredEntity::build(#variant, #identifier).unwrap()
310 }
311 }
312}
313
314impl ToRustCodeTokens for SqlDeclared {}
315
316impl Parse for SqlDeclared {
317 fn parse(input: ParseStream) -> syn::Result<Self> {
318 let variant: Ident = input.parse()?;
319 let content;
320 let _bracket: syn::token::Paren = syn::parenthesized!(content in input);
321 let identifier_path: syn::Path = content.parse()?;
322 let identifier_str = {
323 let mut identifier_segments = Vec::new();
324 for segment in identifier_path.segments {
325 identifier_segments.push(segment.ident.to_string())
326 }
327 identifier_segments.join("::")
328 };
329 let this = match variant.to_string().as_str() {
330 "Type" => SqlDeclared::Type(identifier_str),
331 "Enum" => SqlDeclared::Enum(identifier_str),
332 "Function" => SqlDeclared::Function(identifier_str),
333 _ => return Err(syn::Error::new(
334 variant.span(),
335 "SQL declared entities must be `Type(ident)`, `Enum(ident)`, or `Function(ident)`",
336 )),
337 };
338 Ok(this)
339 }
340}
341
342impl ToTokens for SqlDeclared {
343 fn to_tokens(&self, tokens: &mut TokenStream2) {
344 tokens.append_all(self.to_entity_graph_tokens())
345 }
346}