cedar_policy_core/ast/
extension.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use crate::ast::*;
18use crate::entities::SchemaType;
19use crate::evaluator;
20use std::any::Any;
21use std::collections::{BTreeSet, HashMap};
22use std::fmt::Debug;
23use std::panic::{RefUnwindSafe, UnwindSafe};
24use std::sync::Arc;
25
26/// Cedar extension.
27///
28/// An extension can define new types and functions on those types. (Currently,
29/// there's nothing preventing an extension from defining new functions on
30/// built-in types, either, although we haven't discussed whether we want to
31/// allow this long-term.)
32pub struct Extension {
33    /// Name of the extension
34    name: Name,
35    /// Extension functions. These are legal to call in Cedar expressions.
36    functions: HashMap<Name, ExtensionFunction>,
37    /// Types with operator overloading
38    types_with_operator_overloading: BTreeSet<Name>,
39}
40
41impl Extension {
42    /// Create a new `Extension` with the given name and extension functions
43    pub fn new(
44        name: Name,
45        functions: impl IntoIterator<Item = ExtensionFunction>,
46        types_with_operator_overloading: impl IntoIterator<Item = Name>,
47    ) -> Self {
48        Self {
49            name,
50            functions: functions.into_iter().map(|f| (f.name.clone(), f)).collect(),
51            types_with_operator_overloading: types_with_operator_overloading.into_iter().collect(),
52        }
53    }
54
55    /// Get the name of the extension
56    pub fn name(&self) -> &Name {
57        &self.name
58    }
59
60    /// Look up a function by name, or return `None` if the extension doesn't
61    /// provide a function with that name
62    pub fn get_func(&self, name: &Name) -> Option<&ExtensionFunction> {
63        self.functions.get(name)
64    }
65
66    /// Iterate over the functions
67    pub fn funcs(&self) -> impl Iterator<Item = &ExtensionFunction> {
68        self.functions.values()
69    }
70
71    /// Iterate over the extension types that can be produced by any functions
72    /// in this extension
73    pub fn ext_types(&self) -> impl Iterator<Item = &Name> + '_ {
74        self.funcs().flat_map(|func| func.ext_types())
75    }
76
77    /// Iterate over extension types with operator overloading
78    pub fn types_with_operator_overloading(&self) -> impl Iterator<Item = &Name> + '_ {
79        self.types_with_operator_overloading.iter()
80    }
81}
82
83impl std::fmt::Debug for Extension {
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        write!(f, "<extension {}>", self.name())
86    }
87}
88
89/// The output of an extension call, either a value or an unknown
90#[derive(Debug, Clone)]
91pub enum ExtensionOutputValue {
92    /// A concrete value from an extension call
93    Known(Value),
94    /// An unknown returned from an extension call
95    Unknown(Unknown),
96}
97
98impl<T> From<T> for ExtensionOutputValue
99where
100    T: Into<Value>,
101{
102    fn from(v: T) -> Self {
103        ExtensionOutputValue::Known(v.into())
104    }
105}
106
107/// Which "style" is a function call
108#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
109#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
110pub enum CallStyle {
111    /// Function-style, eg foo(a, b)
112    FunctionStyle,
113    /// Method-style, eg a.foo(b)
114    MethodStyle,
115}
116
117// Note: we could use currying to make this a little nicer
118
119macro_rules! extension_function_object {
120    ( $( $tys:ty ), * ) => {
121        Box<dyn Fn($($tys,)*) -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static>
122    }
123}
124
125/// Trait object that implements the extension function call accepting any number of arguments.
126pub type ExtensionFunctionObject = extension_function_object!(&[Value]);
127/// Trait object that implements the extension function call accepting exactly 0 arguments
128pub type NullaryExtensionFunctionObject = extension_function_object!();
129/// Trait object that implements the extension function call accepting exactly 1 arguments
130pub type UnaryExtensionFunctionObject = extension_function_object!(&Value);
131/// Trait object that implements the extension function call accepting exactly 2 arguments
132pub type BinaryExtensionFunctionObject = extension_function_object!(&Value, &Value);
133/// Trait object that implements the extension function call accepting exactly 3 arguments
134pub type TernaryExtensionFunctionObject = extension_function_object!(&Value, &Value, &Value);
135
136/// Extension function. These can be called by the given `name` in Ceder
137/// expressions.
138pub struct ExtensionFunction {
139    /// Name of the function
140    name: Name,
141    /// Which `CallStyle` should be used when calling this function
142    style: CallStyle,
143    /// The actual function, which takes an `&[Value]` and returns a `Value`,
144    /// or an evaluation error
145    func: ExtensionFunctionObject,
146    /// The return type of this function, as a `SchemaType`. We require that
147    /// this be constant -- any given extension function must always return a
148    /// value of this `SchemaType`.
149    ///
150    /// `return_type` is `None` if and only if this function represents an
151    /// "unknown" value for partial evaluation. Such a function may only return
152    /// a fully unknown residual and may never return a value.
153    return_type: Option<SchemaType>,
154    /// The argument types that this function expects, as `SchemaType`s.
155    arg_types: Vec<SchemaType>,
156}
157
158impl ExtensionFunction {
159    /// Create a new `ExtensionFunction` taking any number of arguments
160    fn new(
161        name: Name,
162        style: CallStyle,
163        func: ExtensionFunctionObject,
164        return_type: Option<SchemaType>,
165        arg_types: Vec<SchemaType>,
166    ) -> Self {
167        Self {
168            name,
169            func,
170            style,
171            return_type,
172            arg_types,
173        }
174    }
175
176    /// Create a new `ExtensionFunction` taking no arguments
177    pub fn nullary(
178        name: Name,
179        style: CallStyle,
180        func: NullaryExtensionFunctionObject,
181        return_type: SchemaType,
182    ) -> Self {
183        Self::new(
184            name.clone(),
185            style,
186            Box::new(move |args: &[Value]| {
187                if args.is_empty() {
188                    func()
189                } else {
190                    Err(evaluator::EvaluationError::wrong_num_arguments(
191                        name.clone(),
192                        0,
193                        args.len(),
194                        None, // evaluator will add the source location later
195                    ))
196                }
197            }),
198            Some(return_type),
199            vec![],
200        )
201    }
202
203    /// Create a new `ExtensionFunction` to represent a function which is an
204    /// "unknown" in partial evaluation. Please don't use this for anything else.
205    pub fn partial_eval_unknown(
206        name: Name,
207        style: CallStyle,
208        func: UnaryExtensionFunctionObject,
209        arg_type: SchemaType,
210    ) -> Self {
211        Self::new(
212            name.clone(),
213            style,
214            Box::new(move |args: &[Value]| match args.first() {
215                Some(arg) => func(arg),
216                None => Err(evaluator::EvaluationError::wrong_num_arguments(
217                    name.clone(),
218                    1,
219                    args.len(),
220                    None, // evaluator will add the source location later
221                )),
222            }),
223            None,
224            vec![arg_type],
225        )
226    }
227
228    /// Create a new `ExtensionFunction` taking one argument
229    #[allow(clippy::type_complexity)]
230    pub fn unary(
231        name: Name,
232        style: CallStyle,
233        func: UnaryExtensionFunctionObject,
234        return_type: SchemaType,
235        arg_type: SchemaType,
236    ) -> Self {
237        Self::new(
238            name.clone(),
239            style,
240            Box::new(move |args: &[Value]| match &args {
241                &[arg] => func(arg),
242                _ => Err(evaluator::EvaluationError::wrong_num_arguments(
243                    name.clone(),
244                    1,
245                    args.len(),
246                    None, // evaluator will add the source location later
247                )),
248            }),
249            Some(return_type),
250            vec![arg_type],
251        )
252    }
253
254    /// Create a new `ExtensionFunction` taking two arguments
255    #[allow(clippy::type_complexity)]
256    pub fn binary(
257        name: Name,
258        style: CallStyle,
259        func: BinaryExtensionFunctionObject,
260        return_type: SchemaType,
261        arg_types: (SchemaType, SchemaType),
262    ) -> Self {
263        Self::new(
264            name.clone(),
265            style,
266            Box::new(move |args: &[Value]| match &args {
267                &[first, second] => func(first, second),
268                _ => Err(evaluator::EvaluationError::wrong_num_arguments(
269                    name.clone(),
270                    2,
271                    args.len(),
272                    None, // evaluator will add the source location later
273                )),
274            }),
275            Some(return_type),
276            vec![arg_types.0, arg_types.1],
277        )
278    }
279
280    /// Create a new `ExtensionFunction` taking three arguments
281    #[allow(clippy::type_complexity)]
282    pub fn ternary(
283        name: Name,
284        style: CallStyle,
285        func: TernaryExtensionFunctionObject,
286        return_type: SchemaType,
287        arg_types: (SchemaType, SchemaType, SchemaType),
288    ) -> Self {
289        Self::new(
290            name.clone(),
291            style,
292            Box::new(move |args: &[Value]| match &args {
293                &[first, second, third] => func(first, second, third),
294                _ => Err(evaluator::EvaluationError::wrong_num_arguments(
295                    name.clone(),
296                    3,
297                    args.len(),
298                    None, // evaluator will add the source location later
299                )),
300            }),
301            Some(return_type),
302            vec![arg_types.0, arg_types.1, arg_types.2],
303        )
304    }
305
306    /// Get the `Name` of the `ExtensionFunction`
307    pub fn name(&self) -> &Name {
308        &self.name
309    }
310
311    /// Get the `CallStyle` of the `ExtensionFunction`
312    pub fn style(&self) -> CallStyle {
313        self.style
314    }
315
316    /// Get the return type of the `ExtensionFunction`
317    /// `None` is returned exactly when this function represents an "unknown"
318    /// for partial evaluation.
319    pub fn return_type(&self) -> Option<&SchemaType> {
320        self.return_type.as_ref()
321    }
322
323    /// Get the argument types of the `ExtensionFunction`.
324    pub fn arg_types(&self) -> &[SchemaType] {
325        &self.arg_types
326    }
327
328    /// Returns `true` if this function is considered a "constructor".
329    ///
330    /// Currently, the only impact of this is that non-constructors are not
331    /// accessible in the JSON format (entities/json.rs).
332    pub fn is_constructor(&self) -> bool {
333        // return type is an extension type
334        matches!(self.return_type(), Some(SchemaType::Extension { .. }))
335        // no argument is an extension type
336        && !self.arg_types().iter().any(|ty| matches!(ty, SchemaType::Extension { .. }))
337    }
338
339    /// Call the `ExtensionFunction` with the given args
340    pub fn call(&self, args: &[Value]) -> evaluator::Result<PartialValue> {
341        match (self.func)(args)? {
342            ExtensionOutputValue::Known(v) => Ok(PartialValue::Value(v)),
343            ExtensionOutputValue::Unknown(u) => Ok(PartialValue::Residual(Expr::unknown(u))),
344        }
345    }
346
347    /// Iterate over the extension types that could be produced by this
348    /// function, if any
349    pub fn ext_types(&self) -> impl Iterator<Item = &Name> + '_ {
350        self.return_type
351            .iter()
352            .flat_map(|ret_ty| ret_ty.contained_ext_types())
353    }
354}
355
356impl std::fmt::Debug for ExtensionFunction {
357    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
358        write!(f, "<extension function {}>", self.name())
359    }
360}
361
362/// Extension value.
363///
364/// Anything implementing this trait can be used as a first-class value in
365/// Cedar. For instance, the `ipaddr` extension uses this mechanism
366/// to implement IPAddr as a Cedar first-class value.
367pub trait ExtensionValue: Debug + Send + Sync + UnwindSafe + RefUnwindSafe {
368    /// Get the name of the type of this value.
369    ///
370    /// Cedar has nominal typing, so two values have the same type iff they
371    /// return the same typename here.
372    fn typename(&self) -> Name;
373
374    /// If it supports operator overloading
375    fn supports_operator_overloading(&self) -> bool;
376}
377
378impl<V: ExtensionValue> StaticallyTyped for V {
379    fn type_of(&self) -> Type {
380        Type::Extension {
381            name: self.typename(),
382        }
383    }
384}
385
386#[derive(Debug, Clone)]
387/// Object container for extension values
388/// An extension value must be representable by a [`RestrictedExpr`]
389/// Specifically, it will be a function call `func` on `args`
390/// Note that `func` may not be the constructor. A counterexample is that a
391/// `datetime` is represented by an `offset` method call.
392/// Nevertheless, an invariant is that `eval(<func>(<args>)) == value`
393pub struct RepresentableExtensionValue {
394    pub(crate) func: Name,
395    pub(crate) args: Vec<RestrictedExpr>,
396    pub(crate) value: Arc<dyn InternalExtensionValue>,
397}
398
399impl RepresentableExtensionValue {
400    /// Create a new [`RepresentableExtensionValue`]
401    pub fn new(
402        value: Arc<dyn InternalExtensionValue + Send + Sync>,
403        func: Name,
404        args: Vec<RestrictedExpr>,
405    ) -> Self {
406        Self { value, func, args }
407    }
408
409    /// Get the internal value
410    pub fn value(&self) -> &(dyn InternalExtensionValue) {
411        self.value.as_ref()
412    }
413
414    /// Get the typename of this extension value
415    pub fn typename(&self) -> Name {
416        self.value.typename()
417    }
418
419    /// If this value supports operator overloading
420    pub(crate) fn supports_operator_overloading(&self) -> bool {
421        self.value.supports_operator_overloading()
422    }
423}
424
425impl From<RepresentableExtensionValue> for RestrictedExpr {
426    fn from(val: RepresentableExtensionValue) -> Self {
427        RestrictedExpr::call_extension_fn(val.func, val.args)
428    }
429}
430
431impl StaticallyTyped for RepresentableExtensionValue {
432    fn type_of(&self) -> Type {
433        self.value.type_of()
434    }
435}
436
437impl PartialEq for RepresentableExtensionValue {
438    fn eq(&self, other: &Self) -> bool {
439        // Values that are equal are equal regardless of which arguments made them
440        self.value.as_ref() == other.value.as_ref()
441    }
442}
443
444impl Eq for RepresentableExtensionValue {}
445
446impl PartialOrd for RepresentableExtensionValue {
447    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
448        Some(self.cmp(other))
449    }
450}
451
452impl Ord for RepresentableExtensionValue {
453    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
454        self.value.cmp(&other.value)
455    }
456}
457
458/// Extensions provide a type implementing `ExtensionValue`, `Eq`, and `Ord`.
459/// We automatically implement `InternalExtensionValue` for that type (with the
460/// impl below).  Internally, we use `dyn InternalExtensionValue` instead of
461/// `dyn ExtensionValue`.
462///
463/// You might wonder why we don't just have `ExtensionValue: Eq + Ord` and use
464/// `dyn ExtensionValue` everywhere.  The answer is that the Rust compiler
465/// doesn't let you because of
466/// [object safety](https://doc.rust-lang.org/reference/items/traits.html#object-safety).
467/// So instead we have this workaround where we define our own `equals_extvalue`
468/// method that compares not against `&Self` but against `&dyn InternalExtensionValue`,
469/// and likewise for `cmp_extvalue`.
470pub trait InternalExtensionValue: ExtensionValue {
471    /// convert to an `Any`
472    fn as_any(&self) -> &dyn Any;
473    /// this will be the basis for `PartialEq` on `InternalExtensionValue`; but
474    /// note the `&dyn` (normal `PartialEq` doesn't have the `dyn`)
475    fn equals_extvalue(&self, other: &dyn InternalExtensionValue) -> bool;
476    /// this will be the basis for `Ord` on `InternalExtensionValue`; but note
477    /// the `&dyn` (normal `Ord` doesn't have the `dyn`)
478    fn cmp_extvalue(&self, other: &dyn InternalExtensionValue) -> std::cmp::Ordering;
479}
480
481impl<V: 'static + Eq + Ord + ExtensionValue + Send + Sync + Clone> InternalExtensionValue for V {
482    fn as_any(&self) -> &dyn Any {
483        self
484    }
485
486    fn equals_extvalue(&self, other: &dyn InternalExtensionValue) -> bool {
487        other
488            .as_any()
489            .downcast_ref::<V>()
490            .map(|v| self == v)
491            .unwrap_or(false) // if the downcast failed, values are different types, so equality is false
492    }
493
494    fn cmp_extvalue(&self, other: &dyn InternalExtensionValue) -> std::cmp::Ordering {
495        other
496            .as_any()
497            .downcast_ref::<V>()
498            .map(|v| self.cmp(v))
499            .unwrap_or_else(|| {
500                // downcast failed, so values are different types.
501                // we fall back on the total ordering on typenames.
502                self.typename().cmp(&other.typename())
503            })
504    }
505}
506
507impl PartialEq for dyn InternalExtensionValue {
508    fn eq(&self, other: &Self) -> bool {
509        self.equals_extvalue(other)
510    }
511}
512
513impl Eq for dyn InternalExtensionValue {}
514
515impl PartialOrd for dyn InternalExtensionValue {
516    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
517        Some(self.cmp(other))
518    }
519}
520
521impl Ord for dyn InternalExtensionValue {
522    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
523        self.cmp_extvalue(other)
524    }
525}
526
527impl StaticallyTyped for dyn InternalExtensionValue {
528    fn type_of(&self) -> Type {
529        Type::Extension {
530            name: self.typename(),
531        }
532    }
533}