cedar_policy_core/ast/
expr_iterator.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use super::{Expr, ExprKind};
18
19/// This structure implements the iterator used to traverse subexpressions of an
20/// expression.
21#[derive(Debug)]
22pub struct ExprIterator<'a, T = ()> {
23    /// The stack of expressions that need to be visited. To get the next
24    /// expression, the iterator will pop from the stack. If the stack is empty,
25    /// then the iterator is finished. Otherwise, any subexpressions of that
26    /// expression are then pushed onto the stack, and the popped expression is
27    /// returned.
28    expression_stack: Vec<&'a Expr<T>>,
29}
30
31impl<'a, T> ExprIterator<'a, T> {
32    /// Construct an expr iterator
33    pub fn new(expr: &'a Expr<T>) -> Self {
34        Self {
35            expression_stack: vec![expr],
36        }
37    }
38}
39
40impl<'a, T> Iterator for ExprIterator<'a, T> {
41    type Item = &'a Expr<T>;
42
43    fn next(&mut self) -> Option<Self::Item> {
44        let next_expr = self.expression_stack.pop()?;
45        match next_expr.expr_kind() {
46            ExprKind::Lit(_) => (),
47            ExprKind::Unknown(_) => (),
48            ExprKind::Slot(_) => (),
49            ExprKind::Var(_) => (),
50            ExprKind::If {
51                test_expr,
52                then_expr,
53                else_expr,
54            } => {
55                self.expression_stack.push(test_expr);
56                self.expression_stack.push(then_expr);
57                self.expression_stack.push(else_expr);
58            }
59            ExprKind::And { left, right } | ExprKind::Or { left, right } => {
60                self.expression_stack.push(left);
61                self.expression_stack.push(right);
62            }
63            ExprKind::UnaryApp { arg, .. } => {
64                self.expression_stack.push(arg);
65            }
66            ExprKind::BinaryApp { arg1, arg2, .. } => {
67                self.expression_stack.push(arg1);
68                self.expression_stack.push(arg2);
69            }
70            ExprKind::GetAttr { expr, attr: _ }
71            | ExprKind::HasAttr { expr, attr: _ }
72            | ExprKind::Like { expr, pattern: _ }
73            | ExprKind::Is {
74                expr,
75                entity_type: _,
76            } => {
77                self.expression_stack.push(expr);
78            }
79            ExprKind::ExtensionFunctionApp { args: exprs, .. } | ExprKind::Set(exprs) => {
80                self.expression_stack.extend(exprs.as_ref());
81            }
82            ExprKind::Record(map) => {
83                self.expression_stack.extend(map.values());
84            }
85        }
86        Some(next_expr)
87    }
88}
89
90#[cfg(test)]
91mod test {
92    use std::collections::HashSet;
93
94    use crate::ast::{BinaryOp, Expr, SlotId, UnaryOp, Var};
95
96    #[test]
97    fn literals() {
98        let e = Expr::val(true);
99        let v: HashSet<_> = e.subexpressions().collect();
100
101        assert_eq!(v.len(), 1);
102        assert!(v.contains(&Expr::val(true)));
103    }
104
105    #[test]
106    fn slots() {
107        let e = Expr::slot(SlotId::principal());
108        let v: HashSet<_> = e.subexpressions().collect();
109        assert_eq!(v.len(), 1);
110        assert!(v.contains(&Expr::slot(SlotId::principal())));
111    }
112
113    #[test]
114    fn variables() {
115        let e = Expr::var(Var::Principal);
116        let v: HashSet<_> = e.subexpressions().collect();
117        let s = HashSet::from([&e]);
118        assert_eq!(v, s);
119    }
120
121    #[test]
122    fn ite() {
123        let e = Expr::ite(Expr::val(true), Expr::val(false), Expr::val(0));
124        let v: HashSet<_> = e.subexpressions().collect();
125        assert_eq!(
126            v,
127            HashSet::from([&e, &Expr::val(true), &Expr::val(false), &Expr::val(0)])
128        );
129    }
130
131    #[test]
132    fn and() {
133        // Using `1 && false` because `true && false` would be simplified to
134        // `false` by `Expr::and`.
135        let e = Expr::and(Expr::val(1), Expr::val(false));
136        println!("{:?}", e);
137        let v: HashSet<_> = e.subexpressions().collect();
138        assert_eq!(v, HashSet::from([&e, &Expr::val(1), &Expr::val(false)]));
139    }
140
141    #[test]
142    fn or() {
143        // Using `1 || false` because `true || false` would be simplified to
144        // `true` by `Expr::or`.
145        let e = Expr::or(Expr::val(1), Expr::val(false));
146        let v: HashSet<_> = e.subexpressions().collect();
147        assert_eq!(v, HashSet::from([&e, &Expr::val(1), &Expr::val(false)]));
148    }
149
150    #[test]
151    fn unary() {
152        let e = Expr::unary_app(UnaryOp::Not, Expr::val(false));
153        assert_eq!(
154            e.subexpressions().collect::<HashSet<_>>(),
155            HashSet::from([&e, &Expr::val(false)])
156        );
157    }
158
159    #[test]
160    fn binary() {
161        let e = Expr::binary_app(BinaryOp::Eq, Expr::val(false), Expr::val(true));
162        assert_eq!(
163            e.subexpressions().collect::<HashSet<_>>(),
164            HashSet::from([&e, &Expr::val(false), &Expr::val(true)])
165        );
166    }
167
168    #[test]
169    fn ext() {
170        let e = Expr::call_extension_fn(
171            "test".parse().unwrap(),
172            vec![Expr::val(false), Expr::val(true)],
173        );
174        assert_eq!(
175            e.subexpressions().collect::<HashSet<_>>(),
176            HashSet::from([&e, &Expr::val(false), &Expr::val(true)])
177        );
178    }
179
180    #[test]
181    fn has_attr() {
182        let e = Expr::has_attr(Expr::val(false), "test".into());
183        assert_eq!(
184            e.subexpressions().collect::<HashSet<_>>(),
185            HashSet::from([&e, &Expr::val(false)])
186        );
187    }
188
189    #[test]
190    fn get_attr() {
191        let e = Expr::get_attr(Expr::val(false), "test".into());
192        assert_eq!(
193            e.subexpressions().collect::<HashSet<_>>(),
194            HashSet::from([&e, &Expr::val(false)])
195        );
196    }
197
198    #[test]
199    fn set() {
200        let e = Expr::set(vec![Expr::val(false), Expr::val(true)]);
201        assert_eq!(
202            e.subexpressions().collect::<HashSet<_>>(),
203            HashSet::from([&e, &Expr::val(false), &Expr::val(true)])
204        );
205    }
206
207    #[test]
208    fn set_duplicates() {
209        let e = Expr::set(vec![Expr::val(true), Expr::val(true)]);
210        let v: Vec<_> = e.subexpressions().collect();
211        assert_eq!(v.len(), 3);
212        assert!(v.contains(&&Expr::val(true)));
213    }
214
215    #[test]
216    fn record() {
217        let e = Expr::record(vec![
218            ("test".into(), Expr::val(true)),
219            ("another".into(), Expr::val(false)),
220        ])
221        .unwrap();
222        assert_eq!(
223            e.subexpressions().collect::<HashSet<_>>(),
224            HashSet::from([&e, &Expr::val(false), &Expr::val(true)])
225        );
226    }
227
228    #[test]
229    fn is() {
230        let e = Expr::is_entity_type(Expr::val(1), "T".parse().unwrap());
231        assert_eq!(
232            e.subexpressions().collect::<HashSet<_>>(),
233            HashSet::from([&e, &Expr::val(1)])
234        );
235    }
236
237    #[test]
238    fn duplicates() {
239        let e = Expr::ite(Expr::val(true), Expr::val(true), Expr::val(true));
240        let v: Vec<_> = e.subexpressions().collect();
241        assert_eq!(v.len(), 4);
242        assert!(v.contains(&&e));
243        assert!(v.contains(&&Expr::val(true)));
244    }
245
246    #[test]
247    fn deeply_nested() {
248        let e = Expr::get_attr(
249            Expr::get_attr(Expr::and(Expr::val(1), Expr::val(0)), "attr2".into()),
250            "attr1".into(),
251        );
252        let set: HashSet<_> = e.subexpressions().collect();
253        assert!(set.contains(&e));
254        assert!(set.contains(&Expr::get_attr(
255            Expr::and(Expr::val(1), Expr::val(0)),
256            "attr2".into()
257        )));
258        assert!(set.contains(&Expr::and(Expr::val(1), Expr::val(0))));
259        assert!(set.contains(&Expr::val(1)));
260        assert!(set.contains(&Expr::val(0)));
261    }
262}