1use super::{Expr, ExprKind};
18
19#[derive(Debug)]
22pub struct ExprIterator<'a, T = ()> {
23 expression_stack: Vec<&'a Expr<T>>,
29}
30
31impl<'a, T> ExprIterator<'a, T> {
32 pub fn new(expr: &'a Expr<T>) -> Self {
34 Self {
35 expression_stack: vec![expr],
36 }
37 }
38}
39
40impl<'a, T> Iterator for ExprIterator<'a, T> {
41 type Item = &'a Expr<T>;
42
43 fn next(&mut self) -> Option<Self::Item> {
44 let next_expr = self.expression_stack.pop()?;
45 match next_expr.expr_kind() {
46 ExprKind::Lit(_) => (),
47 ExprKind::Unknown(_) => (),
48 ExprKind::Slot(_) => (),
49 ExprKind::Var(_) => (),
50 ExprKind::If {
51 test_expr,
52 then_expr,
53 else_expr,
54 } => {
55 self.expression_stack.push(test_expr);
56 self.expression_stack.push(then_expr);
57 self.expression_stack.push(else_expr);
58 }
59 ExprKind::And { left, right } | ExprKind::Or { left, right } => {
60 self.expression_stack.push(left);
61 self.expression_stack.push(right);
62 }
63 ExprKind::UnaryApp { arg, .. } => {
64 self.expression_stack.push(arg);
65 }
66 ExprKind::BinaryApp { arg1, arg2, .. } => {
67 self.expression_stack.push(arg1);
68 self.expression_stack.push(arg2);
69 }
70 ExprKind::GetAttr { expr, attr: _ }
71 | ExprKind::HasAttr { expr, attr: _ }
72 | ExprKind::Like { expr, pattern: _ }
73 | ExprKind::Is {
74 expr,
75 entity_type: _,
76 } => {
77 self.expression_stack.push(expr);
78 }
79 ExprKind::ExtensionFunctionApp { args: exprs, .. } | ExprKind::Set(exprs) => {
80 self.expression_stack.extend(exprs.as_ref());
81 }
82 ExprKind::Record(map) => {
83 self.expression_stack.extend(map.values());
84 }
85 }
86 Some(next_expr)
87 }
88}
89
90#[cfg(test)]
91mod test {
92 use std::collections::HashSet;
93
94 use crate::ast::{BinaryOp, Expr, SlotId, UnaryOp, Var};
95
96 #[test]
97 fn literals() {
98 let e = Expr::val(true);
99 let v: HashSet<_> = e.subexpressions().collect();
100
101 assert_eq!(v.len(), 1);
102 assert!(v.contains(&Expr::val(true)));
103 }
104
105 #[test]
106 fn slots() {
107 let e = Expr::slot(SlotId::principal());
108 let v: HashSet<_> = e.subexpressions().collect();
109 assert_eq!(v.len(), 1);
110 assert!(v.contains(&Expr::slot(SlotId::principal())));
111 }
112
113 #[test]
114 fn variables() {
115 let e = Expr::var(Var::Principal);
116 let v: HashSet<_> = e.subexpressions().collect();
117 let s = HashSet::from([&e]);
118 assert_eq!(v, s);
119 }
120
121 #[test]
122 fn ite() {
123 let e = Expr::ite(Expr::val(true), Expr::val(false), Expr::val(0));
124 let v: HashSet<_> = e.subexpressions().collect();
125 assert_eq!(
126 v,
127 HashSet::from([&e, &Expr::val(true), &Expr::val(false), &Expr::val(0)])
128 );
129 }
130
131 #[test]
132 fn and() {
133 let e = Expr::and(Expr::val(1), Expr::val(false));
136 println!("{:?}", e);
137 let v: HashSet<_> = e.subexpressions().collect();
138 assert_eq!(v, HashSet::from([&e, &Expr::val(1), &Expr::val(false)]));
139 }
140
141 #[test]
142 fn or() {
143 let e = Expr::or(Expr::val(1), Expr::val(false));
146 let v: HashSet<_> = e.subexpressions().collect();
147 assert_eq!(v, HashSet::from([&e, &Expr::val(1), &Expr::val(false)]));
148 }
149
150 #[test]
151 fn unary() {
152 let e = Expr::unary_app(UnaryOp::Not, Expr::val(false));
153 assert_eq!(
154 e.subexpressions().collect::<HashSet<_>>(),
155 HashSet::from([&e, &Expr::val(false)])
156 );
157 }
158
159 #[test]
160 fn binary() {
161 let e = Expr::binary_app(BinaryOp::Eq, Expr::val(false), Expr::val(true));
162 assert_eq!(
163 e.subexpressions().collect::<HashSet<_>>(),
164 HashSet::from([&e, &Expr::val(false), &Expr::val(true)])
165 );
166 }
167
168 #[test]
169 fn ext() {
170 let e = Expr::call_extension_fn(
171 "test".parse().unwrap(),
172 vec![Expr::val(false), Expr::val(true)],
173 );
174 assert_eq!(
175 e.subexpressions().collect::<HashSet<_>>(),
176 HashSet::from([&e, &Expr::val(false), &Expr::val(true)])
177 );
178 }
179
180 #[test]
181 fn has_attr() {
182 let e = Expr::has_attr(Expr::val(false), "test".into());
183 assert_eq!(
184 e.subexpressions().collect::<HashSet<_>>(),
185 HashSet::from([&e, &Expr::val(false)])
186 );
187 }
188
189 #[test]
190 fn get_attr() {
191 let e = Expr::get_attr(Expr::val(false), "test".into());
192 assert_eq!(
193 e.subexpressions().collect::<HashSet<_>>(),
194 HashSet::from([&e, &Expr::val(false)])
195 );
196 }
197
198 #[test]
199 fn set() {
200 let e = Expr::set(vec![Expr::val(false), Expr::val(true)]);
201 assert_eq!(
202 e.subexpressions().collect::<HashSet<_>>(),
203 HashSet::from([&e, &Expr::val(false), &Expr::val(true)])
204 );
205 }
206
207 #[test]
208 fn set_duplicates() {
209 let e = Expr::set(vec![Expr::val(true), Expr::val(true)]);
210 let v: Vec<_> = e.subexpressions().collect();
211 assert_eq!(v.len(), 3);
212 assert!(v.contains(&&Expr::val(true)));
213 }
214
215 #[test]
216 fn record() {
217 let e = Expr::record(vec![
218 ("test".into(), Expr::val(true)),
219 ("another".into(), Expr::val(false)),
220 ])
221 .unwrap();
222 assert_eq!(
223 e.subexpressions().collect::<HashSet<_>>(),
224 HashSet::from([&e, &Expr::val(false), &Expr::val(true)])
225 );
226 }
227
228 #[test]
229 fn is() {
230 let e = Expr::is_entity_type(Expr::val(1), "T".parse().unwrap());
231 assert_eq!(
232 e.subexpressions().collect::<HashSet<_>>(),
233 HashSet::from([&e, &Expr::val(1)])
234 );
235 }
236
237 #[test]
238 fn duplicates() {
239 let e = Expr::ite(Expr::val(true), Expr::val(true), Expr::val(true));
240 let v: Vec<_> = e.subexpressions().collect();
241 assert_eq!(v.len(), 4);
242 assert!(v.contains(&&e));
243 assert!(v.contains(&&Expr::val(true)));
244 }
245
246 #[test]
247 fn deeply_nested() {
248 let e = Expr::get_attr(
249 Expr::get_attr(Expr::and(Expr::val(1), Expr::val(0)), "attr2".into()),
250 "attr1".into(),
251 );
252 let set: HashSet<_> = e.subexpressions().collect();
253 assert!(set.contains(&e));
254 assert!(set.contains(&Expr::get_attr(
255 Expr::and(Expr::val(1), Expr::val(0)),
256 "attr2".into()
257 )));
258 assert!(set.contains(&Expr::and(Expr::val(1), Expr::val(0))));
259 assert!(set.contains(&Expr::val(1)));
260 assert!(set.contains(&Expr::val(0)));
261 }
262}