linfa_linalg/
triangular.rs

1//! Traits for creating and manipulating triangular matrices
2
3use crate::{check_square, index::*, LinalgError, Result};
4
5use ndarray::{s, Array, ArrayBase, Data, DataMut, Ix2, NdFloat, SliceArg};
6use num_traits::Zero;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9/// Denotes an upper-triangular or lower-triangular matrix
10pub enum UPLO {
11    Upper,
12    Lower,
13}
14
15/// Transform square matrix into triangular matrix
16pub trait IntoTriangular {
17    /// Transform square matrix into a strict triangular matrix in place, zeroing out the other
18    /// elements.
19    fn triangular_inplace(&mut self, uplo: UPLO) -> Result<&mut Self>;
20
21    /// Transform square matrix into a strict triangular matrix, zeroing out the other elements.
22    fn into_triangular(self, uplo: UPLO) -> Result<Self>
23    where
24        Self: Sized;
25}
26
27impl<A, S> IntoTriangular for ArrayBase<S, Ix2>
28where
29    A: Zero,
30    S: DataMut<Elem = A>,
31{
32    fn into_triangular(mut self, uplo: UPLO) -> Result<Self> {
33        self.triangular_inplace(uplo)?;
34        Ok(self)
35    }
36
37    fn triangular_inplace(&mut self, uplo: UPLO) -> Result<&mut Self> {
38        let n = check_square(self)?;
39        if uplo == UPLO::Upper {
40            for i in 0..n {
41                for j in 0..i {
42                    unsafe { *self.atm((i, j)) = A::zero() };
43                }
44            }
45        } else {
46            for i in 0..n {
47                for j in i + 1..n {
48                    unsafe { *self.atm((i, j)) = A::zero() };
49                }
50            }
51        }
52        Ok(self)
53    }
54}
55
56/// Operations on triangular matrices
57pub trait Triangular {
58    /// Check if matrix is triangular
59    fn is_triangular(&self, uplo: UPLO) -> bool;
60}
61
62impl<A, S> Triangular for ArrayBase<S, Ix2>
63where
64    A: Zero,
65    S: Data<Elem = A>,
66{
67    fn is_triangular(&self, uplo: UPLO) -> bool {
68        if let Ok(n) = check_square(self) {
69            if uplo == UPLO::Upper {
70                for i in 0..n {
71                    for j in 0..i {
72                        if unsafe { !self.at((i, j)).is_zero() } {
73                            return false;
74                        }
75                    }
76                }
77            } else {
78                for i in 0..n {
79                    for j in i + 1..n {
80                        if unsafe { !self.at((i, j)).is_zero() } {
81                            return false;
82                        }
83                    }
84                }
85            }
86            true
87        } else {
88            false
89        }
90    }
91}
92
93#[inline]
94/// Common code for upper and lower triangular solvers
95fn solve_triangular_system_common<A: NdFloat, I: Iterator<Item = usize>, S: SliceArg<Ix2>>(
96    a: &ArrayBase<impl Data<Elem = A>, Ix2>,
97    b: &mut ArrayBase<impl DataMut<Elem = A>, Ix2>,
98    row_iter_fn: impl Fn(usize) -> I,
99    row_slice_fn: impl Fn(usize, usize) -> S,
100    diag_fn: impl Fn(usize) -> A,
101) -> Result<()> {
102    let rows = check_square(a)?;
103    if b.nrows() != rows {
104        return Err(LinalgError::WrongRows {
105            expected: rows,
106            actual: b.nrows(),
107        });
108    }
109    let cols = b.ncols();
110
111    // XXX Switching the col and row loops might lead to better cache locality for row-major
112    // layouts of b
113    for k in 0..cols {
114        for i in row_iter_fn(rows) {
115            let coeff;
116            unsafe {
117                let diag = diag_fn(i);
118                coeff = *b.at((i, k)) / diag;
119                *b.atm((i, k)) = coeff;
120            }
121
122            b.slice_mut(row_slice_fn(i, k))
123                .scaled_add(-coeff, &a.slice(row_slice_fn(i, i)));
124        }
125    }
126
127    Ok(())
128}
129
130/// Generalized implementation for both upper and lower triangular solvers.
131/// Index passed into `diag_fn` is guaranteed to be within the bounds of `MIN(a.nrows, a.ncols)`.
132/// Ensure that the return of `diag_fn` is non-zero, otherwise output will be wrong.
133pub(crate) fn solve_triangular_system<A: NdFloat>(
134    a: &ArrayBase<impl Data<Elem = A>, Ix2>,
135    b: &mut ArrayBase<impl DataMut<Elem = A>, Ix2>,
136    uplo: UPLO,
137    diag_fn: impl Fn(usize) -> A,
138) -> Result<()> {
139    if uplo == UPLO::Upper {
140        solve_triangular_system_common(a, b, |rows| (0..rows).rev(), |r, c| s![..r, c], diag_fn)
141    } else {
142        solve_triangular_system_common(a, b, |rows| (0..rows), |r, c| s![r + 1.., c], diag_fn)
143    }
144}
145
146/// Solves a triangular system
147pub trait SolveTriangularInplace<B> {
148    /// Solves `self * x = b` where `self` is a triangular matrix, modifying `b` into `x` in-place.
149    fn solve_triangular_inplace<'a>(&self, b: &'a mut B, uplo: UPLO) -> Result<&'a mut B>;
150
151    /// Solves `self * x = b` where `self` is a triangular matrix, consuming `b`.
152    fn solve_triangular_into(&self, mut b: B, uplo: UPLO) -> Result<B> {
153        self.solve_triangular_inplace(&mut b, uplo)?;
154        Ok(b)
155    }
156}
157
158impl<A: NdFloat, Si: Data<Elem = A>, So: DataMut<Elem = A>>
159    SolveTriangularInplace<ArrayBase<So, Ix2>> for ArrayBase<Si, Ix2>
160{
161    fn solve_triangular_inplace<'a>(
162        &self,
163        b: &'a mut ArrayBase<So, Ix2>,
164        uplo: UPLO,
165    ) -> Result<&'a mut ArrayBase<So, Ix2>> {
166        solve_triangular_system(self, b, uplo, |i| unsafe { *self.at((i, i)) })?;
167        Ok(b)
168    }
169}
170
171/// Solves a triangular system
172pub trait SolveTriangular<B> {
173    type Output;
174
175    /// Solves `self * x = b` where `self` is a triangular matrix.
176    fn solve_triangular(&self, b: &B, uplo: UPLO) -> Result<Self::Output>;
177}
178
179impl<A: NdFloat, Si: Data<Elem = A>, So: Data<Elem = A>> SolveTriangular<ArrayBase<So, Ix2>>
180    for ArrayBase<Si, Ix2>
181{
182    type Output = Array<A, Ix2>;
183
184    fn solve_triangular(&self, b: &ArrayBase<So, Ix2>, uplo: UPLO) -> Result<Self::Output> {
185        self.solve_triangular_into(b.to_owned(), uplo)
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use approx::assert_abs_diff_eq;
192    use ndarray::{array, Array2};
193
194    use crate::LinalgError;
195
196    use super::*;
197
198    #[test]
199    fn corner_cases() {
200        let empty = Array2::<f64>::zeros((0, 0));
201        assert!(empty.is_triangular(UPLO::Lower));
202        assert!(empty.is_triangular(UPLO::Upper));
203        assert_eq!(empty.clone().into_triangular(UPLO::Lower).unwrap(), empty);
204
205        let one = array![[1]];
206        assert!(one.is_triangular(UPLO::Lower));
207        assert!(one.is_triangular(UPLO::Upper));
208        assert_eq!(one.clone().into_triangular(UPLO::Upper).unwrap(), one);
209        assert_eq!(one.clone().into_triangular(UPLO::Lower).unwrap(), one);
210    }
211
212    #[test]
213    fn non_square() {
214        let row = array![[1, 2, 3], [3, 4, 5]];
215        assert!(!row.is_triangular(UPLO::Lower));
216        assert!(!row.is_triangular(UPLO::Upper));
217        assert!(matches!(
218            row.into_triangular(UPLO::Lower),
219            Err(LinalgError::NotSquare { rows: 2, cols: 3 })
220        ));
221
222        let col = array![[1, 2], [3, 5], [6, 8]];
223        assert!(!col.is_triangular(UPLO::Lower));
224        assert!(!col.is_triangular(UPLO::Upper));
225        assert!(matches!(
226            col.into_triangular(UPLO::Upper),
227            Err(LinalgError::NotSquare { rows: 3, cols: 2 })
228        ));
229    }
230
231    #[test]
232    fn square() {
233        let square = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
234        assert!(!square.is_triangular(UPLO::Lower));
235        assert!(!square.is_triangular(UPLO::Upper));
236
237        let upper = square.clone().into_triangular(UPLO::Upper).unwrap();
238        assert_eq!(upper, array![[1, 2, 3], [0, 5, 6], [0, 0, 9]]);
239        assert!(!upper.is_triangular(UPLO::Lower));
240        assert!(upper.is_triangular(UPLO::Upper));
241
242        let lower = square.into_triangular(UPLO::Lower).unwrap();
243        assert!(lower.is_triangular(UPLO::Lower));
244        assert!(!lower.is_triangular(UPLO::Upper));
245        assert_eq!(lower, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]);
246    }
247
248    #[test]
249    fn solve_triangular() {
250        let lower = array![[1.0, 0.0], [3.0, 4.0]];
251        assert!(lower.is_triangular(UPLO::Lower));
252        let expected = array![[2.2, 3.1, 2.2], [1.0, 0.0, 5.7]];
253        let b = lower.dot(&expected);
254        let x = lower.solve_triangular_into(b, UPLO::Lower).unwrap();
255        assert_abs_diff_eq!(x, expected, epsilon = 1e-7);
256
257        let upper = array![[4.4, 2.1], [0.0, 4.3]];
258        assert!(upper.is_triangular(UPLO::Upper));
259        let b = upper.dot(&expected);
260        let x = upper.solve_triangular_into(b, UPLO::Upper).unwrap();
261        assert_abs_diff_eq!(x, expected, epsilon = 1e-7);
262    }
263
264    #[test]
265    fn solve_corner_cases() {
266        let empty = Array2::<f64>::zeros((0, 0));
267        let out = empty.solve_triangular(&empty, UPLO::Upper).unwrap();
268        assert_eq!(out.dim(), (0, 0));
269
270        let one = Array2::<f64>::ones((1, 1));
271        let out = one.solve_triangular(&one, UPLO::Upper).unwrap();
272        assert_abs_diff_eq!(out, one);
273
274        let diag_zero = array![[0., 3.], [2., 0.]];
275        let zeros = Array2::<f64>::zeros((2, 2));
276        diag_zero.solve_triangular(&zeros, UPLO::Lower).unwrap(); // Just make sure that zeroed diagonals won't crash
277    }
278
279    #[test]
280    fn solve_error() {
281        let non_square = array![[1.2f64, 3.3]];
282        assert!(matches!(
283            non_square
284                .solve_triangular(&non_square, UPLO::Lower)
285                .unwrap_err(),
286            LinalgError::NotSquare { .. }
287        ));
288
289        let square = array![[1.1, 2.2], [3.3, 2.1]];
290        assert!(matches!(
291            square
292                .solve_triangular(&array![[2.2, 3.3]], UPLO::Upper)
293                .unwrap_err(),
294            LinalgError::WrongRows {
295                expected: 2,
296                actual: 1
297            }
298        ));
299    }
300}