cedar_policy_core/est/
expr.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use super::FromJsonError;
18use crate::ast::{self, BoundedDisplay, EntityUID};
19use crate::entities::json::{
20    err::EscapeKind, err::JsonDeserializationError, err::JsonDeserializationErrorContext,
21    CedarValueJson, FnAndArg,
22};
23use crate::expr_builder::ExprBuilder;
24use crate::extensions::Extensions;
25use crate::jsonvalue::JsonValueWithNoDuplicateKeys;
26use crate::parser::cst_to_ast;
27use crate::parser::err::ParseErrors;
28use crate::parser::Node;
29use crate::parser::{cst, Loc};
30use itertools::Itertools;
31use serde::{de::Visitor, Deserialize, Serialize};
32use serde_with::serde_as;
33use smol_str::{SmolStr, ToSmolStr};
34use std::collections::{btree_map, BTreeMap, HashMap};
35use std::sync::Arc;
36
37/// Serde JSON structure for a Cedar expression in the EST format
38#[derive(Debug, Clone, PartialEq, Serialize)]
39#[serde(untagged)]
40#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
41#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
42pub enum Expr {
43    /// Any Cedar expression other than an extension function call.
44    ExprNoExt(ExprNoExt),
45    /// Extension function call, where the key is the name of an extension
46    /// function or method.
47    ExtFuncCall(ExtFuncCall),
48}
49
50// Manual implementation of `Deserialize` is more efficient than the derived
51// implementation with `serde(untagged)`. In particular, if the key is valid for
52// `ExprNoExt` but there is a deserialization problem within the corresponding
53// value, the derived implementation would backtrack and try to deserialize as
54// `ExtFuncCall` with that key as the extension function name, but this manual
55// implementation instead eagerly errors out, taking advantage of the fact that
56// none of the keys for `ExprNoExt` are valid extension function names.
57//
58// See #1284.
59impl<'de> Deserialize<'de> for Expr {
60    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
61    where
62        D: serde::Deserializer<'de>,
63    {
64        struct ExprVisitor;
65        impl<'de> Visitor<'de> for ExprVisitor {
66            type Value = Expr;
67            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68                formatter.write_str("JSON object representing an expression")
69            }
70
71            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
72            where
73                A: serde::de::MapAccess<'de>,
74            {
75                let (k, v): (SmolStr, JsonValueWithNoDuplicateKeys) = match map.next_entry()? {
76                    None => {
77                        return Err(serde::de::Error::custom(
78                            "empty map is not a valid expression",
79                        ))
80                    }
81                    Some((k, v)) => (k, v),
82                };
83                match map.next_key()? {
84                    None => (),
85                    Some(k2) => {
86                        let k2: SmolStr = k2;
87                        return Err(serde::de::Error::custom(format!("JSON object representing an `Expr` should have only one key, but found two keys: `{k}` and `{k2}`")));
88                    }
89                };
90                if cst_to_ast::is_known_extension_func_str(&k) {
91                    // `k` is the name of an extension function or method. We assume that
92                    // no such keys are valid keys for `ExprNoExt`, so we must parse as an
93                    // `ExtFuncCall`.
94                    let obj = serde_json::json!({ k: v });
95                    let extfunccall =
96                        serde_json::from_value(obj).map_err(serde::de::Error::custom)?;
97                    Ok(Expr::ExtFuncCall(extfunccall))
98                } else {
99                    // not a valid extension function or method, so we expect it
100                    // to work for `ExprNoExt`.
101                    let obj = serde_json::json!({ k: v });
102                    let exprnoext =
103                        serde_json::from_value(obj).map_err(serde::de::Error::custom)?;
104                    Ok(Expr::ExprNoExt(exprnoext))
105                }
106            }
107        }
108
109        deserializer.deserialize_map(ExprVisitor)
110    }
111}
112
113/// Represent an element of a pattern literal
114#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
115#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
116#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
117pub enum PatternElem {
118    /// The wildcard asterisk
119    Wildcard,
120    /// A string without any wildcards
121    Literal(SmolStr),
122}
123
124impl From<&[PatternElem]> for crate::ast::Pattern {
125    fn from(value: &[PatternElem]) -> Self {
126        let mut elems = Vec::new();
127        for elem in value {
128            match elem {
129                PatternElem::Wildcard => {
130                    elems.push(crate::ast::PatternElem::Wildcard);
131                }
132                PatternElem::Literal(s) => {
133                    elems.extend(s.chars().map(crate::ast::PatternElem::Char));
134                }
135            }
136        }
137        Self::from(elems)
138    }
139}
140
141impl From<crate::ast::PatternElem> for PatternElem {
142    fn from(value: crate::ast::PatternElem) -> Self {
143        match value {
144            crate::ast::PatternElem::Wildcard => Self::Wildcard,
145            crate::ast::PatternElem::Char(c) => Self::Literal(c.to_smolstr()),
146        }
147    }
148}
149
150impl From<crate::ast::Pattern> for Vec<PatternElem> {
151    fn from(value: crate::ast::Pattern) -> Self {
152        value.iter().map(|elem| (*elem).into()).collect()
153    }
154}
155
156/// Serde JSON structure for [any Cedar expression other than an extension
157/// function call] in the EST format
158#[serde_as]
159#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
160#[serde(deny_unknown_fields)]
161#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
162#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
163pub enum ExprNoExt {
164    /// Literal value (including anything that's legal to express in the
165    /// attribute-value JSON format)
166    Value(CedarValueJson),
167    /// Var
168    Var(ast::Var),
169    /// Template slot
170    Slot(#[cfg_attr(feature = "wasm", tsify(type = "string"))] ast::SlotId),
171    /// `!`
172    #[serde(rename = "!")]
173    Not {
174        /// Argument
175        arg: Arc<Expr>,
176    },
177    /// `-`
178    #[serde(rename = "neg")]
179    Neg {
180        /// Argument
181        arg: Arc<Expr>,
182    },
183    /// `==`
184    #[serde(rename = "==")]
185    Eq {
186        /// Left-hand argument
187        left: Arc<Expr>,
188        /// Right-hand argument
189        right: Arc<Expr>,
190    },
191    /// `!=`
192    #[serde(rename = "!=")]
193    NotEq {
194        /// Left-hand argument
195        left: Arc<Expr>,
196        /// Right-hand argument
197        right: Arc<Expr>,
198    },
199    /// `in`
200    #[serde(rename = "in")]
201    In {
202        /// Left-hand argument
203        left: Arc<Expr>,
204        /// Right-hand argument
205        right: Arc<Expr>,
206    },
207    /// `<`
208    #[serde(rename = "<")]
209    Less {
210        /// Left-hand argument
211        left: Arc<Expr>,
212        /// Right-hand argument
213        right: Arc<Expr>,
214    },
215    /// `<=`
216    #[serde(rename = "<=")]
217    LessEq {
218        /// Left-hand argument
219        left: Arc<Expr>,
220        /// Right-hand argument
221        right: Arc<Expr>,
222    },
223    /// `>`
224    #[serde(rename = ">")]
225    Greater {
226        /// Left-hand argument
227        left: Arc<Expr>,
228        /// Right-hand argument
229        right: Arc<Expr>,
230    },
231    /// `>=`
232    #[serde(rename = ">=")]
233    GreaterEq {
234        /// Left-hand argument
235        left: Arc<Expr>,
236        /// Right-hand argument
237        right: Arc<Expr>,
238    },
239    /// `&&`
240    #[serde(rename = "&&")]
241    And {
242        /// Left-hand argument
243        left: Arc<Expr>,
244        /// Right-hand argument
245        right: Arc<Expr>,
246    },
247    /// `||`
248    #[serde(rename = "||")]
249    Or {
250        /// Left-hand argument
251        left: Arc<Expr>,
252        /// Right-hand argument
253        right: Arc<Expr>,
254    },
255    /// `+`
256    #[serde(rename = "+")]
257    Add {
258        /// Left-hand argument
259        left: Arc<Expr>,
260        /// Right-hand argument
261        right: Arc<Expr>,
262    },
263    /// `-`
264    #[serde(rename = "-")]
265    Sub {
266        /// Left-hand argument
267        left: Arc<Expr>,
268        /// Right-hand argument
269        right: Arc<Expr>,
270    },
271    /// `*`
272    #[serde(rename = "*")]
273    Mul {
274        /// Left-hand argument
275        left: Arc<Expr>,
276        /// Right-hand argument
277        right: Arc<Expr>,
278    },
279    /// `contains()`
280    #[serde(rename = "contains")]
281    Contains {
282        /// Left-hand argument (receiver)
283        left: Arc<Expr>,
284        /// Right-hand argument (inside the `()`)
285        right: Arc<Expr>,
286    },
287    /// `containsAll()`
288    #[serde(rename = "containsAll")]
289    ContainsAll {
290        /// Left-hand argument (receiver)
291        left: Arc<Expr>,
292        /// Right-hand argument (inside the `()`)
293        right: Arc<Expr>,
294    },
295    /// `containsAny()`
296    #[serde(rename = "containsAny")]
297    ContainsAny {
298        /// Left-hand argument (receiver)
299        left: Arc<Expr>,
300        /// Right-hand argument (inside the `()`)
301        right: Arc<Expr>,
302    },
303    /// `isEmpty()`
304    #[serde(rename = "isEmpty")]
305    IsEmpty {
306        /// Argument
307        arg: Arc<Expr>,
308    },
309    /// `getTag()`
310    #[serde(rename = "getTag")]
311    GetTag {
312        /// Left-hand argument (receiver)
313        left: Arc<Expr>,
314        /// Right-hand argument (inside the `()`)
315        right: Arc<Expr>,
316    },
317    /// `hasTag()`
318    #[serde(rename = "hasTag")]
319    HasTag {
320        /// Left-hand argument (receiver)
321        left: Arc<Expr>,
322        /// Right-hand argument (inside the `()`)
323        right: Arc<Expr>,
324    },
325    /// Get-attribute
326    #[serde(rename = ".")]
327    GetAttr {
328        /// Left-hand argument
329        left: Arc<Expr>,
330        /// Attribute name
331        attr: SmolStr,
332    },
333    /// `has`
334    #[serde(rename = "has")]
335    HasAttr {
336        /// Left-hand argument
337        left: Arc<Expr>,
338        /// Attribute name
339        attr: SmolStr,
340    },
341    /// `like`
342    #[serde(rename = "like")]
343    Like {
344        /// Left-hand argument
345        left: Arc<Expr>,
346        /// Pattern
347        pattern: Vec<PatternElem>,
348    },
349    /// `<entity> is <entity_type> in <entity_or_entity_set> `
350    #[serde(rename = "is")]
351    Is {
352        /// Left-hand entity argument
353        left: Arc<Expr>,
354        /// Entity type
355        entity_type: SmolStr,
356        /// Entity or entity set
357        #[serde(skip_serializing_if = "Option::is_none")]
358        #[serde(rename = "in")]
359        in_expr: Option<Arc<Expr>>,
360    },
361    /// Ternary
362    #[serde(rename = "if-then-else")]
363    If {
364        /// Condition
365        #[serde(rename = "if")]
366        cond_expr: Arc<Expr>,
367        /// `then` expression
368        #[serde(rename = "then")]
369        then_expr: Arc<Expr>,
370        /// `else` expression
371        #[serde(rename = "else")]
372        else_expr: Arc<Expr>,
373    },
374    /// Set literal, whose elements may be arbitrary expressions
375    /// (which is why we need this case specifically and can't just
376    /// use Expr::Value)
377    Set(Vec<Expr>),
378    /// Record literal, whose elements may be arbitrary expressions
379    /// (which is why we need this case specifically and can't just
380    /// use Expr::Value)
381    Record(
382        #[serde_as(as = "serde_with::MapPreventDuplicates<_,_>")]
383        #[cfg_attr(feature = "wasm", tsify(type = "Record<string, Expr>"))]
384        BTreeMap<SmolStr, Expr>,
385    ),
386}
387
388/// Serde JSON structure for an extension function call in the EST format
389#[serde_as]
390#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
391#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
392#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
393pub struct ExtFuncCall {
394    /// maps the name of the function to a JSON list/array of the arguments.
395    /// Note that for method calls, the method receiver is the first argument.
396    /// For example, for `a.isInRange(b)`, the first argument is `a` and the
397    /// second argument is `b`.
398    ///
399    /// INVARIANT: This map should always have exactly one k-v pair (not more or
400    /// less), but we make it a map in order to get the correct JSON structure
401    /// we want.
402    #[serde(flatten)]
403    #[serde_as(as = "serde_with::MapPreventDuplicates<_,_>")]
404    #[cfg_attr(feature = "wasm", tsify(type = "Record<string, Array<Expr>>"))]
405    call: HashMap<SmolStr, Vec<Expr>>,
406}
407
408/// Construct an [`Expr`].
409#[derive(Clone, Debug)]
410pub struct Builder;
411
412impl ExprBuilder for Builder {
413    type Expr = Expr;
414
415    type Data = ();
416
417    fn with_data(_data: Self::Data) -> Self {
418        Self
419    }
420
421    fn with_maybe_source_loc(self, _: Option<&Loc>) -> Self {
422        self
423    }
424
425    fn loc(&self) -> Option<&Loc> {
426        None
427    }
428
429    fn data(&self) -> &Self::Data {
430        &()
431    }
432
433    /// literal
434    fn val(self, lit: impl Into<ast::Literal>) -> Expr {
435        Expr::ExprNoExt(ExprNoExt::Value(CedarValueJson::from_lit(lit.into())))
436    }
437
438    /// principal, action, resource, context
439    fn var(self, var: ast::Var) -> Expr {
440        Expr::ExprNoExt(ExprNoExt::Var(var))
441    }
442
443    /// Template slots
444    fn slot(self, slot: ast::SlotId) -> Expr {
445        Expr::ExprNoExt(ExprNoExt::Slot(slot))
446    }
447
448    /// An extension call with one arg, which is the name of the unknown
449    fn unknown(self, u: ast::Unknown) -> Expr {
450        Expr::ExtFuncCall(ExtFuncCall {
451            call: HashMap::from([("unknown".to_smolstr(), vec![Builder::new().val(u.name)])]),
452        })
453    }
454
455    /// `!`
456    fn not(self, e: Expr) -> Expr {
457        Expr::ExprNoExt(ExprNoExt::Not { arg: Arc::new(e) })
458    }
459
460    /// `-`
461    fn neg(self, e: Expr) -> Expr {
462        Expr::ExprNoExt(ExprNoExt::Neg { arg: Arc::new(e) })
463    }
464
465    /// `==`
466    fn is_eq(self, left: Expr, right: Expr) -> Expr {
467        Expr::ExprNoExt(ExprNoExt::Eq {
468            left: Arc::new(left),
469            right: Arc::new(right),
470        })
471    }
472
473    /// `!=`
474    fn noteq(self, left: Expr, right: Expr) -> Expr {
475        Expr::ExprNoExt(ExprNoExt::NotEq {
476            left: Arc::new(left),
477            right: Arc::new(right),
478        })
479    }
480
481    /// `in`
482    fn is_in(self, left: Expr, right: Expr) -> Expr {
483        Expr::ExprNoExt(ExprNoExt::In {
484            left: Arc::new(left),
485            right: Arc::new(right),
486        })
487    }
488
489    /// `<`
490    fn less(self, left: Expr, right: Expr) -> Expr {
491        Expr::ExprNoExt(ExprNoExt::Less {
492            left: Arc::new(left),
493            right: Arc::new(right),
494        })
495    }
496
497    /// `<=`
498    fn lesseq(self, left: Expr, right: Expr) -> Expr {
499        Expr::ExprNoExt(ExprNoExt::LessEq {
500            left: Arc::new(left),
501            right: Arc::new(right),
502        })
503    }
504
505    /// `>`
506    fn greater(self, left: Expr, right: Expr) -> Expr {
507        Expr::ExprNoExt(ExprNoExt::Greater {
508            left: Arc::new(left),
509            right: Arc::new(right),
510        })
511    }
512
513    /// `>=`
514    fn greatereq(self, left: Expr, right: Expr) -> Expr {
515        Expr::ExprNoExt(ExprNoExt::GreaterEq {
516            left: Arc::new(left),
517            right: Arc::new(right),
518        })
519    }
520
521    /// `&&`
522    fn and(self, left: Expr, right: Expr) -> Expr {
523        Expr::ExprNoExt(ExprNoExt::And {
524            left: Arc::new(left),
525            right: Arc::new(right),
526        })
527    }
528
529    /// `||`
530    fn or(self, left: Expr, right: Expr) -> Expr {
531        Expr::ExprNoExt(ExprNoExt::Or {
532            left: Arc::new(left),
533            right: Arc::new(right),
534        })
535    }
536
537    /// `+`
538    fn add(self, left: Expr, right: Expr) -> Expr {
539        Expr::ExprNoExt(ExprNoExt::Add {
540            left: Arc::new(left),
541            right: Arc::new(right),
542        })
543    }
544
545    /// `-`
546    fn sub(self, left: Expr, right: Expr) -> Expr {
547        Expr::ExprNoExt(ExprNoExt::Sub {
548            left: Arc::new(left),
549            right: Arc::new(right),
550        })
551    }
552
553    /// `*`
554    fn mul(self, left: Expr, right: Expr) -> Expr {
555        Expr::ExprNoExt(ExprNoExt::Mul {
556            left: Arc::new(left),
557            right: Arc::new(right),
558        })
559    }
560
561    /// `left.contains(right)`
562    fn contains(self, left: Expr, right: Expr) -> Expr {
563        Expr::ExprNoExt(ExprNoExt::Contains {
564            left: Arc::new(left),
565            right: Arc::new(right),
566        })
567    }
568
569    /// `left.containsAll(right)`
570    fn contains_all(self, left: Expr, right: Expr) -> Expr {
571        Expr::ExprNoExt(ExprNoExt::ContainsAll {
572            left: Arc::new(left),
573            right: Arc::new(right),
574        })
575    }
576
577    /// `left.containsAny(right)`
578    fn contains_any(self, left: Expr, right: Expr) -> Expr {
579        Expr::ExprNoExt(ExprNoExt::ContainsAny {
580            left: Arc::new(left),
581            right: Arc::new(right),
582        })
583    }
584
585    /// `arg.isEmpty()`
586    fn is_empty(self, expr: Expr) -> Expr {
587        Expr::ExprNoExt(ExprNoExt::IsEmpty {
588            arg: Arc::new(expr),
589        })
590    }
591
592    /// `left.getTag(right)`
593    fn get_tag(self, expr: Expr, tag: Expr) -> Expr {
594        Expr::ExprNoExt(ExprNoExt::GetTag {
595            left: Arc::new(expr),
596            right: Arc::new(tag),
597        })
598    }
599
600    /// `left.hasTag(right)`
601    fn has_tag(self, expr: Expr, tag: Expr) -> Expr {
602        Expr::ExprNoExt(ExprNoExt::HasTag {
603            left: Arc::new(expr),
604            right: Arc::new(tag),
605        })
606    }
607
608    /// `left.attr`
609    fn get_attr(self, expr: Expr, attr: SmolStr) -> Expr {
610        Expr::ExprNoExt(ExprNoExt::GetAttr {
611            left: Arc::new(expr),
612            attr,
613        })
614    }
615
616    /// `left has attr`
617    fn has_attr(self, expr: Expr, attr: SmolStr) -> Expr {
618        Expr::ExprNoExt(ExprNoExt::HasAttr {
619            left: Arc::new(expr),
620            attr,
621        })
622    }
623
624    /// `left like pattern`
625    fn like(self, expr: Expr, pattern: ast::Pattern) -> Expr {
626        Expr::ExprNoExt(ExprNoExt::Like {
627            left: Arc::new(expr),
628            pattern: pattern.into(),
629        })
630    }
631
632    /// `left is entity_type`
633    fn is_entity_type(self, left: Expr, entity_type: ast::EntityType) -> Expr {
634        Expr::ExprNoExt(ExprNoExt::Is {
635            left: Arc::new(left),
636            entity_type: entity_type.to_smolstr(),
637            in_expr: None,
638        })
639    }
640
641    /// `left is entity_type in entity`
642    fn is_in_entity_type(self, left: Expr, entity_type: ast::EntityType, entity: Expr) -> Expr {
643        Expr::ExprNoExt(ExprNoExt::Is {
644            left: Arc::new(left),
645            entity_type: entity_type.to_smolstr(),
646            in_expr: Some(Arc::new(entity)),
647        })
648    }
649
650    /// `if cond_expr then then_expr else else_expr`
651    fn ite(self, cond_expr: Expr, then_expr: Expr, else_expr: Expr) -> Expr {
652        Expr::ExprNoExt(ExprNoExt::If {
653            cond_expr: Arc::new(cond_expr),
654            then_expr: Arc::new(then_expr),
655            else_expr: Arc::new(else_expr),
656        })
657    }
658
659    /// e.g. [1+2, !(context has department)]
660    fn set(self, elements: impl IntoIterator<Item = Expr>) -> Expr {
661        Expr::ExprNoExt(ExprNoExt::Set(elements.into_iter().collect()))
662    }
663
664    /// e.g. {foo: 1+2, bar: !(context has department)}
665    fn record(
666        self,
667        map: impl IntoIterator<Item = (SmolStr, Expr)>,
668    ) -> Result<Expr, ast::ExpressionConstructionError> {
669        let mut dedup_map = BTreeMap::new();
670        for (k, v) in map {
671            match dedup_map.entry(k) {
672                btree_map::Entry::Occupied(oentry) => {
673                    return Err(ast::expression_construction_errors::DuplicateKeyError {
674                        key: oentry.key().clone(),
675                        context: "in record literal",
676                    }
677                    .into());
678                }
679                btree_map::Entry::Vacant(ventry) => {
680                    ventry.insert(v);
681                }
682            }
683        }
684        Ok(Expr::ExprNoExt(ExprNoExt::Record(dedup_map)))
685    }
686
687    /// extension function call, including method calls
688    fn call_extension_fn(self, fn_name: ast::Name, args: impl IntoIterator<Item = Expr>) -> Expr {
689        Expr::ExtFuncCall(ExtFuncCall {
690            call: HashMap::from([(fn_name.to_smolstr(), args.into_iter().collect())]),
691        })
692    }
693}
694
695impl Expr {
696    /// Consume the `Expr`, producing a string literal if it was a string literal, otherwise returns the literal in the `Err` variant.
697    pub fn into_string_literal(self) -> Result<SmolStr, Self> {
698        match self {
699            Expr::ExprNoExt(ExprNoExt::Value(CedarValueJson::String(s))) => Ok(s),
700            _ => Err(self),
701        }
702    }
703
704    /// Substitute entity literals
705    pub fn sub_entity_literals(
706        self,
707        mapping: &BTreeMap<EntityUID, EntityUID>,
708    ) -> Result<Self, JsonDeserializationError> {
709        match self.clone() {
710            Expr::ExprNoExt(e) => match e {
711                ExprNoExt::Value(v) => Ok(Expr::ExprNoExt(ExprNoExt::Value(
712                    v.sub_entity_literals(mapping)?,
713                ))),
714                ExprNoExt::Var(_) => Ok(self),
715                ExprNoExt::Slot(_) => Ok(self),
716                ExprNoExt::Not { arg } => Ok(Expr::ExprNoExt(ExprNoExt::Not {
717                    arg: Arc::new(Arc::unwrap_or_clone(arg).sub_entity_literals(mapping)?),
718                })),
719                ExprNoExt::Neg { arg } => Ok(Expr::ExprNoExt(ExprNoExt::Neg {
720                    arg: Arc::new(Arc::unwrap_or_clone(arg).sub_entity_literals(mapping)?),
721                })),
722                ExprNoExt::Eq { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::Eq {
723                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
724                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
725                })),
726                ExprNoExt::NotEq { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::NotEq {
727                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
728                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
729                })),
730                ExprNoExt::In { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::In {
731                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
732                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
733                })),
734                ExprNoExt::Less { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::Less {
735                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
736                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
737                })),
738                ExprNoExt::LessEq { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::LessEq {
739                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
740                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
741                })),
742                ExprNoExt::Greater { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::Greater {
743                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
744                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
745                })),
746                ExprNoExt::GreaterEq { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::GreaterEq {
747                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
748                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
749                })),
750                ExprNoExt::And { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::And {
751                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
752                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
753                })),
754                ExprNoExt::Or { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::Or {
755                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
756                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
757                })),
758                ExprNoExt::Add { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::Add {
759                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
760                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
761                })),
762                ExprNoExt::Sub { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::Sub {
763                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
764                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
765                })),
766                ExprNoExt::Mul { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::Mul {
767                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
768                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
769                })),
770                ExprNoExt::Contains { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::Contains {
771                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
772                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
773                })),
774                ExprNoExt::ContainsAll { left, right } => {
775                    Ok(Expr::ExprNoExt(ExprNoExt::ContainsAll {
776                        left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
777                        right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
778                    }))
779                }
780                ExprNoExt::ContainsAny { left, right } => {
781                    Ok(Expr::ExprNoExt(ExprNoExt::ContainsAny {
782                        left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
783                        right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
784                    }))
785                }
786                ExprNoExt::IsEmpty { arg } => Ok(Expr::ExprNoExt(ExprNoExt::IsEmpty {
787                    arg: Arc::new(Arc::unwrap_or_clone(arg).sub_entity_literals(mapping)?),
788                })),
789                ExprNoExt::GetTag { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::GetTag {
790                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
791                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
792                })),
793                ExprNoExt::HasTag { left, right } => Ok(Expr::ExprNoExt(ExprNoExt::HasTag {
794                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
795                    right: Arc::new(Arc::unwrap_or_clone(right).sub_entity_literals(mapping)?),
796                })),
797                ExprNoExt::GetAttr { left, attr } => Ok(Expr::ExprNoExt(ExprNoExt::GetAttr {
798                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
799                    attr,
800                })),
801                ExprNoExt::HasAttr { left, attr } => Ok(Expr::ExprNoExt(ExprNoExt::HasAttr {
802                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
803                    attr,
804                })),
805                ExprNoExt::Like { left, pattern } => Ok(Expr::ExprNoExt(ExprNoExt::Like {
806                    left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
807                    pattern,
808                })),
809                ExprNoExt::Is {
810                    left,
811                    entity_type,
812                    in_expr,
813                } => match in_expr {
814                    Some(in_expr) => Ok(Expr::ExprNoExt(ExprNoExt::Is {
815                        left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
816                        entity_type,
817                        in_expr: Some(Arc::new(
818                            Arc::unwrap_or_clone(in_expr).sub_entity_literals(mapping)?,
819                        )),
820                    })),
821                    None => Ok(Expr::ExprNoExt(ExprNoExt::Is {
822                        left: Arc::new(Arc::unwrap_or_clone(left).sub_entity_literals(mapping)?),
823                        entity_type,
824                        in_expr: None,
825                    })),
826                },
827                ExprNoExt::If {
828                    cond_expr,
829                    then_expr,
830                    else_expr,
831                } => Ok(Expr::ExprNoExt(ExprNoExt::If {
832                    cond_expr: Arc::new(
833                        Arc::unwrap_or_clone(cond_expr).sub_entity_literals(mapping)?,
834                    ),
835                    then_expr: Arc::new(
836                        Arc::unwrap_or_clone(then_expr).sub_entity_literals(mapping)?,
837                    ),
838                    else_expr: Arc::new(
839                        Arc::unwrap_or_clone(else_expr).sub_entity_literals(mapping)?,
840                    ),
841                })),
842                ExprNoExt::Set(v) => {
843                    let mut new_v = vec![];
844                    for e in v {
845                        new_v.push(e.sub_entity_literals(mapping)?);
846                    }
847                    Ok(Expr::ExprNoExt(ExprNoExt::Set(new_v)))
848                }
849                ExprNoExt::Record(m) => {
850                    let mut new_m = BTreeMap::new();
851                    for (k, v) in m {
852                        new_m.insert(k, v.sub_entity_literals(mapping)?);
853                    }
854                    Ok(Expr::ExprNoExt(ExprNoExt::Record(new_m)))
855                }
856            },
857            Expr::ExtFuncCall(e_fn_call) => {
858                let mut new_m = HashMap::new();
859                for (k, v) in e_fn_call.call {
860                    let mut new_v = vec![];
861                    for e in v {
862                        new_v.push(e.sub_entity_literals(mapping)?);
863                    }
864                    new_m.insert(k, new_v);
865                }
866                Ok(Expr::ExtFuncCall(ExtFuncCall { call: new_m }))
867            }
868        }
869    }
870}
871
872impl Expr {
873    /// Attempt to convert this `est::Expr` into an `ast::Expr`
874    ///
875    /// `id`: the ID of the policy this `Expr` belongs to, used only for reporting errors
876    pub fn try_into_ast(self, id: ast::PolicyID) -> Result<ast::Expr, FromJsonError> {
877        match self {
878            Expr::ExprNoExt(ExprNoExt::Value(jsonvalue)) => jsonvalue
879                .into_expr(|| JsonDeserializationErrorContext::Policy { id: id.clone() })
880                .map(Into::into)
881                .map_err(Into::into),
882            Expr::ExprNoExt(ExprNoExt::Var(var)) => Ok(ast::Expr::var(var)),
883            Expr::ExprNoExt(ExprNoExt::Slot(slot)) => Ok(ast::Expr::slot(slot)),
884            Expr::ExprNoExt(ExprNoExt::Not { arg }) => {
885                Ok(ast::Expr::not(Arc::unwrap_or_clone(arg).try_into_ast(id)?))
886            }
887            Expr::ExprNoExt(ExprNoExt::Neg { arg }) => {
888                Ok(ast::Expr::neg(Arc::unwrap_or_clone(arg).try_into_ast(id)?))
889            }
890            Expr::ExprNoExt(ExprNoExt::Eq { left, right }) => Ok(ast::Expr::is_eq(
891                Arc::unwrap_or_clone(left).try_into_ast(id.clone())?,
892                Arc::unwrap_or_clone(right).try_into_ast(id)?,
893            )),
894            Expr::ExprNoExt(ExprNoExt::NotEq { left, right }) => Ok(ast::Expr::noteq(
895                Arc::unwrap_or_clone(left).try_into_ast(id.clone())?,
896                Arc::unwrap_or_clone(right).try_into_ast(id)?,
897            )),
898            Expr::ExprNoExt(ExprNoExt::In { left, right }) => Ok(ast::Expr::is_in(
899                Arc::unwrap_or_clone(left).try_into_ast(id.clone())?,
900                Arc::unwrap_or_clone(right).try_into_ast(id)?,
901            )),
902            Expr::ExprNoExt(ExprNoExt::Less { left, right }) => Ok(ast::Expr::less(
903                Arc::unwrap_or_clone(left).try_into_ast(id.clone())?,
904                Arc::unwrap_or_clone(right).try_into_ast(id)?,
905            )),
906            Expr::ExprNoExt(ExprNoExt::LessEq { left, right }) => Ok(ast::Expr::lesseq(
907                Arc::unwrap_or_clone(left).try_into_ast(id.clone())?,
908                Arc::unwrap_or_clone(right).try_into_ast(id)?,
909            )),
910            Expr::ExprNoExt(ExprNoExt::Greater { left, right }) => Ok(ast::Expr::greater(
911                Arc::unwrap_or_clone(left).try_into_ast(id.clone())?,
912                Arc::unwrap_or_clone(right).try_into_ast(id)?,
913            )),
914            Expr::ExprNoExt(ExprNoExt::GreaterEq { left, right }) => Ok(ast::Expr::greatereq(
915                Arc::unwrap_or_clone(left).try_into_ast(id.clone())?,
916                Arc::unwrap_or_clone(right).try_into_ast(id)?,
917            )),
918            Expr::ExprNoExt(ExprNoExt::And { left, right }) => Ok(ast::Expr::and(
919                Arc::unwrap_or_clone(left).try_into_ast(id.clone())?,
920                Arc::unwrap_or_clone(right).try_into_ast(id)?,
921            )),
922            Expr::ExprNoExt(ExprNoExt::Or { left, right }) => Ok(ast::Expr::or(
923                Arc::unwrap_or_clone(left).try_into_ast(id.clone())?,
924                Arc::unwrap_or_clone(right).try_into_ast(id)?,
925            )),
926            Expr::ExprNoExt(ExprNoExt::Add { left, right }) => Ok(ast::Expr::add(
927                Arc::unwrap_or_clone(left).try_into_ast(id.clone())?,
928                Arc::unwrap_or_clone(right).try_into_ast(id)?,
929            )),
930            Expr::ExprNoExt(ExprNoExt::Sub { left, right }) => Ok(ast::Expr::sub(
931                Arc::unwrap_or_clone(left).try_into_ast(id.clone())?,
932                Arc::unwrap_or_clone(right).try_into_ast(id)?,
933            )),
934            Expr::ExprNoExt(ExprNoExt::Mul { left, right }) => Ok(ast::Expr::mul(
935                Arc::unwrap_or_clone(left).try_into_ast(id.clone())?,
936                Arc::unwrap_or_clone(right).try_into_ast(id)?,
937            )),
938            Expr::ExprNoExt(ExprNoExt::Contains { left, right }) => Ok(ast::Expr::contains(
939                Arc::unwrap_or_clone(left).try_into_ast(id.clone())?,
940                Arc::unwrap_or_clone(right).try_into_ast(id)?,
941            )),
942            Expr::ExprNoExt(ExprNoExt::ContainsAll { left, right }) => Ok(ast::Expr::contains_all(
943                Arc::unwrap_or_clone(left).try_into_ast(id.clone())?,
944                Arc::unwrap_or_clone(right).try_into_ast(id)?,
945            )),
946            Expr::ExprNoExt(ExprNoExt::ContainsAny { left, right }) => Ok(ast::Expr::contains_any(
947                Arc::unwrap_or_clone(left).try_into_ast(id.clone())?,
948                Arc::unwrap_or_clone(right).try_into_ast(id)?,
949            )),
950            Expr::ExprNoExt(ExprNoExt::IsEmpty { arg }) => Ok(ast::Expr::is_empty(
951                Arc::unwrap_or_clone(arg).try_into_ast(id)?,
952            )),
953            Expr::ExprNoExt(ExprNoExt::GetTag { left, right }) => Ok(ast::Expr::get_tag(
954                Arc::unwrap_or_clone(left).try_into_ast(id.clone())?,
955                Arc::unwrap_or_clone(right).try_into_ast(id)?,
956            )),
957            Expr::ExprNoExt(ExprNoExt::HasTag { left, right }) => Ok(ast::Expr::has_tag(
958                Arc::unwrap_or_clone(left).try_into_ast(id.clone())?,
959                Arc::unwrap_or_clone(right).try_into_ast(id)?,
960            )),
961            Expr::ExprNoExt(ExprNoExt::GetAttr { left, attr }) => Ok(ast::Expr::get_attr(
962                Arc::unwrap_or_clone(left).try_into_ast(id)?,
963                attr,
964            )),
965            Expr::ExprNoExt(ExprNoExt::HasAttr { left, attr }) => Ok(ast::Expr::has_attr(
966                Arc::unwrap_or_clone(left).try_into_ast(id)?,
967                attr,
968            )),
969            Expr::ExprNoExt(ExprNoExt::Like { left, pattern }) => Ok(ast::Expr::like(
970                Arc::unwrap_or_clone(left).try_into_ast(id)?,
971                crate::ast::Pattern::from(pattern.as_slice()),
972            )),
973            Expr::ExprNoExt(ExprNoExt::Is {
974                left,
975                entity_type,
976                in_expr,
977            }) => ast::EntityType::from_normalized_str(entity_type.as_str())
978                .map_err(FromJsonError::InvalidEntityType)
979                .and_then(|entity_type_name| {
980                    let left: ast::Expr = Arc::unwrap_or_clone(left).try_into_ast(id.clone())?;
981                    let is_expr = ast::Expr::is_entity_type(left.clone(), entity_type_name);
982                    match in_expr {
983                        // The AST doesn't have an `... is ... in ..` node, so
984                        // we represent it as a conjunction of `is` and `in`.
985                        Some(in_expr) => Ok(ast::Expr::and(
986                            is_expr,
987                            ast::Expr::is_in(left, Arc::unwrap_or_clone(in_expr).try_into_ast(id)?),
988                        )),
989                        None => Ok(is_expr),
990                    }
991                }),
992            Expr::ExprNoExt(ExprNoExt::If {
993                cond_expr,
994                then_expr,
995                else_expr,
996            }) => Ok(ast::Expr::ite(
997                Arc::unwrap_or_clone(cond_expr).try_into_ast(id.clone())?,
998                Arc::unwrap_or_clone(then_expr).try_into_ast(id.clone())?,
999                Arc::unwrap_or_clone(else_expr).try_into_ast(id)?,
1000            )),
1001            Expr::ExprNoExt(ExprNoExt::Set(elements)) => Ok(ast::Expr::set(
1002                elements
1003                    .into_iter()
1004                    .map(|el| el.try_into_ast(id.clone()))
1005                    .collect::<Result<Vec<_>, FromJsonError>>()?,
1006            )),
1007            Expr::ExprNoExt(ExprNoExt::Record(map)) => {
1008                // PANIC SAFETY: can't have duplicate keys here because the input was already a HashMap
1009                #[allow(clippy::expect_used)]
1010                Ok(ast::Expr::record(
1011                    map.into_iter()
1012                        .map(|(k, v)| Ok((k, v.try_into_ast(id.clone())?)))
1013                        .collect::<Result<HashMap<SmolStr, _>, FromJsonError>>()?,
1014                )
1015                .expect("can't have duplicate keys here because the input was already a HashMap"))
1016            }
1017            Expr::ExtFuncCall(ExtFuncCall { call }) => {
1018                match call.len() {
1019                    0 => Err(FromJsonError::MissingOperator),
1020                    1 => {
1021                        // PANIC SAFETY checked that `call.len() == 1`
1022                        #[allow(clippy::expect_used)]
1023                        let (fn_name, args) = call
1024                            .into_iter()
1025                            .next()
1026                            .expect("already checked that len was 1");
1027                        let fn_name: ast::Name = fn_name.parse().map_err(|errs| {
1028                            JsonDeserializationError::parse_escape(
1029                                EscapeKind::Extension,
1030                                fn_name,
1031                                errs,
1032                            )
1033                        })?;
1034                        if !cst_to_ast::is_known_extension_func_name(&fn_name) {
1035                            return Err(FromJsonError::UnknownExtensionFunction(fn_name));
1036                        }
1037                        Ok(ast::Expr::call_extension_fn(
1038                            fn_name,
1039                            args.into_iter()
1040                                .map(|arg| arg.try_into_ast(id.clone()))
1041                                .collect::<Result<_, _>>()?,
1042                        ))
1043                    }
1044                    _ => Err(FromJsonError::MultipleOperators {
1045                        ops: call.into_keys().collect(),
1046                    }),
1047                }
1048            }
1049        }
1050    }
1051}
1052
1053// PANIC SAFETY: See comment on `unwrap`
1054#[allow(clippy::fallible_impl_from)]
1055impl<T: Clone> From<ast::Expr<T>> for Expr {
1056    fn from(expr: ast::Expr<T>) -> Expr {
1057        match expr.into_expr_kind() {
1058            ast::ExprKind::Lit(lit) => lit.into(),
1059            ast::ExprKind::Var(var) => var.into(),
1060            ast::ExprKind::Slot(slot) => slot.into(),
1061            ast::ExprKind::Unknown(u) => Builder::new().unknown(u),
1062            ast::ExprKind::If {
1063                test_expr,
1064                then_expr,
1065                else_expr,
1066            } => Builder::new().ite(
1067                Arc::unwrap_or_clone(test_expr).into(),
1068                Arc::unwrap_or_clone(then_expr).into(),
1069                Arc::unwrap_or_clone(else_expr).into(),
1070            ),
1071            ast::ExprKind::And { left, right } => Builder::new().and(
1072                Arc::unwrap_or_clone(left).into(),
1073                Arc::unwrap_or_clone(right).into(),
1074            ),
1075            ast::ExprKind::Or { left, right } => Builder::new().or(
1076                Arc::unwrap_or_clone(left).into(),
1077                Arc::unwrap_or_clone(right).into(),
1078            ),
1079            ast::ExprKind::UnaryApp { op, arg } => {
1080                let arg = Arc::unwrap_or_clone(arg).into();
1081                Builder::new().unary_app(op, arg)
1082            }
1083            ast::ExprKind::BinaryApp { op, arg1, arg2 } => {
1084                let arg1 = Arc::unwrap_or_clone(arg1).into();
1085                let arg2 = Arc::unwrap_or_clone(arg2).into();
1086                Builder::new().binary_app(op, arg1, arg2)
1087            }
1088            ast::ExprKind::ExtensionFunctionApp { fn_name, args } => {
1089                let args = Arc::unwrap_or_clone(args).into_iter().map(Into::into);
1090                Builder::new().call_extension_fn(fn_name, args)
1091            }
1092            ast::ExprKind::GetAttr { expr, attr } => {
1093                Builder::new().get_attr(Arc::unwrap_or_clone(expr).into(), attr)
1094            }
1095            ast::ExprKind::HasAttr { expr, attr } => {
1096                Builder::new().has_attr(Arc::unwrap_or_clone(expr).into(), attr)
1097            }
1098            ast::ExprKind::Like { expr, pattern } => {
1099                Builder::new().like(Arc::unwrap_or_clone(expr).into(), pattern)
1100            }
1101            ast::ExprKind::Is { expr, entity_type } => {
1102                Builder::new().is_entity_type(Arc::unwrap_or_clone(expr).into(), entity_type)
1103            }
1104            ast::ExprKind::Set(set) => {
1105                Builder::new().set(Arc::unwrap_or_clone(set).into_iter().map(Into::into))
1106            }
1107            // PANIC SAFETY: `map` is a map, so it will not have duplicates keys, so the `record` constructor cannot error.
1108            #[allow(clippy::unwrap_used)]
1109            ast::ExprKind::Record(map) => Builder::new()
1110                .record(
1111                    Arc::unwrap_or_clone(map)
1112                        .into_iter()
1113                        .map(|(k, v)| (k, v.into())),
1114                )
1115                .unwrap(),
1116        }
1117    }
1118}
1119
1120impl From<ast::Literal> for Expr {
1121    fn from(lit: ast::Literal) -> Expr {
1122        Builder::new().val(lit)
1123    }
1124}
1125
1126impl From<ast::Var> for Expr {
1127    fn from(var: ast::Var) -> Expr {
1128        Builder::new().var(var)
1129    }
1130}
1131
1132impl From<ast::SlotId> for Expr {
1133    fn from(slot: ast::SlotId) -> Expr {
1134        Builder::new().slot(slot)
1135    }
1136}
1137
1138impl TryFrom<&Node<Option<cst::Expr>>> for Expr {
1139    type Error = ParseErrors;
1140    fn try_from(e: &Node<Option<cst::Expr>>) -> Result<Expr, ParseErrors> {
1141        e.to_expr::<Builder>()
1142    }
1143}
1144
1145impl std::fmt::Display for Expr {
1146    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1147        match self {
1148            Self::ExprNoExt(e) => write!(f, "{e}"),
1149            Self::ExtFuncCall(e) => write!(f, "{e}"),
1150        }
1151    }
1152}
1153
1154impl BoundedDisplay for Expr {
1155    fn fmt(&self, f: &mut impl std::fmt::Write, n: Option<usize>) -> std::fmt::Result {
1156        match self {
1157            Self::ExprNoExt(e) => BoundedDisplay::fmt(e, f, n),
1158            Self::ExtFuncCall(e) => BoundedDisplay::fmt(e, f, n),
1159        }
1160    }
1161}
1162
1163fn display_cedarvaluejson(
1164    f: &mut impl std::fmt::Write,
1165    v: &CedarValueJson,
1166    n: Option<usize>,
1167) -> std::fmt::Result {
1168    match v {
1169        // Add parentheses around negative numeric literals otherwise
1170        // round-tripping fuzzer fails for expressions like `(-1)["a"]`.
1171        CedarValueJson::Long(i) if *i < 0 => write!(f, "({i})"),
1172        CedarValueJson::Long(i) => write!(f, "{i}"),
1173        CedarValueJson::Bool(b) => write!(f, "{b}"),
1174        CedarValueJson::String(s) => write!(f, "\"{}\"", s.escape_debug()),
1175        CedarValueJson::EntityEscape { __entity } => {
1176            match ast::EntityUID::try_from(__entity.clone()) {
1177                Ok(euid) => write!(f, "{euid}"),
1178                Err(e) => write!(f, "(invalid entity uid: {})", e),
1179            }
1180        }
1181        CedarValueJson::ExprEscape { __expr } => write!(f, "({__expr})"),
1182        CedarValueJson::ExtnEscape {
1183            __extn: FnAndArg { ext_fn, arg },
1184        } => {
1185            // search for the name and callstyle
1186            let style = Extensions::all_available().all_funcs().find_map(|f| {
1187                if &f.name().to_string() == ext_fn {
1188                    Some(f.style())
1189                } else {
1190                    None
1191                }
1192            });
1193            match style {
1194                Some(ast::CallStyle::MethodStyle) => {
1195                    display_cedarvaluejson(f, arg, n)?;
1196                    write!(f, ".{ext_fn}()")?;
1197                    Ok(())
1198                }
1199                Some(ast::CallStyle::FunctionStyle) | None => {
1200                    write!(f, "{ext_fn}(")?;
1201                    display_cedarvaluejson(f, arg, n)?;
1202                    write!(f, ")")?;
1203                    Ok(())
1204                }
1205            }
1206        }
1207        CedarValueJson::Set(v) => {
1208            match n {
1209                Some(n) if v.len() > n => {
1210                    // truncate to n elements
1211                    write!(f, "[")?;
1212                    for val in v.iter().take(n) {
1213                        display_cedarvaluejson(f, val, Some(n))?;
1214                        write!(f, ", ")?;
1215                    }
1216                    write!(f, "..]")?;
1217                    Ok(())
1218                }
1219                _ => {
1220                    // no truncation
1221                    write!(f, "[")?;
1222                    for (i, val) in v.iter().enumerate() {
1223                        display_cedarvaluejson(f, val, n)?;
1224                        if i < v.len() - 1 {
1225                            write!(f, ", ")?;
1226                        }
1227                    }
1228                    write!(f, "]")?;
1229                    Ok(())
1230                }
1231            }
1232        }
1233        CedarValueJson::Record(r) => {
1234            match n {
1235                Some(n) if r.len() > n => {
1236                    // truncate to n key-value pairs
1237                    write!(f, "{{")?;
1238                    for (k, v) in r.iter().take(n) {
1239                        write!(f, "\"{}\": ", k.escape_debug())?;
1240                        display_cedarvaluejson(f, v, Some(n))?;
1241                        write!(f, ", ")?;
1242                    }
1243                    write!(f, "..}}")?;
1244                    Ok(())
1245                }
1246                _ => {
1247                    // no truncation
1248                    write!(f, "{{")?;
1249                    for (i, (k, v)) in r.iter().enumerate() {
1250                        write!(f, "\"{}\": ", k.escape_debug())?;
1251                        display_cedarvaluejson(f, v, n)?;
1252                        if i < r.len() - 1 {
1253                            write!(f, ", ")?;
1254                        }
1255                    }
1256                    write!(f, "}}")?;
1257                    Ok(())
1258                }
1259            }
1260        }
1261        CedarValueJson::Null => {
1262            write!(f, "null")?;
1263            Ok(())
1264        }
1265    }
1266}
1267
1268impl std::fmt::Display for ExprNoExt {
1269    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1270        BoundedDisplay::fmt_unbounded(self, f)
1271    }
1272}
1273
1274impl BoundedDisplay for ExprNoExt {
1275    fn fmt(&self, f: &mut impl std::fmt::Write, n: Option<usize>) -> std::fmt::Result {
1276        match &self {
1277            ExprNoExt::Value(v) => display_cedarvaluejson(f, v, n),
1278            ExprNoExt::Var(v) => write!(f, "{v}"),
1279            ExprNoExt::Slot(id) => write!(f, "{id}"),
1280            ExprNoExt::Not { arg } => {
1281                write!(f, "!")?;
1282                maybe_with_parens(f, arg, n)
1283            }
1284            ExprNoExt::Neg { arg } => {
1285                // Always add parentheses instead of calling
1286                // `maybe_with_parens`.
1287                // This makes sure that we always get a negation operation back
1288                // (as opposed to e.g., a negative number) when parsing the
1289                // printed form, thus preserving the round-tripping property.
1290                write!(f, "-({arg})")
1291            }
1292            ExprNoExt::Eq { left, right } => {
1293                maybe_with_parens(f, left, n)?;
1294                write!(f, " == ")?;
1295                maybe_with_parens(f, right, n)
1296            }
1297            ExprNoExt::NotEq { left, right } => {
1298                maybe_with_parens(f, left, n)?;
1299                write!(f, " != ")?;
1300                maybe_with_parens(f, right, n)
1301            }
1302            ExprNoExt::In { left, right } => {
1303                maybe_with_parens(f, left, n)?;
1304                write!(f, " in ")?;
1305                maybe_with_parens(f, right, n)
1306            }
1307            ExprNoExt::Less { left, right } => {
1308                maybe_with_parens(f, left, n)?;
1309                write!(f, " < ")?;
1310                maybe_with_parens(f, right, n)
1311            }
1312            ExprNoExt::LessEq { left, right } => {
1313                maybe_with_parens(f, left, n)?;
1314                write!(f, " <= ")?;
1315                maybe_with_parens(f, right, n)
1316            }
1317            ExprNoExt::Greater { left, right } => {
1318                maybe_with_parens(f, left, n)?;
1319                write!(f, " > ")?;
1320                maybe_with_parens(f, right, n)
1321            }
1322            ExprNoExt::GreaterEq { left, right } => {
1323                maybe_with_parens(f, left, n)?;
1324                write!(f, " >= ")?;
1325                maybe_with_parens(f, right, n)
1326            }
1327            ExprNoExt::And { left, right } => {
1328                maybe_with_parens(f, left, n)?;
1329                write!(f, " && ")?;
1330                maybe_with_parens(f, right, n)
1331            }
1332            ExprNoExt::Or { left, right } => {
1333                maybe_with_parens(f, left, n)?;
1334                write!(f, " || ")?;
1335                maybe_with_parens(f, right, n)
1336            }
1337            ExprNoExt::Add { left, right } => {
1338                maybe_with_parens(f, left, n)?;
1339                write!(f, " + ")?;
1340                maybe_with_parens(f, right, n)
1341            }
1342            ExprNoExt::Sub { left, right } => {
1343                maybe_with_parens(f, left, n)?;
1344                write!(f, " - ")?;
1345                maybe_with_parens(f, right, n)
1346            }
1347            ExprNoExt::Mul { left, right } => {
1348                maybe_with_parens(f, left, n)?;
1349                write!(f, " * ")?;
1350                maybe_with_parens(f, right, n)
1351            }
1352            ExprNoExt::Contains { left, right } => {
1353                maybe_with_parens(f, left, n)?;
1354                write!(f, ".contains({right})")
1355            }
1356            ExprNoExt::ContainsAll { left, right } => {
1357                maybe_with_parens(f, left, n)?;
1358                write!(f, ".containsAll({right})")
1359            }
1360            ExprNoExt::ContainsAny { left, right } => {
1361                maybe_with_parens(f, left, n)?;
1362                write!(f, ".containsAny({right})")
1363            }
1364            ExprNoExt::IsEmpty { arg } => {
1365                maybe_with_parens(f, arg, n)?;
1366                write!(f, ".isEmpty()")
1367            }
1368            ExprNoExt::GetTag { left, right } => {
1369                maybe_with_parens(f, left, n)?;
1370                write!(f, ".getTag({right})")
1371            }
1372            ExprNoExt::HasTag { left, right } => {
1373                maybe_with_parens(f, left, n)?;
1374                write!(f, ".hasTag({right})")
1375            }
1376            ExprNoExt::GetAttr { left, attr } => {
1377                maybe_with_parens(f, left, n)?;
1378                write!(f, "[\"{}\"]", attr.escape_debug())
1379            }
1380            ExprNoExt::HasAttr { left, attr } => {
1381                maybe_with_parens(f, left, n)?;
1382                write!(f, " has \"{}\"", attr.escape_debug())
1383            }
1384            ExprNoExt::Like { left, pattern } => {
1385                maybe_with_parens(f, left, n)?;
1386                write!(
1387                    f,
1388                    " like \"{}\"",
1389                    crate::ast::Pattern::from(pattern.as_slice())
1390                )
1391            }
1392            ExprNoExt::Is {
1393                left,
1394                entity_type,
1395                in_expr,
1396            } => {
1397                maybe_with_parens(f, left, n)?;
1398                write!(f, " is {entity_type}")?;
1399                match in_expr {
1400                    Some(in_expr) => {
1401                        write!(f, " in ")?;
1402                        maybe_with_parens(f, in_expr, n)
1403                    }
1404                    None => Ok(()),
1405                }
1406            }
1407            ExprNoExt::If {
1408                cond_expr,
1409                then_expr,
1410                else_expr,
1411            } => {
1412                write!(f, "if ")?;
1413                maybe_with_parens(f, cond_expr, n)?;
1414                write!(f, " then ")?;
1415                maybe_with_parens(f, then_expr, n)?;
1416                write!(f, " else ")?;
1417                maybe_with_parens(f, else_expr, n)
1418            }
1419            ExprNoExt::Set(v) => {
1420                match n {
1421                    Some(n) if v.len() > n => {
1422                        // truncate to n elements
1423                        write!(f, "[")?;
1424                        for element in v.iter().take(n) {
1425                            BoundedDisplay::fmt(element, f, Some(n))?;
1426                            write!(f, ", ")?;
1427                        }
1428                        write!(f, "..]")?;
1429                        Ok(())
1430                    }
1431                    _ => {
1432                        // no truncation
1433                        write!(f, "[")?;
1434                        for (i, element) in v.iter().enumerate() {
1435                            BoundedDisplay::fmt(element, f, n)?;
1436                            if i < v.len() - 1 {
1437                                write!(f, ", ")?;
1438                            }
1439                        }
1440                        write!(f, "]")?;
1441                        Ok(())
1442                    }
1443                }
1444            }
1445            ExprNoExt::Record(m) => {
1446                match n {
1447                    Some(n) if m.len() > n => {
1448                        // truncate to n key-value pairs
1449                        write!(f, "{{")?;
1450                        for (k, v) in m.iter().take(n) {
1451                            write!(f, "\"{}\": ", k.escape_debug())?;
1452                            BoundedDisplay::fmt(v, f, Some(n))?;
1453                            write!(f, ", ")?;
1454                        }
1455                        write!(f, "..}}")?;
1456                        Ok(())
1457                    }
1458                    _ => {
1459                        // no truncation
1460                        write!(f, "{{")?;
1461                        for (i, (k, v)) in m.iter().enumerate() {
1462                            write!(f, "\"{}\": ", k.escape_debug())?;
1463                            BoundedDisplay::fmt(v, f, n)?;
1464                            if i < m.len() - 1 {
1465                                write!(f, ", ")?;
1466                            }
1467                        }
1468                        write!(f, "}}")?;
1469                        Ok(())
1470                    }
1471                }
1472            }
1473        }
1474    }
1475}
1476
1477impl std::fmt::Display for ExtFuncCall {
1478    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1479        BoundedDisplay::fmt_unbounded(self, f)
1480    }
1481}
1482
1483impl BoundedDisplay for ExtFuncCall {
1484    fn fmt(&self, f: &mut impl std::fmt::Write, n: Option<usize>) -> std::fmt::Result {
1485        // PANIC SAFETY: safe due to INVARIANT on `ExtFuncCall`
1486        #[allow(clippy::unreachable)]
1487        let Some((fn_name, args)) = self.call.iter().next() else {
1488            unreachable!("invariant violated: empty ExtFuncCall")
1489        };
1490        // search for the name and callstyle
1491        let style = Extensions::all_available().all_funcs().find_map(|ext_fn| {
1492            if &ext_fn.name().to_string() == fn_name {
1493                Some(ext_fn.style())
1494            } else {
1495                None
1496            }
1497        });
1498        match (style, args.iter().next()) {
1499            (Some(ast::CallStyle::MethodStyle), Some(receiver)) => {
1500                maybe_with_parens(f, receiver, n)?;
1501                write!(f, ".{}({})", fn_name, args.iter().skip(1).join(", "))
1502            }
1503            (_, _) => {
1504                write!(f, "{}({})", fn_name, args.iter().join(", "))
1505            }
1506        }
1507    }
1508}
1509
1510/// returns the `BoundedDisplay` representation of the Expr, adding parens around
1511/// the entire string if necessary.
1512/// E.g., won't add parens for constants or `principal` etc, but will for things
1513/// like `(2 < 5)`.
1514/// When in doubt, add the parens.
1515fn maybe_with_parens(
1516    f: &mut impl std::fmt::Write,
1517    expr: &Expr,
1518    n: Option<usize>,
1519) -> std::fmt::Result {
1520    match expr {
1521        Expr::ExprNoExt(ExprNoExt::Set(_)) |
1522        Expr::ExprNoExt(ExprNoExt::Record(_)) |
1523        Expr::ExprNoExt(ExprNoExt::Value(_)) |
1524        Expr::ExprNoExt(ExprNoExt::Var(_)) |
1525        Expr::ExprNoExt(ExprNoExt::Slot(_)) => BoundedDisplay::fmt(expr, f, n),
1526
1527        // we want parens here because things like parse((!x).y)
1528        // would be printed into !x.y which has a different meaning
1529        Expr::ExprNoExt(ExprNoExt::Not { .. }) |
1530        // we want parens here because things like parse((-x).y)
1531        // would be printed into -x.y which has a different meaning
1532        Expr::ExprNoExt(ExprNoExt::Neg { .. })  |
1533        Expr::ExprNoExt(ExprNoExt::Eq { .. }) |
1534        Expr::ExprNoExt(ExprNoExt::NotEq { .. }) |
1535        Expr::ExprNoExt(ExprNoExt::In { .. }) |
1536        Expr::ExprNoExt(ExprNoExt::Less { .. }) |
1537        Expr::ExprNoExt(ExprNoExt::LessEq { .. }) |
1538        Expr::ExprNoExt(ExprNoExt::Greater { .. }) |
1539        Expr::ExprNoExt(ExprNoExt::GreaterEq { .. }) |
1540        Expr::ExprNoExt(ExprNoExt::And { .. }) |
1541        Expr::ExprNoExt(ExprNoExt::Or { .. }) |
1542        Expr::ExprNoExt(ExprNoExt::Add { .. }) |
1543        Expr::ExprNoExt(ExprNoExt::Sub { .. }) |
1544        Expr::ExprNoExt(ExprNoExt::Mul { .. }) |
1545        Expr::ExprNoExt(ExprNoExt::Contains { .. }) |
1546        Expr::ExprNoExt(ExprNoExt::ContainsAll { .. }) |
1547        Expr::ExprNoExt(ExprNoExt::ContainsAny { .. }) |
1548        Expr::ExprNoExt(ExprNoExt::IsEmpty { .. }) |
1549        Expr::ExprNoExt(ExprNoExt::GetAttr { .. }) |
1550        Expr::ExprNoExt(ExprNoExt::HasAttr { .. }) |
1551        Expr::ExprNoExt(ExprNoExt::GetTag { .. }) |
1552        Expr::ExprNoExt(ExprNoExt::HasTag { .. }) |
1553        Expr::ExprNoExt(ExprNoExt::Like { .. }) |
1554        Expr::ExprNoExt(ExprNoExt::Is { .. }) |
1555        Expr::ExprNoExt(ExprNoExt::If { .. }) |
1556        Expr::ExtFuncCall { .. } => {
1557            write!(f, "(")?;
1558            BoundedDisplay::fmt(expr, f, n)?;
1559            write!(f, ")")?;
1560            Ok(())
1561        }
1562    }
1563}
1564
1565#[cfg(test)]
1566// PANIC SAFETY: this is unit test code
1567#[allow(clippy::indexing_slicing)]
1568// PANIC SAFETY: Unit Test Code
1569#[allow(clippy::panic)]
1570mod test {
1571    use crate::parser::{
1572        err::{ParseError, ToASTErrorKind},
1573        parse_expr,
1574    };
1575
1576    use super::*;
1577    use ast::BoundedToString;
1578    use cool_asserts::assert_matches;
1579
1580    #[test]
1581    fn test_invalid_expr_from_cst_name() {
1582        let e = crate::parser::text_to_cst::parse_expr("some_long_str::else").unwrap();
1583        assert_matches!(Expr::try_from(&e), Err(e) => {
1584            assert!(e.len() == 1);
1585            assert_matches!(&e[0],
1586                ParseError::ToAST(to_ast_error) => {
1587                    assert_matches!(to_ast_error.kind(), ToASTErrorKind::ReservedIdentifier(s) => {
1588                        assert_eq!(s.to_string(), "else");
1589                    });
1590                }
1591            );
1592        });
1593    }
1594
1595    #[test]
1596    fn display_and_bounded_display() {
1597        let expr = Expr::from(parse_expr(r#"[100, [3, 4, 5], -20, "foo"]"#).unwrap());
1598        assert_eq!(format!("{expr}"), r#"[100, [3, 4, 5], (-20), "foo"]"#);
1599        assert_eq!(
1600            BoundedToString::to_string(&expr, None),
1601            r#"[100, [3, 4, 5], (-20), "foo"]"#
1602        );
1603        assert_eq!(
1604            BoundedToString::to_string(&expr, Some(4)),
1605            r#"[100, [3, 4, 5], (-20), "foo"]"#
1606        );
1607        assert_eq!(
1608            BoundedToString::to_string(&expr, Some(3)),
1609            r#"[100, [3, 4, 5], (-20), ..]"#
1610        );
1611        assert_eq!(
1612            BoundedToString::to_string(&expr, Some(2)),
1613            r#"[100, [3, 4, ..], ..]"#
1614        );
1615        assert_eq!(BoundedToString::to_string(&expr, Some(1)), r#"[100, ..]"#);
1616        assert_eq!(BoundedToString::to_string(&expr, Some(0)), r#"[..]"#);
1617
1618        let expr = Expr::from(
1619            parse_expr(
1620                r#"{
1621            a: 12,
1622            b: [3, 4, true],
1623            c: -20,
1624            "hello ∞ world": "∂µß≈¥"
1625        }"#,
1626            )
1627            .unwrap(),
1628        );
1629        assert_eq!(
1630            format!("{expr}"),
1631            r#"{"a": 12, "b": [3, 4, true], "c": (-20), "hello ∞ world": "∂µß≈¥"}"#
1632        );
1633        assert_eq!(
1634            BoundedToString::to_string(&expr, None),
1635            r#"{"a": 12, "b": [3, 4, true], "c": (-20), "hello ∞ world": "∂µß≈¥"}"#
1636        );
1637        assert_eq!(
1638            BoundedToString::to_string(&expr, Some(4)),
1639            r#"{"a": 12, "b": [3, 4, true], "c": (-20), "hello ∞ world": "∂µß≈¥"}"#
1640        );
1641        assert_eq!(
1642            BoundedToString::to_string(&expr, Some(3)),
1643            r#"{"a": 12, "b": [3, 4, true], "c": (-20), ..}"#
1644        );
1645        assert_eq!(
1646            BoundedToString::to_string(&expr, Some(2)),
1647            r#"{"a": 12, "b": [3, 4, ..], ..}"#
1648        );
1649        assert_eq!(
1650            BoundedToString::to_string(&expr, Some(1)),
1651            r#"{"a": 12, ..}"#
1652        );
1653        assert_eq!(BoundedToString::to_string(&expr, Some(0)), r#"{..}"#);
1654    }
1655}