linfa_linalg/
cholesky.rs

1//! Cholesky decomposition on symmetric positive definite matrices.
2//!
3//! This module also exports related functionality on symmetric positive definite matrices, such as
4//! solving systems and inversion.
5
6use crate::{
7    check_square,
8    index::*,
9    triangular::{IntoTriangular, SolveTriangularInplace, UPLO},
10    LinalgError, Result,
11};
12
13use ndarray::{Array, Array2, ArrayBase, Data, DataMut, Ix2, NdFloat};
14
15/// Cholesky decomposition of a symmetric positive definite matrix
16pub trait CholeskyInplace {
17    /// Computes decomposition `A = L * L.t` where L is a lower-triangular matrix in place.
18    /// The upper triangle portion is not zeroed out.
19    fn cholesky_inplace_dirty(&mut self) -> Result<&mut Self>;
20
21    /// Computes decomposition `A = L * L.t` where L is a lower-triangular matrix, passing by
22    /// value.
23    /// The upper triangle portion is not zeroed out.
24    fn cholesky_into_dirty(mut self) -> Result<Self>
25    where
26        Self: Sized,
27    {
28        self.cholesky_inplace_dirty()?;
29        Ok(self)
30    }
31
32    /// Computes decomposition `A = L * L.t` where L is a lower-triangular matrix in place.
33    fn cholesky_inplace(&mut self) -> Result<&mut Self>;
34
35    /// Computes decomposition `A = L * L.t` where L is a lower-triangular matrix, passing by
36    /// value.
37    fn cholesky_into(mut self) -> Result<Self>
38    where
39        Self: Sized,
40    {
41        self.cholesky_inplace()?;
42        Ok(self)
43    }
44}
45
46impl<A, S> CholeskyInplace for ArrayBase<S, Ix2>
47where
48    A: NdFloat,
49    S: DataMut<Elem = A>,
50{
51    fn cholesky_inplace_dirty(&mut self) -> Result<&mut Self> {
52        let n = check_square(self)?;
53
54        for j in 0..n {
55            let mut d = A::zero();
56            unsafe {
57                for k in 0..j {
58                    let mut s = A::zero();
59                    for i in 0..k {
60                        s += *self.at((k, i)) * *self.at((j, i));
61                    }
62                    s = (*self.at((j, k)) - s) / *self.at((k, k));
63                    *self.atm((j, k)) = s;
64                    d += s * s;
65                }
66                d = *self.at((j, j)) - d;
67            }
68
69            if d <= A::zero() {
70                return Err(LinalgError::NotPositiveDefinite);
71            }
72
73            unsafe { *self.atm((j, j)) = d.sqrt() };
74        }
75        Ok(self)
76    }
77
78    fn cholesky_inplace(&mut self) -> Result<&mut Self> {
79        self.cholesky_inplace_dirty()?;
80        self.triangular_inplace(UPLO::Lower)?;
81        Ok(self)
82    }
83}
84
85/// Cholesky decomposition of a symmetric positive definite matrix, without modifying the original
86pub trait Cholesky {
87    type Output;
88
89    /// Computes decomposition `A = L * L.t` where L is a lower-triangular matrix without modifying
90    /// or consuming the original.
91    /// The upper triangle portion is not zeroed out.
92    fn cholesky_dirty(&self) -> Result<Self::Output>;
93
94    /// Computes decomposition `A = L * L.t` where L is a lower-triangular matrix without modifying
95    /// or consuming the original.
96    fn cholesky(&self) -> Result<Self::Output>;
97}
98
99impl<A, S> Cholesky for ArrayBase<S, Ix2>
100where
101    A: NdFloat,
102    S: Data<Elem = A>,
103{
104    type Output = Array2<A>;
105
106    fn cholesky_dirty(&self) -> Result<Self::Output> {
107        let arr = self.to_owned();
108        arr.cholesky_into_dirty()
109    }
110
111    fn cholesky(&self) -> Result<Self::Output> {
112        let arr = self.to_owned();
113        arr.cholesky_into()
114    }
115}
116
117/// Solves a symmetric positive definite system
118pub trait SolveCInplace<B> {
119    /// Solves `self * x = b`, where `self` is symmetric positive definite, modifying `b` inplace.
120    ///
121    /// As a side effect, `self` is used to calculate an in-place Cholesky decomposition.
122    fn solvec_inplace<'a>(&mut self, b: &'a mut B) -> Result<&'a mut B>;
123
124    /// Solves `self * x = b`, where `self` is symmetric positive definite, consuming `b`.
125    ///
126    /// As a side effect, `self` is used to calculate an in-place Cholesky decomposition.
127    fn solvec_into(&mut self, mut b: B) -> Result<B> {
128        self.solvec_inplace(&mut b)?;
129        Ok(b)
130    }
131}
132
133impl<A: NdFloat, Si: DataMut<Elem = A>, So: DataMut<Elem = A>> SolveCInplace<ArrayBase<So, Ix2>>
134    for ArrayBase<Si, Ix2>
135{
136    fn solvec_inplace<'a>(
137        &mut self,
138        b: &'a mut ArrayBase<So, Ix2>,
139    ) -> Result<&'a mut ArrayBase<So, Ix2>> {
140        let chol = self.cholesky_inplace_dirty()?;
141        chol.solve_triangular_inplace(b, UPLO::Lower)?;
142        chol.t().solve_triangular_inplace(b, UPLO::Upper)?;
143        Ok(b)
144    }
145}
146
147/// Solves a symmetric positive definite system
148pub trait SolveC<B> {
149    type Output;
150
151    /// Solves `self * x = b`, where `self` is symmetric positive definite.
152    fn solvec(&mut self, b: &B) -> Result<Self::Output>;
153}
154
155impl<A: NdFloat, Si: DataMut<Elem = A>, So: Data<Elem = A>> SolveC<ArrayBase<So, Ix2>>
156    for ArrayBase<Si, Ix2>
157{
158    type Output = Array<A, Ix2>;
159
160    fn solvec(&mut self, b: &ArrayBase<So, Ix2>) -> Result<Self::Output> {
161        self.solvec_into(b.to_owned())
162    }
163}
164
165/// Inverse of a symmetric positive definite matrix
166pub trait InverseCInplace {
167    type Output;
168
169    /// Computes inverse of symmetric positive definite matrix.
170    ///
171    /// As a side effect, `self` is used to calculate an in-place Cholesky decomposition.
172    fn invc_inplace(&mut self) -> Result<Self::Output>;
173}
174
175impl<A: NdFloat, S: DataMut<Elem = A>> InverseCInplace for ArrayBase<S, Ix2> {
176    type Output = Array2<A>;
177
178    fn invc_inplace(&mut self) -> Result<Self::Output> {
179        let eye = Array2::eye(self.nrows());
180        let res = self.solvec_into(eye)?;
181        Ok(res)
182    }
183}
184
185/// Inverse of a symmetric positive definite matrix
186pub trait InverseC {
187    type Output;
188
189    /// Computes inverse of symmetric positive definite matrix.
190    fn invc(&self) -> Result<Self::Output>;
191}
192
193impl<A: NdFloat, S: Data<Elem = A>> InverseC for ArrayBase<S, Ix2> {
194    type Output = Array2<A>;
195
196    fn invc(&self) -> Result<Self::Output> {
197        self.to_owned().invc_inplace()
198    }
199}
200
201#[cfg(test)]
202mod test {
203    use approx::assert_abs_diff_eq;
204    use ndarray::array;
205
206    use super::*;
207
208    #[test]
209    fn decompose() {
210        let arr = array![[25., 15., -5.], [15., 18., 0.], [-5., 0., 11.]];
211        let lower = array![[5.0, 0.0, 0.0], [3.0, 3.0, 0.0], [-1., 1., 3.]];
212
213        let chol = arr.cholesky().unwrap();
214        assert_abs_diff_eq!(chol, lower, epsilon = 1e-7);
215        assert_abs_diff_eq!(chol.dot(&chol.t()), arr, epsilon = 1e-7);
216    }
217
218    #[test]
219    fn bad_matrix() {
220        let mut row = array![[1., 2., 3.], [3., 4., 5.]];
221        assert!(matches!(
222            row.cholesky(),
223            Err(LinalgError::NotSquare { rows: 2, cols: 3 })
224        ));
225        assert!(matches!(
226            row.solvec(&Array2::zeros((2, 3))),
227            Err(LinalgError::NotSquare { rows: 2, cols: 3 })
228        ));
229
230        let mut non_pd = array![[1., 2.], [2., 1.]];
231        assert!(matches!(
232            non_pd.cholesky(),
233            Err(LinalgError::NotPositiveDefinite)
234        ));
235        assert!(matches!(
236            non_pd.solvec(&Array2::zeros((2, 3))),
237            Err(LinalgError::NotPositiveDefinite)
238        ));
239
240        let zeros = array![[0., 0.], [0., 0.]];
241        assert!(matches!(
242            zeros.cholesky(),
243            Err(LinalgError::NotPositiveDefinite)
244        ));
245    }
246
247    #[test]
248    fn solvec() {
249        let mut arr = array![[25., 15., -5.], [15., 18., 0.], [-5., 0., 11.]];
250        let x = array![
251            [10., -3., 2.2, 4.],
252            [0., 2.4, -0.9, 1.1],
253            [5.5, 7.6, 8.1, 10.]
254        ];
255        let b = arr.dot(&x);
256
257        let out = arr.solvec(&b).unwrap();
258        assert_abs_diff_eq!(out, x, epsilon = 1e-7);
259    }
260
261    #[test]
262    fn invc() {
263        let arr = array![[25., 15., -5.], [15., 18., 0.], [-5., 0., 11.]];
264        let inv = arr.invc().unwrap();
265        assert_abs_diff_eq!(arr.dot(&inv), Array2::eye(3));
266    }
267
268    #[test]
269    fn corner_cases() {
270        let empty = Array2::<f64>::zeros((0, 0));
271        assert_eq!(empty.cholesky().unwrap(), empty);
272        assert_eq!(empty.clone().invc().unwrap(), empty);
273
274        let one = array![[1.]];
275        assert_eq!(one.cholesky().unwrap(), one);
276        assert_eq!(one.clone().invc().unwrap(), one);
277    }
278}