cairo_lang_eq_solver/
expr.rs

1use std::fmt::Debug;
2use std::hash::Hash;
3
4use cairo_lang_utils::collection_arithmetics::{HasZero, add_maps, sub_maps};
5use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
6
7#[cfg(test)]
8#[path = "expr_test.rs"]
9mod test;
10
11/// An linear expression of variables.
12#[derive(Clone, Debug, PartialEq, Eq)]
13pub struct Expr<Var: Clone + Debug + PartialEq + Eq + Hash> {
14    /// The constant term of the expression.
15    pub const_term: i32,
16    /// The coefficient for every variable in the expression.
17    pub var_to_coef: OrderedHashMap<Var, i64>,
18}
19impl<Var: Clone + Debug + PartialEq + Eq + Hash> Expr<Var> {
20    /// Creates a cost expression based on const value only.
21    pub fn from_const(const_term: i32) -> Self {
22        Self { const_term, var_to_coef: Default::default() }
23    }
24
25    /// Creates a cost expression based on variable only.
26    pub fn from_var(var: Var) -> Self {
27        Self { const_term: 0, var_to_coef: [(var, 1)].into_iter().collect() }
28    }
29}
30
31impl<Var: Clone + Debug + PartialEq + Eq + Hash> HasZero for Expr<Var> {
32    fn zero() -> Self {
33        Self::from_const(0)
34    }
35}
36
37// Expr operators can be optimized if necessary.
38impl<Var: Clone + Debug + PartialEq + Eq + Hash> std::ops::Add for Expr<Var> {
39    type Output = Self;
40    fn add(self, other: Self) -> Self {
41        Self {
42            const_term: self.const_term + other.const_term,
43            var_to_coef: add_maps(self.var_to_coef, other.var_to_coef),
44        }
45    }
46}
47
48impl<Var: Clone + Debug + PartialEq + Eq + Hash> std::ops::Sub for Expr<Var> {
49    type Output = Self;
50    fn sub(self, other: Self) -> Self {
51        Self {
52            const_term: self.const_term - other.const_term,
53            var_to_coef: sub_maps(self.var_to_coef, other.var_to_coef),
54        }
55    }
56}