1use crate::PositioningRef;
11use proc_macro2::{TokenStream, TokenTree};
12use quote::{format_ident, quote, ToTokens, TokenStreamExt};
13use std::collections::HashSet;
14
15#[derive(Debug, Hash, Eq, PartialEq, Clone, PartialOrd, Ord)]
16pub enum ExternArgs {
17 CreateOrReplace,
18 Immutable,
19 Strict,
20 Stable,
21 Volatile,
22 Raw,
23 NoGuard,
24 SecurityDefiner,
25 SecurityInvoker,
26 ParallelSafe,
27 ParallelUnsafe,
28 ParallelRestricted,
29 ShouldPanic(String),
30 Schema(String),
31 Name(String),
32 Cost(String),
33 Requires(Vec<PositioningRef>),
34}
35
36impl core::fmt::Display for ExternArgs {
37 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
38 match self {
39 ExternArgs::CreateOrReplace => write!(f, "CREATE OR REPLACE"),
40 ExternArgs::Immutable => write!(f, "IMMUTABLE"),
41 ExternArgs::Strict => write!(f, "STRICT"),
42 ExternArgs::Stable => write!(f, "STABLE"),
43 ExternArgs::Volatile => write!(f, "VOLATILE"),
44 ExternArgs::Raw => Ok(()),
45 ExternArgs::ParallelSafe => write!(f, "PARALLEL SAFE"),
46 ExternArgs::ParallelUnsafe => write!(f, "PARALLEL UNSAFE"),
47 ExternArgs::SecurityDefiner => write!(f, "SECURITY DEFINER"),
48 ExternArgs::SecurityInvoker => write!(f, "SECURITY INVOKER"),
49 ExternArgs::ParallelRestricted => write!(f, "PARALLEL RESTRICTED"),
50 ExternArgs::ShouldPanic(_) => Ok(()),
51 ExternArgs::NoGuard => Ok(()),
52 ExternArgs::Schema(_) => Ok(()),
53 ExternArgs::Name(_) => Ok(()),
54 ExternArgs::Cost(cost) => write!(f, "COST {cost}"),
55 ExternArgs::Requires(_) => Ok(()),
56 }
57 }
58}
59
60impl ToTokens for ExternArgs {
61 fn to_tokens(&self, tokens: &mut TokenStream) {
62 match self {
63 ExternArgs::CreateOrReplace => tokens.append(format_ident!("CreateOrReplace")),
64 ExternArgs::Immutable => tokens.append(format_ident!("Immutable")),
65 ExternArgs::Strict => tokens.append(format_ident!("Strict")),
66 ExternArgs::Stable => tokens.append(format_ident!("Stable")),
67 ExternArgs::Volatile => tokens.append(format_ident!("Volatile")),
68 ExternArgs::Raw => tokens.append(format_ident!("Raw")),
69 ExternArgs::NoGuard => tokens.append(format_ident!("NoGuard")),
70 ExternArgs::SecurityDefiner => tokens.append(format_ident!("SecurityDefiner")),
71 ExternArgs::SecurityInvoker => tokens.append(format_ident!("SecurityInvoker")),
72 ExternArgs::ParallelSafe => tokens.append(format_ident!("ParallelSafe")),
73 ExternArgs::ParallelUnsafe => tokens.append(format_ident!("ParallelUnsafe")),
74 ExternArgs::ParallelRestricted => tokens.append(format_ident!("ParallelRestricted")),
75 ExternArgs::ShouldPanic(_s) => {
76 tokens.append_all(
77 quote! {
78 Error(String::from("#_s"))
79 }
80 .to_token_stream(),
81 );
82 }
83 ExternArgs::Schema(_s) => {
84 tokens.append_all(
85 quote! {
86 Schema(String::from("#_s"))
87 }
88 .to_token_stream(),
89 );
90 }
91 ExternArgs::Name(_s) => {
92 tokens.append_all(
93 quote! {
94 Name(String::from("#_s"))
95 }
96 .to_token_stream(),
97 );
98 }
99 ExternArgs::Cost(_s) => {
100 tokens.append_all(
101 quote! {
102 Cost(String::from("#_s"))
103 }
104 .to_token_stream(),
105 );
106 }
107 ExternArgs::Requires(items) => {
108 tokens.append_all(
109 quote! {
110 Requires(vec![#(#items),*])
111 }
112 .to_token_stream(),
113 );
114 }
115 }
116 }
117}
118
119#[track_caller]
121pub fn parse_extern_attributes(attr: TokenStream) -> HashSet<ExternArgs> {
122 let mut args = HashSet::<ExternArgs>::new();
123 let mut itr = attr.into_iter();
124 while let Some(t) = itr.next() {
125 match t {
126 TokenTree::Group(g) => {
127 for arg in parse_extern_attributes(g.stream()).into_iter() {
128 args.insert(arg);
129 }
130 }
131 TokenTree::Ident(i) => {
132 let name = i.to_string();
133 match name.as_str() {
134 "create_or_replace" => args.insert(ExternArgs::CreateOrReplace),
135 "immutable" => args.insert(ExternArgs::Immutable),
136 "strict" => args.insert(ExternArgs::Strict),
137 "stable" => args.insert(ExternArgs::Stable),
138 "volatile" => args.insert(ExternArgs::Volatile),
139 "raw" => args.insert(ExternArgs::Raw),
140 "no_guard" => args.insert(ExternArgs::NoGuard),
141 "security_invoker" => args.insert(ExternArgs::SecurityInvoker),
142 "security_definer" => args.insert(ExternArgs::SecurityDefiner),
143 "parallel_safe" => args.insert(ExternArgs::ParallelSafe),
144 "parallel_unsafe" => args.insert(ExternArgs::ParallelUnsafe),
145 "parallel_restricted" => args.insert(ExternArgs::ParallelRestricted),
146 "error" | "expected" => {
147 let _punc = itr.next().unwrap();
148 let literal = itr.next().unwrap();
149 let message = literal.to_string();
150 let message = unescape::unescape(&message).expect("failed to unescape");
151
152 let message = message[1..message.len() - 1].to_string();
154 args.insert(ExternArgs::ShouldPanic(message.to_string()))
155 }
156 "schema" => {
157 let _punc = itr.next().unwrap();
158 let literal = itr.next().unwrap();
159 let schema = literal.to_string();
160 let schema = unescape::unescape(&schema).expect("failed to unescape");
161
162 let schema = schema[1..schema.len() - 1].to_string();
164 args.insert(ExternArgs::Schema(schema.to_string()))
165 }
166 "name" => {
167 let _punc = itr.next().unwrap();
168 let literal = itr.next().unwrap();
169 let name = literal.to_string();
170 let name = unescape::unescape(&name).expect("failed to unescape");
171
172 let name = name[1..name.len() - 1].to_string();
174 args.insert(ExternArgs::Name(name.to_string()))
175 }
176 "sql" => {
178 let _punc = itr.next().unwrap();
179 let _value = itr.next().unwrap();
180 false
181 }
182 _ => false,
183 };
184 }
185 TokenTree::Punct(_) => {}
186 TokenTree::Literal(_) => {}
187 }
188 }
189 args
190}
191
192#[cfg(test)]
193mod tests {
194 use std::str::FromStr;
195
196 use crate::{parse_extern_attributes, ExternArgs};
197
198 #[test]
199 fn parse_args() {
200 let s = "error = \"syntax error at or near \\\"THIS\\\"\"";
201 let ts = proc_macro2::TokenStream::from_str(s).unwrap();
202
203 let args = parse_extern_attributes(ts);
204 assert!(
205 args.contains(&ExternArgs::ShouldPanic("syntax error at or near \"THIS\"".to_string()))
206 );
207 }
208}