1use crate::{
7 check_square,
8 index::*,
9 triangular::{IntoTriangular, SolveTriangularInplace, UPLO},
10 LinalgError, Result,
11};
12
13use ndarray::{Array, Array2, ArrayBase, Data, DataMut, Ix2, NdFloat};
14
15pub trait CholeskyInplace {
17 fn cholesky_inplace_dirty(&mut self) -> Result<&mut Self>;
20
21 fn cholesky_into_dirty(mut self) -> Result<Self>
25 where
26 Self: Sized,
27 {
28 self.cholesky_inplace_dirty()?;
29 Ok(self)
30 }
31
32 fn cholesky_inplace(&mut self) -> Result<&mut Self>;
34
35 fn cholesky_into(mut self) -> Result<Self>
38 where
39 Self: Sized,
40 {
41 self.cholesky_inplace()?;
42 Ok(self)
43 }
44}
45
46impl<A, S> CholeskyInplace for ArrayBase<S, Ix2>
47where
48 A: NdFloat,
49 S: DataMut<Elem = A>,
50{
51 fn cholesky_inplace_dirty(&mut self) -> Result<&mut Self> {
52 let n = check_square(self)?;
53
54 for j in 0..n {
55 let mut d = A::zero();
56 unsafe {
57 for k in 0..j {
58 let mut s = A::zero();
59 for i in 0..k {
60 s += *self.at((k, i)) * *self.at((j, i));
61 }
62 s = (*self.at((j, k)) - s) / *self.at((k, k));
63 *self.atm((j, k)) = s;
64 d += s * s;
65 }
66 d = *self.at((j, j)) - d;
67 }
68
69 if d <= A::zero() {
70 return Err(LinalgError::NotPositiveDefinite);
71 }
72
73 unsafe { *self.atm((j, j)) = d.sqrt() };
74 }
75 Ok(self)
76 }
77
78 fn cholesky_inplace(&mut self) -> Result<&mut Self> {
79 self.cholesky_inplace_dirty()?;
80 self.triangular_inplace(UPLO::Lower)?;
81 Ok(self)
82 }
83}
84
85pub trait Cholesky {
87 type Output;
88
89 fn cholesky_dirty(&self) -> Result<Self::Output>;
93
94 fn cholesky(&self) -> Result<Self::Output>;
97}
98
99impl<A, S> Cholesky for ArrayBase<S, Ix2>
100where
101 A: NdFloat,
102 S: Data<Elem = A>,
103{
104 type Output = Array2<A>;
105
106 fn cholesky_dirty(&self) -> Result<Self::Output> {
107 let arr = self.to_owned();
108 arr.cholesky_into_dirty()
109 }
110
111 fn cholesky(&self) -> Result<Self::Output> {
112 let arr = self.to_owned();
113 arr.cholesky_into()
114 }
115}
116
117pub trait SolveCInplace<B> {
119 fn solvec_inplace<'a>(&mut self, b: &'a mut B) -> Result<&'a mut B>;
123
124 fn solvec_into(&mut self, mut b: B) -> Result<B> {
128 self.solvec_inplace(&mut b)?;
129 Ok(b)
130 }
131}
132
133impl<A: NdFloat, Si: DataMut<Elem = A>, So: DataMut<Elem = A>> SolveCInplace<ArrayBase<So, Ix2>>
134 for ArrayBase<Si, Ix2>
135{
136 fn solvec_inplace<'a>(
137 &mut self,
138 b: &'a mut ArrayBase<So, Ix2>,
139 ) -> Result<&'a mut ArrayBase<So, Ix2>> {
140 let chol = self.cholesky_inplace_dirty()?;
141 chol.solve_triangular_inplace(b, UPLO::Lower)?;
142 chol.t().solve_triangular_inplace(b, UPLO::Upper)?;
143 Ok(b)
144 }
145}
146
147pub trait SolveC<B> {
149 type Output;
150
151 fn solvec(&mut self, b: &B) -> Result<Self::Output>;
153}
154
155impl<A: NdFloat, Si: DataMut<Elem = A>, So: Data<Elem = A>> SolveC<ArrayBase<So, Ix2>>
156 for ArrayBase<Si, Ix2>
157{
158 type Output = Array<A, Ix2>;
159
160 fn solvec(&mut self, b: &ArrayBase<So, Ix2>) -> Result<Self::Output> {
161 self.solvec_into(b.to_owned())
162 }
163}
164
165pub trait InverseCInplace {
167 type Output;
168
169 fn invc_inplace(&mut self) -> Result<Self::Output>;
173}
174
175impl<A: NdFloat, S: DataMut<Elem = A>> InverseCInplace for ArrayBase<S, Ix2> {
176 type Output = Array2<A>;
177
178 fn invc_inplace(&mut self) -> Result<Self::Output> {
179 let eye = Array2::eye(self.nrows());
180 let res = self.solvec_into(eye)?;
181 Ok(res)
182 }
183}
184
185pub trait InverseC {
187 type Output;
188
189 fn invc(&self) -> Result<Self::Output>;
191}
192
193impl<A: NdFloat, S: Data<Elem = A>> InverseC for ArrayBase<S, Ix2> {
194 type Output = Array2<A>;
195
196 fn invc(&self) -> Result<Self::Output> {
197 self.to_owned().invc_inplace()
198 }
199}
200
201#[cfg(test)]
202mod test {
203 use approx::assert_abs_diff_eq;
204 use ndarray::array;
205
206 use super::*;
207
208 #[test]
209 fn decompose() {
210 let arr = array![[25., 15., -5.], [15., 18., 0.], [-5., 0., 11.]];
211 let lower = array![[5.0, 0.0, 0.0], [3.0, 3.0, 0.0], [-1., 1., 3.]];
212
213 let chol = arr.cholesky().unwrap();
214 assert_abs_diff_eq!(chol, lower, epsilon = 1e-7);
215 assert_abs_diff_eq!(chol.dot(&chol.t()), arr, epsilon = 1e-7);
216 }
217
218 #[test]
219 fn bad_matrix() {
220 let mut row = array![[1., 2., 3.], [3., 4., 5.]];
221 assert!(matches!(
222 row.cholesky(),
223 Err(LinalgError::NotSquare { rows: 2, cols: 3 })
224 ));
225 assert!(matches!(
226 row.solvec(&Array2::zeros((2, 3))),
227 Err(LinalgError::NotSquare { rows: 2, cols: 3 })
228 ));
229
230 let mut non_pd = array![[1., 2.], [2., 1.]];
231 assert!(matches!(
232 non_pd.cholesky(),
233 Err(LinalgError::NotPositiveDefinite)
234 ));
235 assert!(matches!(
236 non_pd.solvec(&Array2::zeros((2, 3))),
237 Err(LinalgError::NotPositiveDefinite)
238 ));
239
240 let zeros = array![[0., 0.], [0., 0.]];
241 assert!(matches!(
242 zeros.cholesky(),
243 Err(LinalgError::NotPositiveDefinite)
244 ));
245 }
246
247 #[test]
248 fn solvec() {
249 let mut arr = array![[25., 15., -5.], [15., 18., 0.], [-5., 0., 11.]];
250 let x = array![
251 [10., -3., 2.2, 4.],
252 [0., 2.4, -0.9, 1.1],
253 [5.5, 7.6, 8.1, 10.]
254 ];
255 let b = arr.dot(&x);
256
257 let out = arr.solvec(&b).unwrap();
258 assert_abs_diff_eq!(out, x, epsilon = 1e-7);
259 }
260
261 #[test]
262 fn invc() {
263 let arr = array![[25., 15., -5.], [15., 18., 0.], [-5., 0., 11.]];
264 let inv = arr.invc().unwrap();
265 assert_abs_diff_eq!(arr.dot(&inv), Array2::eye(3));
266 }
267
268 #[test]
269 fn corner_cases() {
270 let empty = Array2::<f64>::zeros((0, 0));
271 assert_eq!(empty.cholesky().unwrap(), empty);
272 assert_eq!(empty.clone().invc().unwrap(), empty);
273
274 let one = array![[1.]];
275 assert_eq!(one.cholesky().unwrap(), one);
276 assert_eq!(one.clone().invc().unwrap(), one);
277 }
278}