cairo_lang_eq_solver/
lib.rs

1//! Equation solving for Sierra generation.
2pub mod expr;
3
4use std::fmt::Debug;
5use std::hash::Hash;
6
7use cairo_lang_utils::casts::IntoOrPanic;
8use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
9pub use expr::Expr;
10use good_lp::{Expression, Solution, SolverModel, default_solver, variable, variables};
11
12/// Solving a set of equations and returning the values of the symbols contained in them.
13/// # Arguments
14/// * `equations` - The equations to solve.
15/// * `minimization_vars` - Vars to minimize - in ranked ordering - first minimize the sum of the
16///   first set - then the sum of the second and so on.
17/// # Returns
18/// * `Some(OrderedHashMap<Var, i64>)` - The solutions to the equations.
19/// * `None` - The equations are unsolvable.
20pub fn try_solve_equations<Var: Clone + Debug + PartialEq + Eq + Hash>(
21    mut equations: Vec<Expr<Var>>,
22    minimization_vars: Vec<Vec<Var>>,
23) -> Option<OrderedHashMap<Var, i64>> {
24    let mut accumulated_solution: OrderedHashMap<Var, i64> = Default::default();
25    let (final_iter, high_rank_iters) = minimization_vars.split_last()?;
26    // Iterating over the non-last minimization var layers.
27    for target_vars in high_rank_iters {
28        let layer_solution = try_solve_equations_iteration(&equations, target_vars)?;
29        let target_vars_solution = target_vars
30            .iter()
31            .map(|v| (v.clone(), *layer_solution.get(v).unwrap()))
32            .collect::<OrderedHashMap<_, _>>();
33        equations = equations
34            .into_iter()
35            .filter_map(|eq| {
36                let const_term = eq
37                    .var_to_coef
38                    .iter()
39                    .filter_map(|(var, coef)| Some(target_vars_solution.get(var)? * coef))
40                    .sum::<i64>()
41                    + eq.const_term as i64;
42                let var_to_coef: OrderedHashMap<_, _> = eq
43                    .var_to_coef
44                    .into_iter()
45                    .filter(|(var, _coef)| !target_vars_solution.contains_key(var))
46                    .collect();
47                if var_to_coef.is_empty() {
48                    assert_eq!(const_term, 0, "Zeroed out equations should be zeroed out.");
49                    return None;
50                }
51                Some(Expr { var_to_coef, const_term: const_term.into_or_panic::<i32>() })
52            })
53            .collect();
54        accumulated_solution.extend(target_vars_solution);
55    }
56    let final_layer_solution = try_solve_equations_iteration(&equations, final_iter)?;
57    accumulated_solution.extend(final_layer_solution);
58    Some(accumulated_solution)
59}
60
61/// Solving a set of equations and returning the values of the symbols contained in them.
62/// # Arguments
63/// * `equations` - The equations to solve.
64/// * `target_vars` - Minimize the sum of those variables.
65/// # Returns
66/// * `Some(OrderedHashMap<Var, i64>)` - The solutions to the equations.
67/// * `None` - The equations are unsolvable.
68fn try_solve_equations_iteration<Var: Clone + Debug + PartialEq + Eq + Hash>(
69    equations: &[Expr<Var>],
70    target_vars: &[Var],
71) -> Option<OrderedHashMap<Var, i64>> {
72    let mut vars = variables!();
73    let mut orig_to_solver_var = OrderedHashMap::<_, _>::default();
74    // Add all variables to structure and map.
75    for eq in equations {
76        for var in eq.var_to_coef.keys() {
77            orig_to_solver_var
78                .entry(var.clone())
79                .or_insert_with_key(|var| vars.add(variable().min(0).name(format!("{var:?}"))));
80        }
81    }
82    let target: Expression = target_vars.iter().map(|v| *orig_to_solver_var.get(v).unwrap()).sum();
83
84    let mut problem = vars.minimise(target).using(default_solver);
85    // Adding constraints for all equations.
86    for eq in equations.iter() {
87        let as_solver_expr = |expr: &Expr<Var>| {
88            Expression::from_other_affine(expr.const_term)
89                + expr
90                    .var_to_coef
91                    .iter()
92                    .map(|(var, coef)| (*coef as i32) * *orig_to_solver_var.get(var).unwrap())
93                    .sum::<Expression>()
94        };
95        problem = problem.with(as_solver_expr(eq).eq(Expression::from_other_affine(0)));
96    }
97    let solution = problem.solve().ok()?;
98    Some(
99        orig_to_solver_var
100            .into_iter()
101            .map(|(orig, solver)| (orig, solution.value(solver).round() as i64))
102            .collect(),
103    )
104}