cedar_policy_core/
extensions.rs1#[cfg(feature = "ipaddr")]
20pub mod ipaddr;
21
22#[cfg(feature = "decimal")]
23pub mod decimal;
24
25#[cfg(feature = "datetime")]
26pub mod datetime;
27pub mod partial_evaluation;
28
29use std::collections::HashMap;
30
31use crate::ast::{Extension, ExtensionFunction, Name};
32use crate::entities::SchemaType;
33use crate::parser::Loc;
34use miette::Diagnostic;
35use thiserror::Error;
36
37use self::extension_function_lookup_errors::FuncDoesNotExistError;
38use self::extension_initialization_errors::{
39 FuncMultiplyDefinedError, MultipleConstructorsSameSignatureError,
40};
41
42lazy_static::lazy_static! {
43 static ref ALL_AVAILABLE_EXTENSION_OBJECTS: Vec<Extension> = vec![
44 #[cfg(feature = "ipaddr")]
45 ipaddr::extension(),
46 #[cfg(feature = "decimal")]
47 decimal::extension(),
48 #[cfg(feature = "datetime")]
49 datetime::extension(),
50 #[cfg(feature = "partial-eval")]
51 partial_evaluation::extension(),
52 ];
53
54 static ref ALL_AVAILABLE_EXTENSIONS : Extensions<'static> = Extensions::build_all_available();
55
56 static ref EXTENSIONS_NONE : Extensions<'static> = Extensions {
57 extensions: &[],
58 functions: HashMap::new(),
59 single_arg_constructors: HashMap::new(),
60 };
61}
62
63#[derive(Debug)]
68pub struct Extensions<'a> {
69 extensions: &'a [Extension],
71 functions: HashMap<&'a Name, &'a ExtensionFunction>,
76 single_arg_constructors: HashMap<&'a SchemaType, &'a ExtensionFunction>,
80}
81
82impl Extensions<'static> {
83 fn build_all_available() -> Extensions<'static> {
85 #[allow(clippy::expect_used)]
87 Self::specific_extensions(&ALL_AVAILABLE_EXTENSION_OBJECTS)
88 .expect("Default extensions should never error on initialization")
89 }
90
91 pub fn all_available() -> &'static Extensions<'static> {
93 &ALL_AVAILABLE_EXTENSIONS
94 }
95
96 pub fn none() -> &'static Extensions<'static> {
98 &EXTENSIONS_NONE
99 }
100}
101
102impl<'a> Extensions<'a> {
103 pub fn types_with_operator_overloading(&self) -> impl Iterator<Item = &Name> + '_ {
105 self.extensions
106 .iter()
107 .flat_map(|ext| ext.types_with_operator_overloading())
108 }
109 pub fn specific_extensions(
111 extensions: &'a [Extension],
112 ) -> std::result::Result<Extensions<'a>, ExtensionInitializationError> {
113 let functions = util::collect_no_duplicates(
115 extensions
116 .iter()
117 .flat_map(|e| e.funcs())
118 .map(|f| (f.name(), f)),
119 )
120 .map_err(|name| FuncMultiplyDefinedError { name: name.clone() })?;
121
122 let single_arg_constructors = util::collect_no_duplicates(
124 extensions
125 .iter()
126 .flat_map(|e| e.funcs())
127 .filter(|f| f.is_constructor() && f.arg_types().len() == 1)
128 .filter_map(|f| f.return_type().map(|return_type| (return_type, f))),
129 )
130 .map_err(|return_type| MultipleConstructorsSameSignatureError {
131 return_type: Box::new(return_type.clone()),
132 })?;
133
134 Ok(Extensions {
135 extensions,
136 functions,
137 single_arg_constructors,
138 })
139 }
140
141 pub fn ext_names(&self) -> impl Iterator<Item = &Name> {
143 self.extensions.iter().map(|ext| ext.name())
144 }
145
146 pub fn ext_types(&self) -> impl Iterator<Item = &Name> {
151 self.extensions.iter().flat_map(|ext| ext.ext_types())
152 }
153
154 pub fn func(
158 &self,
159 name: &Name,
160 ) -> std::result::Result<&ExtensionFunction, ExtensionFunctionLookupError> {
161 self.functions.get(name).copied().ok_or_else(|| {
162 FuncDoesNotExistError {
163 name: name.clone(),
164 source_loc: name.loc().cloned(),
165 }
166 .into()
167 })
168 }
169
170 pub(crate) fn all_funcs(&self) -> impl Iterator<Item = &'a ExtensionFunction> {
174 self.extensions.iter().flat_map(|ext| ext.funcs())
175 }
176
177 pub(crate) fn lookup_single_arg_constructor(
181 &self,
182 return_type: &SchemaType,
183 ) -> Option<&ExtensionFunction> {
184 self.single_arg_constructors.get(return_type).copied()
185 }
186}
187
188#[derive(Diagnostic, Debug, PartialEq, Eq, Clone, Error)]
192pub enum ExtensionInitializationError {
193 #[error(transparent)]
195 #[diagnostic(transparent)]
196 FuncMultiplyDefined(#[from] extension_initialization_errors::FuncMultiplyDefinedError),
197
198 #[error(transparent)]
201 #[diagnostic(transparent)]
202 MultipleConstructorsSameSignature(
203 #[from] extension_initialization_errors::MultipleConstructorsSameSignatureError,
204 ),
205}
206
207mod extension_initialization_errors {
209 use crate::{ast::Name, entities::SchemaType};
210 use miette::Diagnostic;
211 use thiserror::Error;
212
213 #[derive(Diagnostic, Debug, PartialEq, Eq, Clone, Error)]
215 #[error("extension function `{name}` is defined multiple times")]
216 pub struct FuncMultiplyDefinedError {
217 pub(crate) name: Name,
219 }
220
221 #[derive(Diagnostic, Debug, PartialEq, Eq, Clone, Error)]
224 #[error("multiple extension constructors for the same extension type {return_type}")]
225 pub struct MultipleConstructorsSameSignatureError {
226 pub(crate) return_type: Box<SchemaType>,
228 }
229}
230
231#[derive(Debug, PartialEq, Eq, Clone, Diagnostic, Error)]
237pub enum ExtensionFunctionLookupError {
238 #[error(transparent)]
240 #[diagnostic(transparent)]
241 FuncDoesNotExist(#[from] extension_function_lookup_errors::FuncDoesNotExistError),
242}
243
244impl ExtensionFunctionLookupError {
245 pub(crate) fn source_loc(&self) -> Option<&Loc> {
246 match self {
247 Self::FuncDoesNotExist(e) => e.source_loc.as_ref(),
248 }
249 }
250
251 pub(crate) fn with_maybe_source_loc(self, source_loc: Option<Loc>) -> Self {
252 match self {
253 Self::FuncDoesNotExist(e) => {
254 Self::FuncDoesNotExist(extension_function_lookup_errors::FuncDoesNotExistError {
255 source_loc,
256 ..e
257 })
258 }
259 }
260 }
261}
262
263pub mod extension_function_lookup_errors {
265 use crate::ast::Name;
266 use crate::parser::Loc;
267 use miette::Diagnostic;
268 use thiserror::Error;
269
270 #[derive(Debug, PartialEq, Eq, Clone, Error)]
276 #[error("extension function `{name}` does not exist")]
277 pub struct FuncDoesNotExistError {
278 pub(crate) name: Name,
280 pub(crate) source_loc: Option<Loc>,
282 }
283
284 impl Diagnostic for FuncDoesNotExistError {
285 impl_diagnostic_from_source_loc_opt_field!(source_loc);
286 }
287}
288
289pub type Result<T> = std::result::Result<T, ExtensionFunctionLookupError>;
291
292pub mod util {
294 use std::collections::{hash_map::Entry, HashMap};
295
296 pub fn collect_no_duplicates<K, V>(
300 i: impl Iterator<Item = (K, V)>,
301 ) -> std::result::Result<HashMap<K, V>, K>
302 where
303 K: Clone + std::hash::Hash + Eq,
304 {
305 let mut map = HashMap::with_capacity(i.size_hint().0);
306 for (k, v) in i {
307 match map.entry(k) {
308 Entry::Occupied(occupied) => {
309 return Err(occupied.key().clone());
310 }
311 Entry::Vacant(vacant) => {
312 vacant.insert(v);
313 }
314 }
315 }
316 Ok(map)
317 }
318}
319
320#[cfg(test)]
321mod test {
322 use super::*;
323 use std::collections::HashSet;
324
325 #[test]
326 fn no_common_extension_function_names() {
327 let all_names: Vec<_> = Extensions::all_available()
334 .extensions
335 .iter()
336 .flat_map(|e| e.funcs().map(|f| f.name().clone()))
337 .collect();
338 let dedup_names: HashSet<_> = all_names.iter().collect();
339 assert_eq!(all_names.len(), dedup_names.len());
340 }
341}