pgrx_sql_entity_graph/
extern_args.rs

1//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC.
2//LICENSE
3//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc.
4//LICENSE
5//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. <contact@pgcentral.org>
6//LICENSE
7//LICENSE All rights reserved.
8//LICENSE
9//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file.
10use crate::PositioningRef;
11use proc_macro2::{TokenStream, TokenTree};
12use quote::{format_ident, quote, ToTokens, TokenStreamExt};
13use std::collections::HashSet;
14
15#[derive(Debug, Hash, Eq, PartialEq, Clone, PartialOrd, Ord)]
16pub enum ExternArgs {
17    CreateOrReplace,
18    Immutable,
19    Strict,
20    Stable,
21    Volatile,
22    Raw,
23    NoGuard,
24    SecurityDefiner,
25    SecurityInvoker,
26    ParallelSafe,
27    ParallelUnsafe,
28    ParallelRestricted,
29    ShouldPanic(String),
30    Schema(String),
31    Name(String),
32    Cost(String),
33    Requires(Vec<PositioningRef>),
34}
35
36impl core::fmt::Display for ExternArgs {
37    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
38        match self {
39            ExternArgs::CreateOrReplace => write!(f, "CREATE OR REPLACE"),
40            ExternArgs::Immutable => write!(f, "IMMUTABLE"),
41            ExternArgs::Strict => write!(f, "STRICT"),
42            ExternArgs::Stable => write!(f, "STABLE"),
43            ExternArgs::Volatile => write!(f, "VOLATILE"),
44            ExternArgs::Raw => Ok(()),
45            ExternArgs::ParallelSafe => write!(f, "PARALLEL SAFE"),
46            ExternArgs::ParallelUnsafe => write!(f, "PARALLEL UNSAFE"),
47            ExternArgs::SecurityDefiner => write!(f, "SECURITY DEFINER"),
48            ExternArgs::SecurityInvoker => write!(f, "SECURITY INVOKER"),
49            ExternArgs::ParallelRestricted => write!(f, "PARALLEL RESTRICTED"),
50            ExternArgs::ShouldPanic(_) => Ok(()),
51            ExternArgs::NoGuard => Ok(()),
52            ExternArgs::Schema(_) => Ok(()),
53            ExternArgs::Name(_) => Ok(()),
54            ExternArgs::Cost(cost) => write!(f, "COST {cost}"),
55            ExternArgs::Requires(_) => Ok(()),
56        }
57    }
58}
59
60impl ToTokens for ExternArgs {
61    fn to_tokens(&self, tokens: &mut TokenStream) {
62        match self {
63            ExternArgs::CreateOrReplace => tokens.append(format_ident!("CreateOrReplace")),
64            ExternArgs::Immutable => tokens.append(format_ident!("Immutable")),
65            ExternArgs::Strict => tokens.append(format_ident!("Strict")),
66            ExternArgs::Stable => tokens.append(format_ident!("Stable")),
67            ExternArgs::Volatile => tokens.append(format_ident!("Volatile")),
68            ExternArgs::Raw => tokens.append(format_ident!("Raw")),
69            ExternArgs::NoGuard => tokens.append(format_ident!("NoGuard")),
70            ExternArgs::SecurityDefiner => tokens.append(format_ident!("SecurityDefiner")),
71            ExternArgs::SecurityInvoker => tokens.append(format_ident!("SecurityInvoker")),
72            ExternArgs::ParallelSafe => tokens.append(format_ident!("ParallelSafe")),
73            ExternArgs::ParallelUnsafe => tokens.append(format_ident!("ParallelUnsafe")),
74            ExternArgs::ParallelRestricted => tokens.append(format_ident!("ParallelRestricted")),
75            ExternArgs::ShouldPanic(_s) => {
76                tokens.append_all(
77                    quote! {
78                        Error(String::from("#_s"))
79                    }
80                    .to_token_stream(),
81                );
82            }
83            ExternArgs::Schema(_s) => {
84                tokens.append_all(
85                    quote! {
86                        Schema(String::from("#_s"))
87                    }
88                    .to_token_stream(),
89                );
90            }
91            ExternArgs::Name(_s) => {
92                tokens.append_all(
93                    quote! {
94                        Name(String::from("#_s"))
95                    }
96                    .to_token_stream(),
97                );
98            }
99            ExternArgs::Cost(_s) => {
100                tokens.append_all(
101                    quote! {
102                        Cost(String::from("#_s"))
103                    }
104                    .to_token_stream(),
105                );
106            }
107            ExternArgs::Requires(items) => {
108                tokens.append_all(
109                    quote! {
110                        Requires(vec![#(#items),*])
111                    }
112                    .to_token_stream(),
113                );
114            }
115        }
116    }
117}
118
119// This horror-story should be returning result
120#[track_caller]
121pub fn parse_extern_attributes(attr: TokenStream) -> HashSet<ExternArgs> {
122    let mut args = HashSet::<ExternArgs>::new();
123    let mut itr = attr.into_iter();
124    while let Some(t) = itr.next() {
125        match t {
126            TokenTree::Group(g) => {
127                for arg in parse_extern_attributes(g.stream()).into_iter() {
128                    args.insert(arg);
129                }
130            }
131            TokenTree::Ident(i) => {
132                let name = i.to_string();
133                match name.as_str() {
134                    "create_or_replace" => args.insert(ExternArgs::CreateOrReplace),
135                    "immutable" => args.insert(ExternArgs::Immutable),
136                    "strict" => args.insert(ExternArgs::Strict),
137                    "stable" => args.insert(ExternArgs::Stable),
138                    "volatile" => args.insert(ExternArgs::Volatile),
139                    "raw" => args.insert(ExternArgs::Raw),
140                    "no_guard" => args.insert(ExternArgs::NoGuard),
141                    "security_invoker" => args.insert(ExternArgs::SecurityInvoker),
142                    "security_definer" => args.insert(ExternArgs::SecurityDefiner),
143                    "parallel_safe" => args.insert(ExternArgs::ParallelSafe),
144                    "parallel_unsafe" => args.insert(ExternArgs::ParallelUnsafe),
145                    "parallel_restricted" => args.insert(ExternArgs::ParallelRestricted),
146                    "error" | "expected" => {
147                        let _punc = itr.next().unwrap();
148                        let literal = itr.next().unwrap();
149                        let message = literal.to_string();
150                        let message = unescape::unescape(&message).expect("failed to unescape");
151
152                        // trim leading/trailing quotes around the literal
153                        let message = message[1..message.len() - 1].to_string();
154                        args.insert(ExternArgs::ShouldPanic(message.to_string()))
155                    }
156                    "schema" => {
157                        let _punc = itr.next().unwrap();
158                        let literal = itr.next().unwrap();
159                        let schema = literal.to_string();
160                        let schema = unescape::unescape(&schema).expect("failed to unescape");
161
162                        // trim leading/trailing quotes around the literal
163                        let schema = schema[1..schema.len() - 1].to_string();
164                        args.insert(ExternArgs::Schema(schema.to_string()))
165                    }
166                    "name" => {
167                        let _punc = itr.next().unwrap();
168                        let literal = itr.next().unwrap();
169                        let name = literal.to_string();
170                        let name = unescape::unescape(&name).expect("failed to unescape");
171
172                        // trim leading/trailing quotes around the literal
173                        let name = name[1..name.len() - 1].to_string();
174                        args.insert(ExternArgs::Name(name.to_string()))
175                    }
176                    // Recognized, but not handled as an extern argument
177                    "sql" => {
178                        let _punc = itr.next().unwrap();
179                        let _value = itr.next().unwrap();
180                        false
181                    }
182                    _ => false,
183                };
184            }
185            TokenTree::Punct(_) => {}
186            TokenTree::Literal(_) => {}
187        }
188    }
189    args
190}
191
192#[cfg(test)]
193mod tests {
194    use std::str::FromStr;
195
196    use crate::{parse_extern_attributes, ExternArgs};
197
198    #[test]
199    fn parse_args() {
200        let s = "error = \"syntax error at or near \\\"THIS\\\"\"";
201        let ts = proc_macro2::TokenStream::from_str(s).unwrap();
202
203        let args = parse_extern_attributes(ts);
204        assert!(
205            args.contains(&ExternArgs::ShouldPanic("syntax error at or near \"THIS\"".to_string()))
206        );
207    }
208}