use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
pub trait DemandReporter<Var> {
type UsePosition: Copy;
type IntroducePosition: Copy;
fn drop(&mut self, _position: Self::IntroducePosition, _var: Var) {}
fn dup(&mut self, _position: Self::UsePosition, _var: Var) {}
fn last_use(&mut self, _position: Self::UsePosition, _var_index: usize, _var: Var) {}
fn unused_mapped_var(&mut self, _var: Var) {}
}
#[derive(Clone)]
pub struct Demand<Var: std::hash::Hash + Eq + Copy> {
pub vars: OrderedHashSet<Var>,
}
impl<Var: std::hash::Hash + Eq + Copy> Default for Demand<Var> {
fn default() -> Self {
Self { vars: Default::default() }
}
}
impl<Var: std::hash::Hash + Eq + Copy> Demand<Var> {
pub fn finalize(self) -> bool {
self.vars.is_empty()
}
pub fn apply_remapping<V: Into<Var>, T: DemandReporter<Var>>(
&mut self,
reporter: &mut T,
remapping: impl Iterator<Item = (V, V)>,
position: T::UsePosition,
) {
for (var_index, (dst, src)) in remapping.enumerate() {
let src = src.into();
let dst = dst.into();
if self.vars.swap_remove(&dst) {
if self.vars.insert(src) {
reporter.last_use(position, var_index, src);
} else {
reporter.dup(position, src);
}
} else {
reporter.unused_mapped_var(dst);
}
}
}
pub fn variables_used<V: Copy + Into<Var>, T: DemandReporter<Var>>(
&mut self,
reporter: &mut T,
vars: &[V],
position: T::UsePosition,
) {
for (var_index, var) in vars.iter().enumerate().rev() {
if !self.vars.insert((*var).into()) {
reporter.dup(position, (*var).into());
} else {
reporter.last_use(position, var_index, (*var).into());
}
}
}
pub fn variables_introduced<V: Copy + Into<Var>, T: DemandReporter<Var>>(
&mut self,
reporter: &mut T,
vars: &[V],
position: T::IntroducePosition,
) {
for var in vars {
if !self.vars.swap_remove(&(*var).into()) {
reporter.drop(position, (*var).into());
}
}
}
pub fn merge_demands<T: DemandReporter<Var>>(
demands: &[(Self, T::IntroducePosition)],
reporter: &mut T,
) -> Self {
let mut demand = Self::default();
for (arm_demand, _) in demands {
demand.vars.extend(arm_demand.vars.iter().copied());
}
for var in demand.vars.iter() {
for (arm_demand, position) in demands {
if !arm_demand.vars.contains(var) {
reporter.drop(*position, *var);
}
}
}
demand
}
}