cedar_policy_core/
extensions.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17//! This module contains all of the standard Cedar extensions.
18
19#[cfg(feature = "ipaddr")]
20pub mod ipaddr;
21
22#[cfg(feature = "decimal")]
23pub mod decimal;
24
25#[cfg(feature = "datetime")]
26pub mod datetime;
27pub mod partial_evaluation;
28
29use std::collections::HashMap;
30
31use crate::ast::{Extension, ExtensionFunction, Name};
32use crate::entities::SchemaType;
33use crate::parser::Loc;
34use miette::Diagnostic;
35use thiserror::Error;
36
37use self::extension_function_lookup_errors::FuncDoesNotExistError;
38use self::extension_initialization_errors::{
39    FuncMultiplyDefinedError, MultipleConstructorsSameSignatureError,
40};
41
42lazy_static::lazy_static! {
43    static ref ALL_AVAILABLE_EXTENSION_OBJECTS: Vec<Extension> = vec![
44        #[cfg(feature = "ipaddr")]
45        ipaddr::extension(),
46        #[cfg(feature = "decimal")]
47        decimal::extension(),
48        #[cfg(feature = "datetime")]
49        datetime::extension(),
50        #[cfg(feature = "partial-eval")]
51        partial_evaluation::extension(),
52    ];
53
54    static ref ALL_AVAILABLE_EXTENSIONS : Extensions<'static> = Extensions::build_all_available();
55
56    static ref EXTENSIONS_NONE : Extensions<'static> = Extensions {
57        extensions: &[],
58        functions: HashMap::new(),
59        single_arg_constructors: HashMap::new(),
60    };
61}
62
63/// Holds data on all the Extensions which are active for a given evaluation.
64///
65/// This structure is intentionally not `Clone` because we can use it entirely
66/// by reference.
67#[derive(Debug)]
68pub struct Extensions<'a> {
69    /// the actual extensions
70    extensions: &'a [Extension],
71    /// All extension functions, collected from every extension used to
72    /// construct this object.  Built ahead of time so that we know during
73    /// extension function lookup that at most one extension function exists
74    /// for a name. This should also make the lookup more efficient.
75    functions: HashMap<&'a Name, &'a ExtensionFunction>,
76    /// All single argument extension function constructors, indexed by their
77    /// return type. Built ahead of time so that we know each constructor has
78    /// a unique return type.
79    single_arg_constructors: HashMap<&'a SchemaType, &'a ExtensionFunction>,
80}
81
82impl Extensions<'static> {
83    /// Get a new `Extensions` containing data on all the available extensions.
84    fn build_all_available() -> Extensions<'static> {
85        // PANIC SAFETY: Builtin extensions define functions/constructors only once. Also tested by many different test cases.
86        #[allow(clippy::expect_used)]
87        Self::specific_extensions(&ALL_AVAILABLE_EXTENSION_OBJECTS)
88            .expect("Default extensions should never error on initialization")
89    }
90
91    /// An [`Extensions`] object with static lifetime contain all available extensions.
92    pub fn all_available() -> &'static Extensions<'static> {
93        &ALL_AVAILABLE_EXTENSIONS
94    }
95
96    /// Get a new `Extensions` with no extensions enabled.
97    pub fn none() -> &'static Extensions<'static> {
98        &EXTENSIONS_NONE
99    }
100}
101
102impl<'a> Extensions<'a> {
103    /// Obtain the non-empty vector of types supporting operator overloading
104    pub fn types_with_operator_overloading(&self) -> impl Iterator<Item = &Name> + '_ {
105        self.extensions
106            .iter()
107            .flat_map(|ext| ext.types_with_operator_overloading())
108    }
109    /// Get a new `Extensions` with these specific extensions enabled.
110    pub fn specific_extensions(
111        extensions: &'a [Extension],
112    ) -> std::result::Result<Extensions<'a>, ExtensionInitializationError> {
113        // Build functions map, ensuring that no functions share the same name.
114        let functions = util::collect_no_duplicates(
115            extensions
116                .iter()
117                .flat_map(|e| e.funcs())
118                .map(|f| (f.name(), f)),
119        )
120        .map_err(|name| FuncMultiplyDefinedError { name: name.clone() })?;
121
122        // Build the constructor map, ensuring that no constructors share a return type
123        let single_arg_constructors = util::collect_no_duplicates(
124            extensions
125                .iter()
126                .flat_map(|e| e.funcs())
127                .filter(|f| f.is_constructor() && f.arg_types().len() == 1)
128                .filter_map(|f| f.return_type().map(|return_type| (return_type, f))),
129        )
130        .map_err(|return_type| MultipleConstructorsSameSignatureError {
131            return_type: Box::new(return_type.clone()),
132        })?;
133
134        Ok(Extensions {
135            extensions,
136            functions,
137            single_arg_constructors,
138        })
139    }
140
141    /// Get the names of all active extensions.
142    pub fn ext_names(&self) -> impl Iterator<Item = &Name> {
143        self.extensions.iter().map(|ext| ext.name())
144    }
145
146    /// Get all extension type names declared by active extensions.
147    ///
148    /// (More specifically, all extension type names such that any function in
149    /// an active extension could produce a value of that extension type.)
150    pub fn ext_types(&self) -> impl Iterator<Item = &Name> {
151        self.extensions.iter().flat_map(|ext| ext.ext_types())
152    }
153
154    /// Get the extension function with the given name, from these extensions.
155    ///
156    /// Returns an error if the function is not defined by any extension
157    pub fn func(
158        &self,
159        name: &Name,
160    ) -> std::result::Result<&ExtensionFunction, ExtensionFunctionLookupError> {
161        self.functions.get(name).copied().ok_or_else(|| {
162            FuncDoesNotExistError {
163                name: name.clone(),
164                source_loc: name.loc().cloned(),
165            }
166            .into()
167        })
168    }
169
170    /// Iterate over all extension functions defined by all of these extensions.
171    ///
172    /// No guarantee that this list won't have duplicates or repeated names.
173    pub(crate) fn all_funcs(&self) -> impl Iterator<Item = &'a ExtensionFunction> {
174        self.extensions.iter().flat_map(|ext| ext.funcs())
175    }
176
177    /// Lookup a single-argument constructor by its return type and argument type.
178    ///
179    /// `None` means no constructor has that signature.
180    pub(crate) fn lookup_single_arg_constructor(
181        &self,
182        return_type: &SchemaType,
183    ) -> Option<&ExtensionFunction> {
184        self.single_arg_constructors.get(return_type).copied()
185    }
186}
187
188/// Errors occurring while initializing extensions. There are internal errors, so
189/// this enum should not become part of the public API unless we publicly expose
190/// user-defined extension function.
191#[derive(Diagnostic, Debug, PartialEq, Eq, Clone, Error)]
192pub enum ExtensionInitializationError {
193    /// An extension function was defined by multiple extensions.
194    #[error(transparent)]
195    #[diagnostic(transparent)]
196    FuncMultiplyDefined(#[from] extension_initialization_errors::FuncMultiplyDefinedError),
197
198    /// Two extension constructors (in the same or different extensions) had
199    /// exactly the same type signature.  This is currently not allowed.
200    #[error(transparent)]
201    #[diagnostic(transparent)]
202    MultipleConstructorsSameSignature(
203        #[from] extension_initialization_errors::MultipleConstructorsSameSignatureError,
204    ),
205}
206
207/// Error subtypes for [`ExtensionInitializationError`]
208mod extension_initialization_errors {
209    use crate::{ast::Name, entities::SchemaType};
210    use miette::Diagnostic;
211    use thiserror::Error;
212
213    /// An extension function was defined by multiple extensions.
214    #[derive(Diagnostic, Debug, PartialEq, Eq, Clone, Error)]
215    #[error("extension function `{name}` is defined multiple times")]
216    pub struct FuncMultiplyDefinedError {
217        /// Name of the function that was multiply defined
218        pub(crate) name: Name,
219    }
220
221    /// Two extension constructors (in the same or different extensions) exist
222    /// for one extension type.  This is currently not allowed.
223    #[derive(Diagnostic, Debug, PartialEq, Eq, Clone, Error)]
224    #[error("multiple extension constructors for the same extension type {return_type}")]
225    pub struct MultipleConstructorsSameSignatureError {
226        /// return type of the shared constructor signature
227        pub(crate) return_type: Box<SchemaType>,
228    }
229}
230
231/// Errors thrown when looking up an extension function in [`Extensions`].
232//
233// CAUTION: this type is publicly exported in `cedar-policy`.
234// Don't make fields `pub`, don't make breaking changes, and use caution
235// when adding public methods.
236#[derive(Debug, PartialEq, Eq, Clone, Diagnostic, Error)]
237pub enum ExtensionFunctionLookupError {
238    /// Tried to call a function that doesn't exist
239    #[error(transparent)]
240    #[diagnostic(transparent)]
241    FuncDoesNotExist(#[from] extension_function_lookup_errors::FuncDoesNotExistError),
242}
243
244impl ExtensionFunctionLookupError {
245    pub(crate) fn source_loc(&self) -> Option<&Loc> {
246        match self {
247            Self::FuncDoesNotExist(e) => e.source_loc.as_ref(),
248        }
249    }
250
251    pub(crate) fn with_maybe_source_loc(self, source_loc: Option<Loc>) -> Self {
252        match self {
253            Self::FuncDoesNotExist(e) => {
254                Self::FuncDoesNotExist(extension_function_lookup_errors::FuncDoesNotExistError {
255                    source_loc,
256                    ..e
257                })
258            }
259        }
260    }
261}
262
263/// Error subtypes for [`ExtensionFunctionLookupError`]
264pub mod extension_function_lookup_errors {
265    use crate::ast::Name;
266    use crate::parser::Loc;
267    use miette::Diagnostic;
268    use thiserror::Error;
269
270    /// Tried to call a function that doesn't exist
271    //
272    // CAUTION: this type is publicly exported in `cedar-policy`.
273    // Don't make fields `pub`, don't make breaking changes, and use caution
274    // when adding public methods.
275    #[derive(Debug, PartialEq, Eq, Clone, Error)]
276    #[error("extension function `{name}` does not exist")]
277    pub struct FuncDoesNotExistError {
278        /// Name of the function that doesn't exist
279        pub(crate) name: Name,
280        /// Source location
281        pub(crate) source_loc: Option<Loc>,
282    }
283
284    impl Diagnostic for FuncDoesNotExistError {
285        impl_diagnostic_from_source_loc_opt_field!(source_loc);
286    }
287}
288
289/// Type alias for convenience
290pub type Result<T> = std::result::Result<T, ExtensionFunctionLookupError>;
291
292/// Utilities shared with the `cedar-policy-validator` extensions module.
293pub mod util {
294    use std::collections::{hash_map::Entry, HashMap};
295
296    /// Utility to build a `HashMap` of key value pairs from an iterator,
297    /// returning an `Err` result if there are any duplicate keys in the
298    /// iterator.
299    pub fn collect_no_duplicates<K, V>(
300        i: impl Iterator<Item = (K, V)>,
301    ) -> std::result::Result<HashMap<K, V>, K>
302    where
303        K: Clone + std::hash::Hash + Eq,
304    {
305        let mut map = HashMap::with_capacity(i.size_hint().0);
306        for (k, v) in i {
307            match map.entry(k) {
308                Entry::Occupied(occupied) => {
309                    return Err(occupied.key().clone());
310                }
311                Entry::Vacant(vacant) => {
312                    vacant.insert(v);
313                }
314            }
315        }
316        Ok(map)
317    }
318}
319
320#[cfg(test)]
321mod test {
322    use super::*;
323    use std::collections::HashSet;
324
325    #[test]
326    fn no_common_extension_function_names() {
327        // Our expr display must search for callstyle given a name, so
328        // no names can be used for both callstyles
329
330        // Test that names are all unique for ease of use.
331        // This overconstrains our current requirements, but shouldn't change
332        // until we identify a strong need.
333        let all_names: Vec<_> = Extensions::all_available()
334            .extensions
335            .iter()
336            .flat_map(|e| e.funcs().map(|f| f.name().clone()))
337            .collect();
338        let dedup_names: HashSet<_> = all_names.iter().collect();
339        assert_eq!(all_names.len(), dedup_names.len());
340    }
341}