use super::{Expr, ExprKind};
#[derive(Debug)]
pub struct ExprIterator<'a, T = ()> {
expression_stack: Vec<&'a Expr<T>>,
}
impl<'a, T> ExprIterator<'a, T> {
pub fn new(expr: &'a Expr<T>) -> Self {
Self {
expression_stack: vec![expr],
}
}
}
impl<'a, T> Iterator for ExprIterator<'a, T> {
type Item = &'a Expr<T>;
fn next(&mut self) -> Option<Self::Item> {
let next_expr = self.expression_stack.pop()?;
match next_expr.expr_kind() {
ExprKind::Lit(_) => (),
ExprKind::Unknown { .. } => (),
ExprKind::Slot(_) => (),
ExprKind::Var(_) => (),
ExprKind::If {
test_expr,
then_expr,
else_expr,
} => {
self.expression_stack.push(test_expr);
self.expression_stack.push(then_expr);
self.expression_stack.push(else_expr);
}
ExprKind::And { left, right } => {
self.expression_stack.push(left);
self.expression_stack.push(right);
}
ExprKind::Or { left, right } => {
self.expression_stack.push(left);
self.expression_stack.push(right);
}
ExprKind::UnaryApp { arg, .. } => {
self.expression_stack.push(arg);
}
ExprKind::BinaryApp { arg1, arg2, .. } => {
self.expression_stack.push(arg1);
self.expression_stack.push(arg2);
}
ExprKind::ExtensionFunctionApp { args, .. } => {
for arg in args.as_ref() {
self.expression_stack.push(arg);
}
}
ExprKind::GetAttr { expr, attr: _ } => {
self.expression_stack.push(expr);
}
ExprKind::HasAttr { expr, attr: _ } => {
self.expression_stack.push(expr);
}
ExprKind::Like { expr, pattern: _ } => {
self.expression_stack.push(expr);
}
ExprKind::Set(elems) => {
for expr in elems.as_ref() {
self.expression_stack.push(expr);
}
}
ExprKind::Record { pairs } => {
for (_, val_expr) in pairs.as_ref() {
self.expression_stack.push(val_expr);
}
}
}
Some(next_expr)
}
}
#[cfg(test)]
mod test {
use std::collections::HashSet;
use crate::ast::{BinaryOp, Expr, SlotId, UnaryOp, Var};
#[test]
fn literals() {
let e = Expr::val(true);
let v: HashSet<_> = e.subexpressions().collect();
assert_eq!(v.len(), 1);
assert!(v.contains(&Expr::val(true)));
}
#[test]
fn slots() {
let e = Expr::slot(SlotId::principal());
let v: HashSet<_> = e.subexpressions().collect();
assert_eq!(v.len(), 1);
assert!(v.contains(&Expr::slot(SlotId::principal())));
}
#[test]
fn variables() {
let e = Expr::var(Var::Principal);
let v: HashSet<_> = e.subexpressions().collect();
let s = HashSet::from([&e]);
assert_eq!(v, s);
}
#[test]
fn ite() {
let e = Expr::ite(Expr::val(true), Expr::val(false), Expr::val(0));
let v: HashSet<_> = e.subexpressions().collect();
assert_eq!(
v,
HashSet::from([&e, &Expr::val(true), &Expr::val(false), &Expr::val(0)])
);
}
#[test]
fn and() {
let e = Expr::and(Expr::val(1), Expr::val(false));
println!("{:?}", e);
let v: HashSet<_> = e.subexpressions().collect();
assert_eq!(v, HashSet::from([&e, &Expr::val(1), &Expr::val(false)]));
}
#[test]
fn or() {
let e = Expr::or(Expr::val(1), Expr::val(false));
let v: HashSet<_> = e.subexpressions().collect();
assert_eq!(v, HashSet::from([&e, &Expr::val(1), &Expr::val(false)]));
}
#[test]
fn unary() {
let e = Expr::unary_app(UnaryOp::Not, Expr::val(false));
assert_eq!(
e.subexpressions().collect::<HashSet<_>>(),
HashSet::from([&e, &Expr::val(false)])
);
}
#[test]
fn binary() {
let e = Expr::binary_app(BinaryOp::Eq, Expr::val(false), Expr::val(true));
assert_eq!(
e.subexpressions().collect::<HashSet<_>>(),
HashSet::from([&e, &Expr::val(false), &Expr::val(true)])
);
}
#[test]
fn ext() {
let e = Expr::call_extension_fn(
"test".parse().unwrap(),
vec![Expr::val(false), Expr::val(true)],
);
assert_eq!(
e.subexpressions().collect::<HashSet<_>>(),
HashSet::from([&e, &Expr::val(false), &Expr::val(true)])
);
}
#[test]
fn has_attr() {
let e = Expr::has_attr(Expr::val(false), "test".into());
assert_eq!(
e.subexpressions().collect::<HashSet<_>>(),
HashSet::from([&e, &Expr::val(false)])
);
}
#[test]
fn get_attr() {
let e = Expr::get_attr(Expr::val(false), "test".into());
assert_eq!(
e.subexpressions().collect::<HashSet<_>>(),
HashSet::from([&e, &Expr::val(false)])
);
}
#[test]
fn set() {
let e = Expr::set(vec![Expr::val(false), Expr::val(true)]);
assert_eq!(
e.subexpressions().collect::<HashSet<_>>(),
HashSet::from([&e, &Expr::val(false), &Expr::val(true)])
);
}
#[test]
fn set_duplicates() {
let e = Expr::set(vec![Expr::val(true), Expr::val(true)]);
let v: Vec<_> = e.subexpressions().collect();
assert_eq!(v.len(), 3);
assert!(v.contains(&&Expr::val(true)));
}
#[test]
fn record() {
let e = Expr::record(vec![
("test".into(), Expr::val(true)),
("another".into(), Expr::val(false)),
]);
assert_eq!(
e.subexpressions().collect::<HashSet<_>>(),
HashSet::from([&e, &Expr::val(false), &Expr::val(true)])
);
}
#[test]
fn duplicates() {
let e = Expr::ite(Expr::val(true), Expr::val(true), Expr::val(true));
let v: Vec<_> = e.subexpressions().collect();
assert_eq!(v.len(), 4);
assert!(v.contains(&&e));
assert!(v.contains(&&Expr::val(true)));
}
#[test]
fn deeply_nested() {
let e = Expr::get_attr(
Expr::get_attr(Expr::and(Expr::val(1), Expr::val(0)), "attr2".into()),
"attr1".into(),
);
let set: HashSet<_> = e.subexpressions().collect();
assert!(set.contains(&e));
assert!(set.contains(&Expr::get_attr(
Expr::and(Expr::val(1), Expr::val(0)),
"attr2".into()
)));
assert!(set.contains(&Expr::and(Expr::val(1), Expr::val(0))));
assert!(set.contains(&Expr::val(1)));
assert!(set.contains(&Expr::val(0)));
}
}