1use crate::{
10 check_square, householder,
11 index::UncheckedIndex,
12 reflection::Reflection,
13 triangular::{self, IntoTriangular, UPLO},
14 LinalgError, Result,
15};
16
17use ndarray::{prelude::*, Data, DataMut, OwnedRepr, RawDataClone};
18
19pub trait QRInto {
21 type Decomp;
22
23 fn qr_into(self) -> Result<Self::Decomp>;
27}
28
29impl<A: NdFloat, S: DataMut<Elem = A>> QRInto for ArrayBase<S, Ix2> {
30 type Decomp = QRDecomp<A, S>;
31
32 fn qr_into(mut self) -> Result<Self::Decomp> {
33 let (rows, cols) = self.dim();
34 if self.nrows() < self.ncols() {
35 return Err(LinalgError::NotThin { rows, cols });
36 }
37
38 let mut diag = Array::zeros(cols);
39 for i in 0..cols {
40 diag[i] = householder::clear_column(&mut self, i, 0);
41 }
42
43 Ok(QRDecomp { qr: self, diag })
44 }
45}
46
47pub trait QR {
49 type Decomp;
50
51 fn qr(&self) -> Result<Self::Decomp>;
55}
56
57impl<A: NdFloat, S: Data<Elem = A>> QR for ArrayBase<S, Ix2> {
58 type Decomp = QRDecomp<A, OwnedRepr<A>>;
59
60 fn qr(&self) -> Result<Self::Decomp> {
61 self.to_owned().qr_into()
62 }
63}
64
65#[derive(Debug)]
66pub struct QRDecomp<A, S: DataMut<Elem = A>> {
69 qr: ArrayBase<S, Ix2>,
71 diag: Array1<A>,
73}
74
75impl<A: Clone, S: DataMut<Elem = A> + RawDataClone> Clone for QRDecomp<A, S> {
76 fn clone(&self) -> Self {
77 Self {
78 qr: self.qr.clone(),
79 diag: self.diag.clone(),
80 }
81 }
82}
83
84impl<A: NdFloat, S: DataMut<Elem = A>> QRDecomp<A, S> {
85 pub fn generate_q(&self) -> Array2<A> {
87 householder::assemble_q(&self.qr, 0, |i| self.diag[i])
88 }
89
90 pub fn into_r(self) -> ArrayBase<S, Ix2> {
92 let ncols = self.qr.ncols();
93 let mut r = self.qr.slice_move(s![..ncols, ..ncols]);
94 r.triangular_inplace(UPLO::Upper).unwrap();
96 r.diag_mut().assign(&self.diag.mapv_into(A::abs));
97 r
98 }
99
100 pub fn into_decomp(self) -> (Array2<A>, ArrayBase<S, Ix2>) {
102 let q = self.generate_q();
103 (q, self.into_r())
104 }
105
106 fn qt_mul<Si: DataMut<Elem = A>>(&self, b: &mut ArrayBase<Si, Ix2>) {
111 let cols = self.qr.ncols();
112 for i in 0..cols {
113 let axis = self.qr.slice(s![i.., i]);
114 let refl = Reflection::new(axis, A::zero());
115
116 let mut rows = b.slice_mut(s![i.., ..]);
117 refl.reflect_cols(&mut rows);
118 rows *= self.diag[i].signum();
119 }
120 }
121
122 pub fn solve_into<Si: DataMut<Elem = A>>(
125 &self,
126 mut b: ArrayBase<Si, Ix2>,
127 ) -> Result<ArrayBase<Si, Ix2>> {
128 if self.qr.nrows() != b.nrows() {
129 return Err(LinalgError::WrongRows {
130 expected: self.qr.nrows(),
131 actual: b.nrows(),
132 });
133 }
134 if !self.is_invertible() {
135 return Err(LinalgError::NonInvertible);
136 }
137
138 self.qt_mul(&mut b);
140 let ncols = self.qr.ncols();
141 let mut b = b.slice_move(s![..ncols, ..]);
142
143 triangular::solve_triangular_system(
146 &self.qr.slice(s![..ncols, ..ncols]),
147 &mut b,
148 UPLO::Upper,
149 |i| unsafe { self.diag.at(i).abs() },
150 )?;
151 Ok(b)
152 }
153
154 pub fn solve_tr_into<Si: DataMut<Elem = A>>(
157 &self,
158 mut b: ArrayBase<Si, Ix2>,
159 ) -> Result<Array2<A>> {
160 if self.qr.ncols() != b.nrows() {
161 return Err(LinalgError::WrongRows {
162 expected: self.qr.ncols(),
163 actual: b.nrows(),
164 });
165 }
166 if !self.is_invertible() {
167 return Err(LinalgError::NonInvertible);
168 }
169
170 let ncols = self.qr.ncols();
171 triangular::solve_triangular_system(
173 &self.qr.slice(s![..ncols, ..ncols]).t(),
174 &mut b,
175 UPLO::Lower,
176 |i| unsafe { self.diag.at(i).abs() },
177 )?;
178
179 Ok(self.generate_q().dot(&b))
181 }
182
183 pub fn solve<Si: Data<Elem = A>>(&self, b: &ArrayBase<Si, Ix2>) -> Result<Array2<A>> {
185 self.solve_into(b.to_owned())
186 }
187
188 pub fn solve_tr<Si: Data<Elem = A>>(&self, b: &ArrayBase<Si, Ix2>) -> Result<Array2<A>> {
190 self.solve_tr_into(b.to_owned())
191 }
192
193 pub fn is_invertible(&self) -> bool {
195 self.diag.iter().all(|f| !f.is_zero())
197 }
198
199 pub fn inverse(&self) -> Result<Array2<A>> {
201 check_square(&self.qr)?;
202 self.solve_into(Array2::eye(self.diag.len()))
203 }
204}
205
206pub trait LeastSquaresQrInto<B> {
208 type Output;
209
210 fn least_squares_into(self, b: B) -> Result<Self::Output>;
212}
213
214impl<A: NdFloat, Si: DataMut<Elem = A>, So: DataMut<Elem = A>>
215 LeastSquaresQrInto<ArrayBase<So, Ix2>> for ArrayBase<Si, Ix2>
216{
217 type Output = Array2<A>;
218
219 fn least_squares_into(self, b: ArrayBase<So, Ix2>) -> Result<Self::Output> {
220 let out = if self.nrows() >= self.ncols() {
221 self.qr_into()?.solve_into(b)?.into_owned()
222 } else {
223 self.reversed_axes().qr_into()?.solve_tr_into(b)?
226 };
227 Ok(out)
228 }
229}
230
231pub trait LeastSquaresQr<B> {
234 type Output;
235
236 fn least_squares(self, b: &B) -> Result<Self::Output>;
238}
239
240impl<A: NdFloat, Si: DataMut<Elem = A>, So: Data<Elem = A>> LeastSquaresQr<ArrayBase<So, Ix2>>
241 for ArrayBase<Si, Ix2>
242{
243 type Output = Array2<A>;
244
245 fn least_squares(self, b: &ArrayBase<So, Ix2>) -> Result<Self::Output> {
246 self.least_squares_into(b.to_owned())
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use approx::assert_abs_diff_eq;
253
254 use super::*;
255
256 #[test]
257 fn qr() {
258 let arr = array![[3.2, 1.3], [4.4, 5.2], [1.3, 6.7]];
259 let (q, r) = arr.qr().unwrap().into_decomp();
260
261 assert_abs_diff_eq!(
262 q,
263 array![
264 [0.5720674, -0.4115578],
265 [0.7865927, 0.0301901],
266 [0.2324024, 0.9108835]
267 ],
268 epsilon = 1e-5
269 );
270 assert_abs_diff_eq!(r, array![[5.594, 6.391], [0., 5.725]], epsilon = 1e-3);
271
272 let zeros = Array2::<f64>::zeros((2, 2));
273 let (q, r) = zeros.qr().unwrap().into_decomp();
274 assert_abs_diff_eq!(q, Array2::eye(2));
275 assert_abs_diff_eq!(r, zeros);
276 }
277
278 #[test]
279 fn solve() {
280 let a = array![[1., 9.80], [-7., 3.3]];
281 let x = array![[3.2, 1.3, 4.4], [5.2, 1.3, 6.7]];
282 let b = a.dot(&x);
283 let sol = a.qr_into().unwrap().solve(&b).unwrap();
284 assert_abs_diff_eq!(sol, x, epsilon = 1e-5);
285
286 assert_abs_diff_eq!(
287 Array2::<f64>::eye(2)
288 .qr_into()
289 .unwrap()
290 .solve(&Array2::zeros((2, 3)))
291 .unwrap(),
292 Array2::zeros((2, 3))
293 );
294
295 let a = array![[3.2, 1.3], [4.4, 5.2], [1.3, 6.7]];
297 let x = array![[3.2, 1.3, 4.4], [5.2, 1.3, 6.7]];
298 let b = a.dot(&x);
299 let sol = a.qr_into().unwrap().solve(&b).unwrap();
300 assert_abs_diff_eq!(sol, x, epsilon = 1e-5);
301 }
302
303 #[test]
304 fn solve_tr() {
305 let a = array![[1., 9.80], [-7., 3.3]];
306 let x = array![[3.2, 1.3, 4.4], [5.2, 1.3, 6.7]];
307 let b = a.dot(&x);
308 let sol = a.reversed_axes().qr_into().unwrap().solve_tr(&b).unwrap();
309 assert_abs_diff_eq!(sol, x, epsilon = 1e-5);
310
311 assert_abs_diff_eq!(
312 Array2::<f64>::eye(2)
313 .qr_into()
314 .unwrap()
315 .solve_tr(&Array2::zeros((2, 3)))
316 .unwrap(),
317 Array2::zeros((2, 3))
318 );
319
320 let a = array![[3.2, 1.3], [4.4, 5.2], [1.3, 6.7]].reversed_axes();
322 let x = array![[3.2, 1.3, 4.4], [5.2, 1.3, 6.7]].reversed_axes();
323 let b = a.dot(&x);
324 let sol = a.t().to_owned().qr_into().unwrap().solve_tr(&b).unwrap();
325 assert_abs_diff_eq!(b, a.dot(&sol), epsilon = 1e-7);
327 }
328
329 #[test]
330 fn inverse() {
331 let a = array![[1., 9.80], [-7., 3.3]];
332 assert_abs_diff_eq!(
333 a.qr_into().unwrap().inverse().unwrap(),
334 array![[0.04589, -0.1363], [0.09735, 0.0139]],
335 epsilon = 1e-4
336 );
337
338 assert_abs_diff_eq!(
339 Array2::<f64>::eye(2).qr_into().unwrap().inverse().unwrap(),
340 Array2::eye(2)
341 );
342 }
343
344 #[test]
345 fn non_invertible() {
346 let arr = Array2::<f64>::zeros((2, 2));
347 assert!(matches!(
348 arr.qr().unwrap().inverse().unwrap_err(),
349 LinalgError::NonInvertible
350 ));
351 assert!(matches!(
352 arr.least_squares_into(Array2::zeros((2, 2))).unwrap_err(),
353 LinalgError::NonInvertible
354 ));
355
356 let wide = Array2::<f64>::zeros((2, 3));
357 assert!(matches!(
358 wide.least_squares_into(Array2::zeros((2, 2))).unwrap_err(),
359 LinalgError::NonInvertible
360 ));
361 }
362
363 #[test]
364 fn qt_mul() {
365 let a = array![[1., 9.80], [-7., 3.3]];
366 let mut b = array![[3.2, 1.3, 4.4], [5.2, 1.3, 6.7]];
367 let qr = a.qr_into().unwrap();
368 let res = qr.generate_q().t().dot(&b);
369 qr.qt_mul(&mut b);
370 assert_abs_diff_eq!(b, res, epsilon = 1e-7);
371
372 let arr = array![[3.2, 1.3], [4.4, 5.2], [1.3, 6.7]];
374 let qr = arr.qr_into().unwrap();
375 let mut b = array![[3.2, 1.3, 4.4], [5.2, 1.3, 6.7]].reversed_axes();
376 let res = qr.generate_q().t().dot(&b);
377 qr.qt_mul(&mut b);
378 assert_abs_diff_eq!(b.slice(s![..2, ..2]), res, epsilon = 1e-7);
379 }
380
381 #[test]
382 fn corner() {
383 let (q, r) = Array2::<f64>::zeros((0, 0))
384 .qr_into()
385 .unwrap()
386 .into_decomp();
387 assert!(q.is_empty());
388 assert!(r.is_empty());
389
390 assert!(matches!(
391 Array2::<f64>::zeros((2, 3)).qr_into().unwrap_err(),
392 LinalgError::NotThin { rows: 2, cols: 3 }
393 ));
394 assert!(matches!(
395 Array2::<f64>::zeros((3, 2))
396 .qr_into()
397 .unwrap()
398 .inverse()
399 .unwrap_err(),
400 LinalgError::NotSquare { rows: 3, cols: 2 }
401 ));
402 }
403}