snarkvm_algorithms/r1cs/
linear_combination.rs

1// Copyright 2024-2025 Aleo Network Foundation
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::r1cs::Variable;
17use snarkvm_fields::Field;
18
19use std::{
20    cmp::Ordering,
21    ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub},
22};
23
24/// This represents a linear combination of some variables, with coefficients
25/// in the field `F`.
26/// The `(coeff, var)` pairs in a `LinearCombination` are kept sorted according
27/// to the index of the variable in its constraint system.
28#[derive(Debug, Clone, PartialEq, Eq, Hash)]
29pub struct LinearCombination<F: Field>(pub Vec<(Variable, F)>);
30
31impl<F: Field> AsRef<[(Variable, F)]> for LinearCombination<F> {
32    #[inline]
33    fn as_ref(&self) -> &[(Variable, F)] {
34        &self.0
35    }
36}
37
38impl<F: Field> From<(F, Variable)> for LinearCombination<F> {
39    #[inline]
40    fn from((coeff, var): (F, Variable)) -> Self {
41        LinearCombination(vec![(var, coeff)])
42    }
43}
44
45impl<F: Field> From<Variable> for LinearCombination<F> {
46    #[inline]
47    fn from(var: Variable) -> Self {
48        LinearCombination(vec![(var, F::one())])
49    }
50}
51
52impl<F: Field> LinearCombination<F> {
53    /// Outputs an empty linear combination.
54    #[inline]
55    pub fn zero() -> LinearCombination<F> {
56        LinearCombination(Vec::new())
57    }
58
59    /// Replaces the contents of `self` with those of `other`.
60    #[inline]
61    pub fn replace_in_place(&mut self, other: Self) {
62        self.0.clear();
63        self.0.extend_from_slice(&other.0)
64    }
65
66    /// Negate the coefficients of all variables in `self`.
67    #[inline]
68    pub fn negate_in_place(&mut self) {
69        self.0.iter_mut().for_each(|(_, coeff)| *coeff = -(*coeff));
70    }
71
72    /// Double the coefficients of all variables in `self`.
73    #[inline]
74    pub fn double_in_place(&mut self) {
75        self.0.iter_mut().for_each(|(_, coeff)| {
76            coeff.double_in_place();
77        });
78    }
79
80    /// Get the location of a variable in `self`.
81    #[inline]
82    pub fn get_var_loc(&self, search_var: &Variable) -> Result<usize, usize> {
83        if self.0.len() < 6 {
84            let mut found_index = 0;
85            for (i, (var, _)) in self.0.iter().enumerate() {
86                if var >= search_var {
87                    found_index = i;
88                    break;
89                } else {
90                    found_index += 1;
91                }
92            }
93            if self.0.get(found_index).map(|x| &x.0 == search_var).unwrap_or_default() {
94                Ok(found_index)
95            } else {
96                Err(found_index)
97            }
98        } else {
99            self.0.binary_search_by_key(search_var, |&(cur_var, _)| cur_var)
100        }
101    }
102}
103
104impl<F: Field> Add<(F, Variable)> for LinearCombination<F> {
105    type Output = Self;
106
107    #[inline]
108    fn add(mut self, coeff_var: (F, Variable)) -> Self {
109        self += coeff_var;
110        self
111    }
112}
113
114impl<F: Field> AddAssign<(F, Variable)> for LinearCombination<F> {
115    #[inline]
116    fn add_assign(&mut self, (coeff, var): (F, Variable)) {
117        match self.get_var_loc(&var) {
118            Ok(found) => {
119                self.0[found].1 += &coeff;
120                if self.0[found].1.is_zero() {
121                    self.0.remove(found);
122                }
123            }
124            Err(not_found) => self.0.insert(not_found, (var, coeff)),
125        }
126    }
127}
128
129impl<F: Field> Sub<(F, Variable)> for LinearCombination<F> {
130    type Output = Self;
131
132    #[inline]
133    fn sub(self, (coeff, var): (F, Variable)) -> Self {
134        self + (-coeff, var)
135    }
136}
137
138impl<F: Field> Neg for LinearCombination<F> {
139    type Output = Self;
140
141    #[inline]
142    fn neg(mut self) -> Self {
143        self.negate_in_place();
144        self
145    }
146}
147
148impl<F: Field> Mul<F> for LinearCombination<F> {
149    type Output = Self;
150
151    #[inline]
152    fn mul(mut self, scalar: F) -> Self {
153        self *= scalar;
154        self
155    }
156}
157
158impl<F: Field> MulAssign<F> for LinearCombination<F> {
159    #[inline]
160    fn mul_assign(&mut self, scalar: F) {
161        self.0.iter_mut().for_each(|(_, coeff)| *coeff *= &scalar);
162    }
163}
164
165impl<F: Field> Add<Variable> for LinearCombination<F> {
166    type Output = Self;
167
168    #[inline]
169    fn add(self, other: Variable) -> LinearCombination<F> {
170        self + (F::one(), other)
171    }
172}
173
174impl<F: Field> Sub<Variable> for LinearCombination<F> {
175    type Output = LinearCombination<F>;
176
177    #[inline]
178    fn sub(self, other: Variable) -> LinearCombination<F> {
179        self - (F::one(), other)
180    }
181}
182
183fn op_impl<F: Field, F1, F2>(
184    cur: &LinearCombination<F>,
185    other: &LinearCombination<F>,
186    push_fn: F1,
187    combine_fn: F2,
188) -> LinearCombination<F>
189where
190    F1: Fn(F) -> F,
191    F2: Fn(F, F) -> F,
192{
193    let mut new_vec = Vec::with_capacity(cur.0.len() + other.0.len());
194    let mut i = 0;
195    let mut j = 0;
196    while i < cur.0.len() && j < other.0.len() {
197        let self_cur = &cur.0[i];
198        let other_cur = &other.0[j];
199        match self_cur.0.cmp(&other_cur.0) {
200            Ordering::Greater => {
201                new_vec.push((other_cur.0, push_fn(other_cur.1)));
202                j += 1;
203            }
204            Ordering::Less => {
205                new_vec.push(*self_cur);
206                i += 1;
207            }
208            Ordering::Equal => {
209                new_vec.push((self_cur.0, combine_fn(self_cur.1, other_cur.1)));
210                i += 1;
211                j += 1;
212            }
213        }
214    }
215    new_vec.extend_from_slice(&cur.0[i..]);
216    while j < other.0.len() {
217        new_vec.push((other.0[j].0, push_fn(other.0[j].1)));
218        j += 1;
219    }
220    LinearCombination(new_vec)
221}
222
223impl<F: Field> Add<&LinearCombination<F>> for &LinearCombination<F> {
224    type Output = LinearCombination<F>;
225
226    fn add(self, other: &LinearCombination<F>) -> LinearCombination<F> {
227        if other.0.is_empty() {
228            return self.clone();
229        } else if self.0.is_empty() {
230            return other.clone();
231        }
232        op_impl(self, other, |coeff| coeff, |cur_coeff, other_coeff| cur_coeff + other_coeff)
233    }
234}
235
236impl<F: Field> Add<LinearCombination<F>> for &LinearCombination<F> {
237    type Output = LinearCombination<F>;
238
239    fn add(self, other: LinearCombination<F>) -> LinearCombination<F> {
240        if self.0.is_empty() {
241            return other;
242        } else if other.0.is_empty() {
243            return self.clone();
244        }
245        op_impl(self, &other, |coeff| coeff, |cur_coeff, other_coeff| cur_coeff + other_coeff)
246    }
247}
248
249impl<'a, F: Field> Add<&'a LinearCombination<F>> for LinearCombination<F> {
250    type Output = LinearCombination<F>;
251
252    fn add(self, other: &'a LinearCombination<F>) -> LinearCombination<F> {
253        if other.0.is_empty() {
254            return self;
255        } else if self.0.is_empty() {
256            return other.clone();
257        }
258        op_impl(&self, other, |coeff| coeff, |cur_coeff, other_coeff| cur_coeff + other_coeff)
259    }
260}
261
262impl<F: Field> Add<LinearCombination<F>> for LinearCombination<F> {
263    type Output = Self;
264
265    fn add(self, other: Self) -> Self {
266        if other.0.is_empty() {
267            return self;
268        } else if self.0.is_empty() {
269            return other;
270        }
271        op_impl(&self, &other, |coeff| coeff, |cur_coeff, other_coeff| cur_coeff + other_coeff)
272    }
273}
274
275impl<F: Field> Sub<&LinearCombination<F>> for &LinearCombination<F> {
276    type Output = LinearCombination<F>;
277
278    fn sub(self, other: &LinearCombination<F>) -> LinearCombination<F> {
279        if other.0.is_empty() {
280            let cur = self.clone();
281            return cur;
282        } else if self.0.is_empty() {
283            let mut other = other.clone();
284            other.negate_in_place();
285            return other;
286        }
287
288        op_impl(self, other, |coeff| -coeff, |cur_coeff, other_coeff| cur_coeff - other_coeff)
289    }
290}
291
292impl<'a, F: Field> Sub<&'a LinearCombination<F>> for LinearCombination<F> {
293    type Output = LinearCombination<F>;
294
295    fn sub(self, other: &'a LinearCombination<F>) -> LinearCombination<F> {
296        if other.0.is_empty() {
297            return self;
298        } else if self.0.is_empty() {
299            let mut other = other.clone();
300            other.negate_in_place();
301            return other;
302        }
303        op_impl(&self, other, |coeff| -coeff, |cur_coeff, other_coeff| cur_coeff - other_coeff)
304    }
305}
306
307impl<F: Field> Sub<LinearCombination<F>> for &LinearCombination<F> {
308    type Output = LinearCombination<F>;
309
310    fn sub(self, mut other: LinearCombination<F>) -> LinearCombination<F> {
311        if self.0.is_empty() {
312            other.negate_in_place();
313            return other;
314        } else if other.0.is_empty() {
315            return self.clone();
316        }
317
318        op_impl(self, &other, |coeff| -coeff, |cur_coeff, other_coeff| cur_coeff - other_coeff)
319    }
320}
321
322impl<F: Field> Sub<LinearCombination<F>> for LinearCombination<F> {
323    type Output = LinearCombination<F>;
324
325    fn sub(self, mut other: LinearCombination<F>) -> LinearCombination<F> {
326        if other.0.is_empty() {
327            return self;
328        } else if self.0.is_empty() {
329            other.negate_in_place();
330            return other;
331        }
332        op_impl(&self, &other, |coeff| -coeff, |cur_coeff, other_coeff| cur_coeff - other_coeff)
333    }
334}
335
336impl<F: Field> Add<(F, &LinearCombination<F>)> for &LinearCombination<F> {
337    type Output = LinearCombination<F>;
338
339    #[allow(clippy::suspicious_arithmetic_impl)]
340    fn add(self, (mul_coeff, other): (F, &LinearCombination<F>)) -> LinearCombination<F> {
341        if other.0.is_empty() {
342            return self.clone();
343        } else if self.0.is_empty() {
344            let mut other = other.clone();
345            other.mul_assign(mul_coeff);
346            return other;
347        }
348        op_impl(self, other, |coeff| mul_coeff * coeff, |cur_coeff, other_coeff| cur_coeff + (mul_coeff * other_coeff))
349    }
350}
351
352impl<'a, F: Field> Add<(F, &'a LinearCombination<F>)> for LinearCombination<F> {
353    type Output = LinearCombination<F>;
354
355    #[allow(clippy::suspicious_arithmetic_impl)]
356    fn add(self, (mul_coeff, other): (F, &'a LinearCombination<F>)) -> LinearCombination<F> {
357        if other.0.is_empty() {
358            return self;
359        } else if self.0.is_empty() {
360            let mut other = other.clone();
361            other.mul_assign(mul_coeff);
362            return other;
363        }
364        op_impl(&self, other, |coeff| mul_coeff * coeff, |cur_coeff, other_coeff| cur_coeff + (mul_coeff * other_coeff))
365    }
366}
367
368impl<F: Field> Add<(F, LinearCombination<F>)> for &LinearCombination<F> {
369    type Output = LinearCombination<F>;
370
371    #[allow(clippy::suspicious_arithmetic_impl)]
372    fn add(self, (mul_coeff, mut other): (F, LinearCombination<F>)) -> LinearCombination<F> {
373        if other.0.is_empty() {
374            return self.clone();
375        } else if self.0.is_empty() {
376            other.mul_assign(mul_coeff);
377            return other;
378        }
379        op_impl(self, &other, |coeff| mul_coeff * coeff, |cur_coeff, other_coeff| cur_coeff + (mul_coeff * other_coeff))
380    }
381}
382
383impl<F: Field> Add<(F, Self)> for LinearCombination<F> {
384    type Output = Self;
385
386    #[allow(clippy::suspicious_arithmetic_impl)]
387    fn add(self, (mul_coeff, other): (F, Self)) -> Self {
388        if other.0.is_empty() {
389            return self;
390        } else if self.0.is_empty() {
391            let mut other = other;
392            other.mul_assign(mul_coeff);
393            return other;
394        }
395        op_impl(
396            &self,
397            &other,
398            |coeff| mul_coeff * coeff,
399            |cur_coeff, other_coeff| cur_coeff + (mul_coeff * other_coeff),
400        )
401    }
402}
403
404impl<F: Field> Sub<(F, &LinearCombination<F>)> for &LinearCombination<F> {
405    type Output = LinearCombination<F>;
406
407    fn sub(self, (coeff, other): (F, &LinearCombination<F>)) -> LinearCombination<F> {
408        self + (-coeff, other)
409    }
410}
411
412impl<'a, F: Field> Sub<(F, &'a LinearCombination<F>)> for LinearCombination<F> {
413    type Output = LinearCombination<F>;
414
415    fn sub(self, (coeff, other): (F, &'a LinearCombination<F>)) -> LinearCombination<F> {
416        self + (-coeff, other)
417    }
418}
419
420impl<F: Field> Sub<(F, LinearCombination<F>)> for &LinearCombination<F> {
421    type Output = LinearCombination<F>;
422
423    fn sub(self, (coeff, other): (F, LinearCombination<F>)) -> LinearCombination<F> {
424        self + (-coeff, other)
425    }
426}
427
428impl<F: Field> Sub<(F, LinearCombination<F>)> for LinearCombination<F> {
429    type Output = LinearCombination<F>;
430
431    fn sub(self, (coeff, other): (F, LinearCombination<F>)) -> LinearCombination<F> {
432        self + (-coeff, other)
433    }
434}
435
436#[cfg(test)]
437mod tests {
438    use crate::r1cs::Index;
439
440    use super::*;
441    use snarkvm_curves::bls12_377::Fr;
442
443    #[test]
444    fn linear_combination_append() {
445        let mut combo = LinearCombination::<Fr>::zero();
446        for i in 0..100u64 {
447            combo += (i.into(), Variable::new_unchecked(Index::Public(0)));
448        }
449        assert_eq!(combo.0.len(), 1);
450    }
451}