snarkvm_r1cs/
linear_combination.rs

1// Copyright (C) 2019-2023 Aleo Systems Inc.
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7// http://www.apache.org/licenses/LICENSE-2.0
8
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::Variable;
16use snarkvm_fields::Field;
17
18use std::{
19    cmp::Ordering,
20    ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub},
21};
22
23/// This represents a linear combination of some variables, with coefficients
24/// in the field `F`.
25/// The `(coeff, var)` pairs in a `LinearCombination` are kept sorted according
26/// to the index of the variable in its constraint system.
27#[derive(Debug, Clone, PartialEq, Eq, Hash)]
28pub struct LinearCombination<F: Field>(pub Vec<(Variable, F)>);
29
30impl<F: Field> AsRef<[(Variable, F)]> for LinearCombination<F> {
31    #[inline]
32    fn as_ref(&self) -> &[(Variable, F)] {
33        &self.0
34    }
35}
36
37impl<F: Field> From<(F, Variable)> for LinearCombination<F> {
38    #[inline]
39    fn from((coeff, var): (F, Variable)) -> Self {
40        LinearCombination(vec![(var, coeff)])
41    }
42}
43
44impl<F: Field> From<Variable> for LinearCombination<F> {
45    #[inline]
46    fn from(var: Variable) -> Self {
47        LinearCombination(vec![(var, F::one())])
48    }
49}
50
51impl<F: Field> LinearCombination<F> {
52    /// Outputs an empty linear combination.
53    #[inline]
54    pub fn zero() -> LinearCombination<F> {
55        LinearCombination(Vec::new())
56    }
57
58    /// Replaces the contents of `self` with those of `other`.
59    #[inline]
60    pub fn replace_in_place(&mut self, other: Self) {
61        self.0.clear();
62        self.0.extend_from_slice(&other.0)
63    }
64
65    /// Negate the coefficients of all variables in `self`.
66    #[inline]
67    pub fn negate_in_place(&mut self) {
68        self.0.iter_mut().for_each(|(_, coeff)| *coeff = -(*coeff));
69    }
70
71    /// Double the coefficients of all variables in `self`.
72    #[inline]
73    pub fn double_in_place(&mut self) {
74        self.0.iter_mut().for_each(|(_, coeff)| {
75            coeff.double_in_place();
76        });
77    }
78
79    /// Get the location of a variable in `self`.
80    #[inline]
81    pub fn get_var_loc(&self, search_var: &Variable) -> Result<usize, usize> {
82        if self.0.len() < 6 {
83            let mut found_index = 0;
84            for (i, (var, _)) in self.0.iter().enumerate() {
85                if var >= search_var {
86                    found_index = i;
87                    break;
88                } else {
89                    found_index += 1;
90                }
91            }
92            if self.0.get(found_index).map(|x| &x.0 == search_var).unwrap_or_default() {
93                Ok(found_index)
94            } else {
95                Err(found_index)
96            }
97        } else {
98            self.0.binary_search_by_key(search_var, |&(cur_var, _)| cur_var)
99        }
100    }
101}
102
103impl<F: Field> Add<(F, Variable)> for LinearCombination<F> {
104    type Output = Self;
105
106    #[inline]
107    fn add(mut self, coeff_var: (F, Variable)) -> Self {
108        self += coeff_var;
109        self
110    }
111}
112
113impl<F: Field> AddAssign<(F, Variable)> for LinearCombination<F> {
114    #[inline]
115    fn add_assign(&mut self, (coeff, var): (F, Variable)) {
116        match self.get_var_loc(&var) {
117            Ok(found) => self.0[found].1 += &coeff,
118            Err(not_found) => self.0.insert(not_found, (var, coeff)),
119        }
120    }
121}
122
123impl<F: Field> Sub<(F, Variable)> for LinearCombination<F> {
124    type Output = Self;
125
126    #[inline]
127    fn sub(self, (coeff, var): (F, Variable)) -> Self {
128        self + (-coeff, var)
129    }
130}
131
132impl<F: Field> Neg for LinearCombination<F> {
133    type Output = Self;
134
135    #[inline]
136    fn neg(mut self) -> Self {
137        self.negate_in_place();
138        self
139    }
140}
141
142impl<F: Field> Mul<F> for LinearCombination<F> {
143    type Output = Self;
144
145    #[inline]
146    fn mul(mut self, scalar: F) -> Self {
147        self *= scalar;
148        self
149    }
150}
151
152impl<F: Field> MulAssign<F> for LinearCombination<F> {
153    #[inline]
154    fn mul_assign(&mut self, scalar: F) {
155        self.0.iter_mut().for_each(|(_, coeff)| *coeff *= &scalar);
156    }
157}
158
159impl<F: Field> Add<Variable> for LinearCombination<F> {
160    type Output = Self;
161
162    #[inline]
163    fn add(self, other: Variable) -> LinearCombination<F> {
164        self + (F::one(), other)
165    }
166}
167
168impl<F: Field> Sub<Variable> for LinearCombination<F> {
169    type Output = LinearCombination<F>;
170
171    #[inline]
172    fn sub(self, other: Variable) -> LinearCombination<F> {
173        self - (F::one(), other)
174    }
175}
176
177fn op_impl<F: Field, F1, F2>(
178    cur: &LinearCombination<F>,
179    other: &LinearCombination<F>,
180    push_fn: F1,
181    combine_fn: F2,
182) -> LinearCombination<F>
183where
184    F1: Fn(F) -> F,
185    F2: Fn(F, F) -> F,
186{
187    let mut new_vec = Vec::with_capacity(cur.0.len() + other.0.len());
188    let mut i = 0;
189    let mut j = 0;
190    while i < cur.0.len() && j < other.0.len() {
191        let self_cur = &cur.0[i];
192        let other_cur = &other.0[j];
193        match self_cur.0.cmp(&other_cur.0) {
194            Ordering::Greater => {
195                new_vec.push((other_cur.0, push_fn(other_cur.1)));
196                j += 1;
197            }
198            Ordering::Less => {
199                new_vec.push(*self_cur);
200                i += 1;
201            }
202            Ordering::Equal => {
203                new_vec.push((self_cur.0, combine_fn(self_cur.1, other_cur.1)));
204                i += 1;
205                j += 1;
206            }
207        }
208    }
209    new_vec.extend_from_slice(&cur.0[i..]);
210    while j < other.0.len() {
211        new_vec.push((other.0[j].0, push_fn(other.0[j].1)));
212        j += 1;
213    }
214    LinearCombination(new_vec)
215}
216
217impl<F: Field> Add<&LinearCombination<F>> for &LinearCombination<F> {
218    type Output = LinearCombination<F>;
219
220    fn add(self, other: &LinearCombination<F>) -> LinearCombination<F> {
221        if other.0.is_empty() {
222            return self.clone();
223        } else if self.0.is_empty() {
224            return other.clone();
225        }
226        op_impl(self, other, |coeff| coeff, |cur_coeff, other_coeff| cur_coeff + other_coeff)
227    }
228}
229
230impl<F: Field> Add<LinearCombination<F>> for &LinearCombination<F> {
231    type Output = LinearCombination<F>;
232
233    fn add(self, other: LinearCombination<F>) -> LinearCombination<F> {
234        if self.0.is_empty() {
235            return other;
236        } else if other.0.is_empty() {
237            return self.clone();
238        }
239        op_impl(self, &other, |coeff| coeff, |cur_coeff, other_coeff| cur_coeff + other_coeff)
240    }
241}
242
243impl<'a, F: Field> Add<&'a LinearCombination<F>> for LinearCombination<F> {
244    type Output = LinearCombination<F>;
245
246    fn add(self, other: &'a LinearCombination<F>) -> LinearCombination<F> {
247        if other.0.is_empty() {
248            return self;
249        } else if self.0.is_empty() {
250            return other.clone();
251        }
252        op_impl(&self, other, |coeff| coeff, |cur_coeff, other_coeff| cur_coeff + other_coeff)
253    }
254}
255
256impl<F: Field> Add<LinearCombination<F>> for LinearCombination<F> {
257    type Output = Self;
258
259    fn add(self, other: Self) -> Self {
260        if other.0.is_empty() {
261            return self;
262        } else if self.0.is_empty() {
263            return other;
264        }
265        op_impl(&self, &other, |coeff| coeff, |cur_coeff, other_coeff| cur_coeff + other_coeff)
266    }
267}
268
269impl<F: Field> Sub<&LinearCombination<F>> for &LinearCombination<F> {
270    type Output = LinearCombination<F>;
271
272    fn sub(self, other: &LinearCombination<F>) -> LinearCombination<F> {
273        if other.0.is_empty() {
274            let cur = self.clone();
275            return cur;
276        } else if self.0.is_empty() {
277            let mut other = other.clone();
278            other.negate_in_place();
279            return other;
280        }
281
282        op_impl(self, other, |coeff| -coeff, |cur_coeff, other_coeff| cur_coeff - other_coeff)
283    }
284}
285
286impl<'a, F: Field> Sub<&'a LinearCombination<F>> for LinearCombination<F> {
287    type Output = LinearCombination<F>;
288
289    fn sub(self, other: &'a LinearCombination<F>) -> LinearCombination<F> {
290        if other.0.is_empty() {
291            return self;
292        } else if self.0.is_empty() {
293            let mut other = other.clone();
294            other.negate_in_place();
295            return other;
296        }
297        op_impl(&self, other, |coeff| -coeff, |cur_coeff, other_coeff| cur_coeff - other_coeff)
298    }
299}
300
301impl<F: Field> Sub<LinearCombination<F>> for &LinearCombination<F> {
302    type Output = LinearCombination<F>;
303
304    fn sub(self, mut other: LinearCombination<F>) -> LinearCombination<F> {
305        if self.0.is_empty() {
306            other.negate_in_place();
307            return other;
308        } else if other.0.is_empty() {
309            return self.clone();
310        }
311
312        op_impl(self, &other, |coeff| -coeff, |cur_coeff, other_coeff| cur_coeff - other_coeff)
313    }
314}
315
316impl<F: Field> Sub<LinearCombination<F>> for LinearCombination<F> {
317    type Output = LinearCombination<F>;
318
319    fn sub(self, mut other: LinearCombination<F>) -> LinearCombination<F> {
320        if other.0.is_empty() {
321            return self;
322        } else if self.0.is_empty() {
323            other.negate_in_place();
324            return other;
325        }
326        op_impl(&self, &other, |coeff| -coeff, |cur_coeff, other_coeff| cur_coeff - other_coeff)
327    }
328}
329
330impl<F: Field> Add<(F, &LinearCombination<F>)> for &LinearCombination<F> {
331    type Output = LinearCombination<F>;
332
333    #[allow(clippy::suspicious_arithmetic_impl)]
334    fn add(self, (mul_coeff, other): (F, &LinearCombination<F>)) -> LinearCombination<F> {
335        if other.0.is_empty() {
336            return self.clone();
337        } else if self.0.is_empty() {
338            let mut other = other.clone();
339            other.mul_assign(mul_coeff);
340            return other;
341        }
342        op_impl(self, other, |coeff| mul_coeff * coeff, |cur_coeff, other_coeff| cur_coeff + (mul_coeff * other_coeff))
343    }
344}
345
346impl<'a, F: Field> Add<(F, &'a LinearCombination<F>)> for LinearCombination<F> {
347    type Output = LinearCombination<F>;
348
349    #[allow(clippy::suspicious_arithmetic_impl)]
350    fn add(self, (mul_coeff, other): (F, &'a LinearCombination<F>)) -> LinearCombination<F> {
351        if other.0.is_empty() {
352            return self;
353        } else if self.0.is_empty() {
354            let mut other = other.clone();
355            other.mul_assign(mul_coeff);
356            return other;
357        }
358        op_impl(&self, other, |coeff| mul_coeff * coeff, |cur_coeff, other_coeff| cur_coeff + (mul_coeff * other_coeff))
359    }
360}
361
362impl<F: Field> Add<(F, LinearCombination<F>)> for &LinearCombination<F> {
363    type Output = LinearCombination<F>;
364
365    #[allow(clippy::suspicious_arithmetic_impl)]
366    fn add(self, (mul_coeff, mut other): (F, LinearCombination<F>)) -> LinearCombination<F> {
367        if other.0.is_empty() {
368            return self.clone();
369        } else if self.0.is_empty() {
370            other.mul_assign(mul_coeff);
371            return other;
372        }
373        op_impl(self, &other, |coeff| mul_coeff * coeff, |cur_coeff, other_coeff| cur_coeff + (mul_coeff * other_coeff))
374    }
375}
376
377impl<F: Field> Add<(F, Self)> for LinearCombination<F> {
378    type Output = Self;
379
380    #[allow(clippy::suspicious_arithmetic_impl)]
381    fn add(self, (mul_coeff, other): (F, Self)) -> Self {
382        if other.0.is_empty() {
383            return self;
384        } else if self.0.is_empty() {
385            let mut other = other;
386            other.mul_assign(mul_coeff);
387            return other;
388        }
389        op_impl(
390            &self,
391            &other,
392            |coeff| mul_coeff * coeff,
393            |cur_coeff, other_coeff| cur_coeff + (mul_coeff * other_coeff),
394        )
395    }
396}
397
398impl<F: Field> Sub<(F, &LinearCombination<F>)> for &LinearCombination<F> {
399    type Output = LinearCombination<F>;
400
401    fn sub(self, (coeff, other): (F, &LinearCombination<F>)) -> LinearCombination<F> {
402        self + (-coeff, other)
403    }
404}
405
406impl<'a, F: Field> Sub<(F, &'a LinearCombination<F>)> for LinearCombination<F> {
407    type Output = LinearCombination<F>;
408
409    fn sub(self, (coeff, other): (F, &'a LinearCombination<F>)) -> LinearCombination<F> {
410        self + (-coeff, other)
411    }
412}
413
414impl<F: Field> Sub<(F, LinearCombination<F>)> for &LinearCombination<F> {
415    type Output = LinearCombination<F>;
416
417    fn sub(self, (coeff, other): (F, LinearCombination<F>)) -> LinearCombination<F> {
418        self + (-coeff, other)
419    }
420}
421
422impl<F: Field> Sub<(F, LinearCombination<F>)> for LinearCombination<F> {
423    type Output = LinearCombination<F>;
424
425    fn sub(self, (coeff, other): (F, LinearCombination<F>)) -> LinearCombination<F> {
426        self + (-coeff, other)
427    }
428}
429
430#[cfg(test)]
431mod tests {
432    use crate::Index;
433
434    use super::*;
435    use snarkvm_curves::bls12_377::Fr;
436
437    #[test]
438    fn linear_combination_append() {
439        let mut combo = LinearCombination::<Fr>::zero();
440        for i in 0..100u64 {
441            combo += (i.into(), Variable::new_unchecked(Index::Public(0)));
442        }
443        assert_eq!(combo.0.len(), 1);
444    }
445}