cairo_lang_eq_solver/
lib.rs1pub mod expr;
3
4use std::fmt::Debug;
5use std::hash::Hash;
6
7use cairo_lang_utils::casts::IntoOrPanic;
8use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
9pub use expr::Expr;
10use good_lp::{Expression, Solution, SolverModel, default_solver, variable, variables};
11
12pub fn try_solve_equations<Var: Clone + Debug + PartialEq + Eq + Hash>(
21 mut equations: Vec<Expr<Var>>,
22 minimization_vars: Vec<Vec<Var>>,
23) -> Option<OrderedHashMap<Var, i64>> {
24 let mut accumulated_solution: OrderedHashMap<Var, i64> = Default::default();
25 let (final_iter, high_rank_iters) = minimization_vars.split_last()?;
26 for target_vars in high_rank_iters {
28 let layer_solution = try_solve_equations_iteration(&equations, target_vars)?;
29 let target_vars_solution = target_vars
30 .iter()
31 .map(|v| (v.clone(), *layer_solution.get(v).unwrap()))
32 .collect::<OrderedHashMap<_, _>>();
33 equations = equations
34 .into_iter()
35 .filter_map(|eq| {
36 let const_term = eq
37 .var_to_coef
38 .iter()
39 .filter_map(|(var, coef)| Some(target_vars_solution.get(var)? * coef))
40 .sum::<i64>()
41 + eq.const_term as i64;
42 let var_to_coef: OrderedHashMap<_, _> = eq
43 .var_to_coef
44 .into_iter()
45 .filter(|(var, _coef)| !target_vars_solution.contains_key(var))
46 .collect();
47 if var_to_coef.is_empty() {
48 assert_eq!(const_term, 0, "Zeroed out equations should be zeroed out.");
49 return None;
50 }
51 Some(Expr { var_to_coef, const_term: const_term.into_or_panic::<i32>() })
52 })
53 .collect();
54 accumulated_solution.extend(target_vars_solution);
55 }
56 let final_layer_solution = try_solve_equations_iteration(&equations, final_iter)?;
57 accumulated_solution.extend(final_layer_solution);
58 Some(accumulated_solution)
59}
60
61fn try_solve_equations_iteration<Var: Clone + Debug + PartialEq + Eq + Hash>(
69 equations: &[Expr<Var>],
70 target_vars: &[Var],
71) -> Option<OrderedHashMap<Var, i64>> {
72 let mut vars = variables!();
73 let mut orig_to_solver_var = OrderedHashMap::<_, _>::default();
74 for eq in equations {
76 for var in eq.var_to_coef.keys() {
77 orig_to_solver_var
78 .entry(var.clone())
79 .or_insert_with_key(|var| vars.add(variable().min(0).name(format!("{var:?}"))));
80 }
81 }
82 let target: Expression = target_vars.iter().map(|v| *orig_to_solver_var.get(v).unwrap()).sum();
83
84 let mut problem = vars.minimise(target).using(default_solver);
85 for eq in equations.iter() {
87 let as_solver_expr = |expr: &Expr<Var>| {
88 Expression::from_other_affine(expr.const_term)
89 + expr
90 .var_to_coef
91 .iter()
92 .map(|(var, coef)| (*coef as i32) * *orig_to_solver_var.get(var).unwrap())
93 .sum::<Expression>()
94 };
95 problem = problem.with(as_solver_expr(eq).eq(Expression::from_other_affine(0)));
96 }
97 let solution = problem.solve().ok()?;
98 Some(
99 orig_to_solver_var
100 .into_iter()
101 .map(|(orig, solver)| (orig, solution.value(solver).round() as i64))
102 .collect(),
103 )
104}