snarkvm_r1cs/
constraint_variable.rs

1// Copyright (C) 2019-2023 Aleo Systems Inc.
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7// http://www.apache.org/licenses/LICENSE-2.0
8
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::{LinearCombination, Variable};
16use snarkvm_fields::Field;
17
18use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub};
19
20/// Either a `Variable` or a `LinearCombination`.
21#[derive(Clone, Debug)]
22pub enum ConstraintVariable<F: Field> {
23    /// A wrapper around a `LinearCombination`.
24    LC(LinearCombination<F>),
25    /// A wrapper around a `Variable`.
26    Var(Variable),
27}
28
29impl<F: Field> From<Variable> for ConstraintVariable<F> {
30    #[inline]
31    fn from(var: Variable) -> Self {
32        ConstraintVariable::Var(var)
33    }
34}
35
36impl<F: Field> From<(F, Variable)> for ConstraintVariable<F> {
37    #[inline]
38    fn from(coeff_var: (F, Variable)) -> Self {
39        ConstraintVariable::LC(coeff_var.into())
40    }
41}
42
43impl<F: Field> From<LinearCombination<F>> for ConstraintVariable<F> {
44    #[inline]
45    fn from(lc: LinearCombination<F>) -> Self {
46        ConstraintVariable::LC(lc)
47    }
48}
49
50impl<F: Field> From<(F, LinearCombination<F>)> for ConstraintVariable<F> {
51    #[inline]
52    fn from((coeff, mut lc): (F, LinearCombination<F>)) -> Self {
53        lc *= coeff;
54        ConstraintVariable::LC(lc)
55    }
56}
57
58impl<F: Field> From<(F, ConstraintVariable<F>)> for ConstraintVariable<F> {
59    #[inline]
60    fn from((coeff, var): (F, ConstraintVariable<F>)) -> Self {
61        match var {
62            ConstraintVariable::LC(lc) => (coeff, lc).into(),
63            ConstraintVariable::Var(var) => (coeff, var).into(),
64        }
65    }
66}
67
68impl<F: Field> ConstraintVariable<F> {
69    /// Returns an empty linear combination.
70    #[inline]
71    pub fn zero() -> Self {
72        ConstraintVariable::LC(LinearCombination::zero())
73    }
74
75    /// Negate the coefficients of all variables in `self`.
76    pub fn negate_in_place(&mut self) {
77        match self {
78            ConstraintVariable::LC(ref mut lc) => lc.negate_in_place(),
79            ConstraintVariable::Var(var) => *self = (-F::one(), *var).into(),
80        }
81    }
82
83    /// Double the coefficients of all variables in `self`.
84    pub fn double_in_place(&mut self) {
85        match self {
86            ConstraintVariable::LC(lc) => lc.double_in_place(),
87            ConstraintVariable::Var(var) => *self = (F::one().double(), *var).into(),
88        }
89    }
90}
91
92impl<F: Field> Add<LinearCombination<F>> for ConstraintVariable<F> {
93    type Output = LinearCombination<F>;
94
95    #[inline]
96    fn add(self, other_lc: LinearCombination<F>) -> LinearCombination<F> {
97        match self {
98            ConstraintVariable::LC(lc) => other_lc + lc,
99            ConstraintVariable::Var(var) => other_lc + var,
100        }
101    }
102}
103
104impl<F: Field> Sub<LinearCombination<F>> for ConstraintVariable<F> {
105    type Output = LinearCombination<F>;
106
107    #[inline]
108    fn sub(self, other_lc: LinearCombination<F>) -> LinearCombination<F> {
109        let result = match self {
110            ConstraintVariable::LC(lc) => other_lc - lc,
111            ConstraintVariable::Var(var) => other_lc - var,
112        };
113        -result
114    }
115}
116
117impl<F: Field> Add<LinearCombination<F>> for &ConstraintVariable<F> {
118    type Output = LinearCombination<F>;
119
120    #[inline]
121    fn add(self, other_lc: LinearCombination<F>) -> LinearCombination<F> {
122        match self {
123            ConstraintVariable::LC(lc) => other_lc + lc,
124            ConstraintVariable::Var(var) => other_lc + *var,
125        }
126    }
127}
128
129impl<F: Field> Sub<LinearCombination<F>> for &ConstraintVariable<F> {
130    type Output = LinearCombination<F>;
131
132    #[inline]
133    fn sub(self, other_lc: LinearCombination<F>) -> LinearCombination<F> {
134        let result = match self {
135            ConstraintVariable::LC(lc) => other_lc - lc,
136            ConstraintVariable::Var(var) => other_lc - *var,
137        };
138        -result
139    }
140}
141
142impl<F: Field> Add<(F, Variable)> for ConstraintVariable<F> {
143    type Output = Self;
144
145    #[inline]
146    fn add(self, var: (F, Variable)) -> Self {
147        let lc = match self {
148            ConstraintVariable::LC(lc) => lc + var,
149            ConstraintVariable::Var(var2) => LinearCombination::from(var2) + var,
150        };
151        ConstraintVariable::LC(lc)
152    }
153}
154
155impl<F: Field> AddAssign<(F, Variable)> for ConstraintVariable<F> {
156    #[inline]
157    fn add_assign(&mut self, var: (F, Variable)) {
158        match self {
159            ConstraintVariable::LC(ref mut lc) => *lc += var,
160            ConstraintVariable::Var(var2) => *self = ConstraintVariable::LC(LinearCombination::from(*var2) + var),
161        };
162    }
163}
164
165impl<F: Field> Neg for ConstraintVariable<F> {
166    type Output = Self;
167
168    #[inline]
169    fn neg(mut self) -> Self {
170        self.negate_in_place();
171        self
172    }
173}
174
175impl<F: Field> Mul<F> for ConstraintVariable<F> {
176    type Output = Self;
177
178    #[inline]
179    fn mul(self, scalar: F) -> Self {
180        match self {
181            ConstraintVariable::LC(lc) => ConstraintVariable::LC(lc * scalar),
182            ConstraintVariable::Var(var) => (scalar, var).into(),
183        }
184    }
185}
186
187impl<F: Field> MulAssign<F> for ConstraintVariable<F> {
188    #[inline]
189    fn mul_assign(&mut self, scalar: F) {
190        match self {
191            ConstraintVariable::LC(lc) => *lc *= scalar,
192            ConstraintVariable::Var(var) => *self = (scalar, *var).into(),
193        }
194    }
195}
196
197impl<F: Field> Sub<(F, Variable)> for ConstraintVariable<F> {
198    type Output = Self;
199
200    #[inline]
201    fn sub(self, (coeff, var): (F, Variable)) -> Self {
202        self + (-coeff, var)
203    }
204}
205
206impl<F: Field> Add<Variable> for ConstraintVariable<F> {
207    type Output = Self;
208
209    fn add(self, other: Variable) -> Self {
210        self + (F::one(), other)
211    }
212}
213
214impl<F: Field> Sub<Variable> for ConstraintVariable<F> {
215    type Output = Self;
216
217    #[inline]
218    fn sub(self, other: Variable) -> Self {
219        self - (F::one(), other)
220    }
221}
222
223impl<'a, F: Field> Add<&'a Self> for ConstraintVariable<F> {
224    type Output = Self;
225
226    #[inline]
227    fn add(self, other: &'a Self) -> Self {
228        let lc = match self {
229            ConstraintVariable::LC(lc2) => lc2,
230            ConstraintVariable::Var(var) => var.into(),
231        };
232        let lc2 = match other {
233            ConstraintVariable::LC(lc2) => lc + lc2,
234            ConstraintVariable::Var(var) => lc + *var,
235        };
236        ConstraintVariable::LC(lc2)
237    }
238}
239
240impl<'a, F: Field> Sub<&'a Self> for ConstraintVariable<F> {
241    type Output = Self;
242
243    #[inline]
244    fn sub(self, other: &'a Self) -> Self {
245        let lc = match self {
246            ConstraintVariable::LC(lc2) => lc2,
247            ConstraintVariable::Var(var) => var.into(),
248        };
249        let lc2 = match other {
250            ConstraintVariable::LC(lc2) => lc - lc2,
251            ConstraintVariable::Var(var) => lc - *var,
252        };
253        ConstraintVariable::LC(lc2)
254    }
255}
256
257impl<F: Field> Add<&ConstraintVariable<F>> for &ConstraintVariable<F> {
258    type Output = ConstraintVariable<F>;
259
260    #[inline]
261    fn add(self, other: &ConstraintVariable<F>) -> Self::Output {
262        (ConstraintVariable::zero() + self) + other
263    }
264}
265
266impl<F: Field> Sub<&ConstraintVariable<F>> for &ConstraintVariable<F> {
267    type Output = ConstraintVariable<F>;
268
269    #[inline]
270    fn sub(self, other: &ConstraintVariable<F>) -> Self::Output {
271        (ConstraintVariable::zero() + self) - other
272    }
273}
274
275impl<'a, F: Field> Add<(F, &'a Self)> for ConstraintVariable<F> {
276    type Output = Self;
277
278    #[inline]
279    fn add(self, (coeff, other): (F, &'a Self)) -> Self {
280        let mut lc = match self {
281            ConstraintVariable::LC(lc2) => lc2,
282            ConstraintVariable::Var(var) => LinearCombination::zero() + var,
283        };
284
285        lc = match other {
286            ConstraintVariable::LC(lc2) => lc + (coeff, lc2),
287            ConstraintVariable::Var(var) => lc + (coeff, *var),
288        };
289        ConstraintVariable::LC(lc)
290    }
291}
292
293impl<'a, F: Field> Sub<(F, &'a Self)> for ConstraintVariable<F> {
294    type Output = Self;
295
296    #[inline]
297    #[allow(clippy::suspicious_arithmetic_impl)]
298    fn sub(self, (coeff, other): (F, &'a Self)) -> Self {
299        let mut lc = match self {
300            ConstraintVariable::LC(lc2) => lc2,
301            ConstraintVariable::Var(var) => LinearCombination::zero() + var,
302        };
303        lc = match other {
304            ConstraintVariable::LC(lc2) => lc - (coeff, lc2),
305            ConstraintVariable::Var(var) => lc - (coeff, *var),
306        };
307        ConstraintVariable::LC(lc)
308    }
309}