linfa_linalg/
qr.rs

1//! QR decomposition of rectangular matrices.
2//!
3//! Note that the QR decomposition implemented here is "thin", so `Q` has dimensions `(r, c)` and
4//! `R` is `(c, c)`, where `r` and `c` are the dimensions of the original matrix.
5//!
6//! This module also exports functionality for calculating matrix inverse and the least squares
7//! problem.
8
9use crate::{
10    check_square, householder,
11    index::UncheckedIndex,
12    reflection::Reflection,
13    triangular::{self, IntoTriangular, UPLO},
14    LinalgError, Result,
15};
16
17use ndarray::{prelude::*, Data, DataMut, OwnedRepr, RawDataClone};
18
19/// QR decomposition for matrix by value
20pub trait QRInto {
21    type Decomp;
22
23    /// Decomposes the matrix into semi-orthogonal matrix `Q` and upper-triangular matrix `R`, such
24    /// that `Q * R` yields the original matrix. Matrix rows must be equal or greater than number
25    /// of columns.
26    fn qr_into(self) -> Result<Self::Decomp>;
27}
28
29impl<A: NdFloat, S: DataMut<Elem = A>> QRInto for ArrayBase<S, Ix2> {
30    type Decomp = QRDecomp<A, S>;
31
32    fn qr_into(mut self) -> Result<Self::Decomp> {
33        let (rows, cols) = self.dim();
34        if self.nrows() < self.ncols() {
35            return Err(LinalgError::NotThin { rows, cols });
36        }
37
38        let mut diag = Array::zeros(cols);
39        for i in 0..cols {
40            diag[i] = householder::clear_column(&mut self, i, 0);
41        }
42
43        Ok(QRDecomp { qr: self, diag })
44    }
45}
46
47/// QR decomposition for matrix by reference
48pub trait QR {
49    type Decomp;
50
51    /// Decomposes the matrix into semi-orthogonal matrix `Q` and upper-triangular matrix `R`, such
52    /// that `Q * R` yields the original matrix. Matrix rows must be equal or greater than number
53    /// of columns.
54    fn qr(&self) -> Result<Self::Decomp>;
55}
56
57impl<A: NdFloat, S: Data<Elem = A>> QR for ArrayBase<S, Ix2> {
58    type Decomp = QRDecomp<A, OwnedRepr<A>>;
59
60    fn qr(&self) -> Result<Self::Decomp> {
61        self.to_owned().qr_into()
62    }
63}
64
65#[derive(Debug)]
66/// Compact representation of a QR decomposition. Can be used to yield the `Q` and `R` matrices or
67/// to calculate the inverse or solve a system.
68pub struct QRDecomp<A, S: DataMut<Elem = A>> {
69    // qr must be a "tall" matrix (rows >= cols)
70    qr: ArrayBase<S, Ix2>,
71    // diag length must be equal to qr.ncols
72    diag: Array1<A>,
73}
74
75impl<A: Clone, S: DataMut<Elem = A> + RawDataClone> Clone for QRDecomp<A, S> {
76    fn clone(&self) -> Self {
77        Self {
78            qr: self.qr.clone(),
79            diag: self.diag.clone(),
80        }
81    }
82}
83
84impl<A: NdFloat, S: DataMut<Elem = A>> QRDecomp<A, S> {
85    /// Generate semi-orthogonal `Q` matrix
86    pub fn generate_q(&self) -> Array2<A> {
87        householder::assemble_q(&self.qr, 0, |i| self.diag[i])
88    }
89
90    /// Consumes `self` to generate the upper-triangular `R` matrix
91    pub fn into_r(self) -> ArrayBase<S, Ix2> {
92        let ncols = self.qr.ncols();
93        let mut r = self.qr.slice_move(s![..ncols, ..ncols]);
94        // Should zero out the lower-triangular portion (not the diagonal)
95        r.triangular_inplace(UPLO::Upper).unwrap();
96        r.diag_mut().assign(&self.diag.mapv_into(A::abs));
97        r
98    }
99
100    /// Generate both `Q` and `R`
101    pub fn into_decomp(self) -> (Array2<A>, ArrayBase<S, Ix2>) {
102        let q = self.generate_q();
103        (q, self.into_r())
104    }
105
106    /// Performs `Q.t * b` in place, without actually producing `Q`.
107    ///
108    /// `b` must have at least R rows, although the output will only reside in the first C rows of
109    /// `b` (R and C are the dimensions of the decomposed matrix).
110    fn qt_mul<Si: DataMut<Elem = A>>(&self, b: &mut ArrayBase<Si, Ix2>) {
111        let cols = self.qr.ncols();
112        for i in 0..cols {
113            let axis = self.qr.slice(s![i.., i]);
114            let refl = Reflection::new(axis, A::zero());
115
116            let mut rows = b.slice_mut(s![i.., ..]);
117            refl.reflect_cols(&mut rows);
118            rows *= self.diag[i].signum();
119        }
120    }
121
122    /// Solves `A * x = b`, where `A` is the original matrix. Used to calculate least squares for
123    /// "thin" matrices (rows >= cols).
124    pub fn solve_into<Si: DataMut<Elem = A>>(
125        &self,
126        mut b: ArrayBase<Si, Ix2>,
127    ) -> Result<ArrayBase<Si, Ix2>> {
128        if self.qr.nrows() != b.nrows() {
129            return Err(LinalgError::WrongRows {
130                expected: self.qr.nrows(),
131                actual: b.nrows(),
132            });
133        }
134        if !self.is_invertible() {
135            return Err(LinalgError::NonInvertible);
136        }
137
138        // Calculate Q.t * b and extract the result
139        self.qt_mul(&mut b);
140        let ncols = self.qr.ncols();
141        let mut b = b.slice_move(s![..ncols, ..]);
142
143        // Equivalent to solving R * x = Q.t * b
144        // This gives the solution to the linear problem
145        triangular::solve_triangular_system(
146            &self.qr.slice(s![..ncols, ..ncols]),
147            &mut b,
148            UPLO::Upper,
149            |i| unsafe { self.diag.at(i).abs() },
150        )?;
151        Ok(b)
152    }
153
154    /// Solves `A.t * x = b`, where `A` is the original matrix. Used to calculate least squares for
155    /// "wide" matrices (rows < cols).
156    pub fn solve_tr_into<Si: DataMut<Elem = A>>(
157        &self,
158        mut b: ArrayBase<Si, Ix2>,
159    ) -> Result<Array2<A>> {
160        if self.qr.ncols() != b.nrows() {
161            return Err(LinalgError::WrongRows {
162                expected: self.qr.ncols(),
163                actual: b.nrows(),
164            });
165        }
166        if !self.is_invertible() {
167            return Err(LinalgError::NonInvertible);
168        }
169
170        let ncols = self.qr.ncols();
171        // Equivalent to solving R.t * m = b, where m is upper portion of x
172        triangular::solve_triangular_system(
173            &self.qr.slice(s![..ncols, ..ncols]).t(),
174            &mut b,
175            UPLO::Lower,
176            |i| unsafe { self.diag.at(i).abs() },
177        )?;
178
179        // XXX Could implement a non-transpose version of qt_mul to reduce allocations
180        Ok(self.generate_q().dot(&b))
181    }
182
183    /// Solves `A * x = b`, where `A` is the original matrix.
184    pub fn solve<Si: Data<Elem = A>>(&self, b: &ArrayBase<Si, Ix2>) -> Result<Array2<A>> {
185        self.solve_into(b.to_owned())
186    }
187
188    /// Solves `A.t * x = b`, where `A` is the original matrix.
189    pub fn solve_tr<Si: Data<Elem = A>>(&self, b: &ArrayBase<Si, Ix2>) -> Result<Array2<A>> {
190        self.solve_tr_into(b.to_owned())
191    }
192
193    /// Checks if original matrix is invertible.
194    pub fn is_invertible(&self) -> bool {
195        // No zeros in the diagonal
196        self.diag.iter().all(|f| !f.is_zero())
197    }
198
199    /// Produce the inverse of the original matrix, if it's invertible.
200    pub fn inverse(&self) -> Result<Array2<A>> {
201        check_square(&self.qr)?;
202        self.solve_into(Array2::eye(self.diag.len()))
203    }
204}
205
206/// Use QR decomposition to calculate least squares by value
207pub trait LeastSquaresQrInto<B> {
208    type Output;
209
210    /// Find solution to `A * x = b` such that `||A * x - b||^2` is minimized
211    fn least_squares_into(self, b: B) -> Result<Self::Output>;
212}
213
214impl<A: NdFloat, Si: DataMut<Elem = A>, So: DataMut<Elem = A>>
215    LeastSquaresQrInto<ArrayBase<So, Ix2>> for ArrayBase<Si, Ix2>
216{
217    type Output = Array2<A>;
218
219    fn least_squares_into(self, b: ArrayBase<So, Ix2>) -> Result<Self::Output> {
220        let out = if self.nrows() >= self.ncols() {
221            self.qr_into()?.solve_into(b)?.into_owned()
222        } else {
223            // If array is fat (rows < cols) then take the QR of the transpose and run the
224            // transpose solving algorithm
225            self.reversed_axes().qr_into()?.solve_tr_into(b)?
226        };
227        Ok(out)
228    }
229}
230
231/// Use QR decomposition to calculate least squares by reference. The `A` matrix is still passed by
232/// value.
233pub trait LeastSquaresQr<B> {
234    type Output;
235
236    /// Find solution to `A * x = b` such that `||A * x - b||^2` is minimized
237    fn least_squares(self, b: &B) -> Result<Self::Output>;
238}
239
240impl<A: NdFloat, Si: DataMut<Elem = A>, So: Data<Elem = A>> LeastSquaresQr<ArrayBase<So, Ix2>>
241    for ArrayBase<Si, Ix2>
242{
243    type Output = Array2<A>;
244
245    fn least_squares(self, b: &ArrayBase<So, Ix2>) -> Result<Self::Output> {
246        self.least_squares_into(b.to_owned())
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use approx::assert_abs_diff_eq;
253
254    use super::*;
255
256    #[test]
257    fn qr() {
258        let arr = array![[3.2, 1.3], [4.4, 5.2], [1.3, 6.7]];
259        let (q, r) = arr.qr().unwrap().into_decomp();
260
261        assert_abs_diff_eq!(
262            q,
263            array![
264                [0.5720674, -0.4115578],
265                [0.7865927, 0.0301901],
266                [0.2324024, 0.9108835]
267            ],
268            epsilon = 1e-5
269        );
270        assert_abs_diff_eq!(r, array![[5.594, 6.391], [0., 5.725]], epsilon = 1e-3);
271
272        let zeros = Array2::<f64>::zeros((2, 2));
273        let (q, r) = zeros.qr().unwrap().into_decomp();
274        assert_abs_diff_eq!(q, Array2::eye(2));
275        assert_abs_diff_eq!(r, zeros);
276    }
277
278    #[test]
279    fn solve() {
280        let a = array![[1., 9.80], [-7., 3.3]];
281        let x = array![[3.2, 1.3, 4.4], [5.2, 1.3, 6.7]];
282        let b = a.dot(&x);
283        let sol = a.qr_into().unwrap().solve(&b).unwrap();
284        assert_abs_diff_eq!(sol, x, epsilon = 1e-5);
285
286        assert_abs_diff_eq!(
287            Array2::<f64>::eye(2)
288                .qr_into()
289                .unwrap()
290                .solve(&Array2::zeros((2, 3)))
291                .unwrap(),
292            Array2::zeros((2, 3))
293        );
294
295        // Test with non-square matrix
296        let a = array![[3.2, 1.3], [4.4, 5.2], [1.3, 6.7]];
297        let x = array![[3.2, 1.3, 4.4], [5.2, 1.3, 6.7]];
298        let b = a.dot(&x);
299        let sol = a.qr_into().unwrap().solve(&b).unwrap();
300        assert_abs_diff_eq!(sol, x, epsilon = 1e-5);
301    }
302
303    #[test]
304    fn solve_tr() {
305        let a = array![[1., 9.80], [-7., 3.3]];
306        let x = array![[3.2, 1.3, 4.4], [5.2, 1.3, 6.7]];
307        let b = a.dot(&x);
308        let sol = a.reversed_axes().qr_into().unwrap().solve_tr(&b).unwrap();
309        assert_abs_diff_eq!(sol, x, epsilon = 1e-5);
310
311        assert_abs_diff_eq!(
312            Array2::<f64>::eye(2)
313                .qr_into()
314                .unwrap()
315                .solve_tr(&Array2::zeros((2, 3)))
316                .unwrap(),
317            Array2::zeros((2, 3))
318        );
319
320        // Test with non-square matrix
321        let a = array![[3.2, 1.3], [4.4, 5.2], [1.3, 6.7]].reversed_axes();
322        let x = array![[3.2, 1.3, 4.4], [5.2, 1.3, 6.7]].reversed_axes();
323        let b = a.dot(&x);
324        let sol = a.t().to_owned().qr_into().unwrap().solve_tr(&b).unwrap();
325        // For some reason we get a different solution than x, but the product is still b
326        assert_abs_diff_eq!(b, a.dot(&sol), epsilon = 1e-7);
327    }
328
329    #[test]
330    fn inverse() {
331        let a = array![[1., 9.80], [-7., 3.3]];
332        assert_abs_diff_eq!(
333            a.qr_into().unwrap().inverse().unwrap(),
334            array![[0.04589, -0.1363], [0.09735, 0.0139]],
335            epsilon = 1e-4
336        );
337
338        assert_abs_diff_eq!(
339            Array2::<f64>::eye(2).qr_into().unwrap().inverse().unwrap(),
340            Array2::eye(2)
341        );
342    }
343
344    #[test]
345    fn non_invertible() {
346        let arr = Array2::<f64>::zeros((2, 2));
347        assert!(matches!(
348            arr.qr().unwrap().inverse().unwrap_err(),
349            LinalgError::NonInvertible
350        ));
351        assert!(matches!(
352            arr.least_squares_into(Array2::zeros((2, 2))).unwrap_err(),
353            LinalgError::NonInvertible
354        ));
355
356        let wide = Array2::<f64>::zeros((2, 3));
357        assert!(matches!(
358            wide.least_squares_into(Array2::zeros((2, 2))).unwrap_err(),
359            LinalgError::NonInvertible
360        ));
361    }
362
363    #[test]
364    fn qt_mul() {
365        let a = array![[1., 9.80], [-7., 3.3]];
366        let mut b = array![[3.2, 1.3, 4.4], [5.2, 1.3, 6.7]];
367        let qr = a.qr_into().unwrap();
368        let res = qr.generate_q().t().dot(&b);
369        qr.qt_mul(&mut b);
370        assert_abs_diff_eq!(b, res, epsilon = 1e-7);
371
372        // Test with non-square matrix
373        let arr = array![[3.2, 1.3], [4.4, 5.2], [1.3, 6.7]];
374        let qr = arr.qr_into().unwrap();
375        let mut b = array![[3.2, 1.3, 4.4], [5.2, 1.3, 6.7]].reversed_axes();
376        let res = qr.generate_q().t().dot(&b);
377        qr.qt_mul(&mut b);
378        assert_abs_diff_eq!(b.slice(s![..2, ..2]), res, epsilon = 1e-7);
379    }
380
381    #[test]
382    fn corner() {
383        let (q, r) = Array2::<f64>::zeros((0, 0))
384            .qr_into()
385            .unwrap()
386            .into_decomp();
387        assert!(q.is_empty());
388        assert!(r.is_empty());
389
390        assert!(matches!(
391            Array2::<f64>::zeros((2, 3)).qr_into().unwrap_err(),
392            LinalgError::NotThin { rows: 2, cols: 3 }
393        ));
394        assert!(matches!(
395            Array2::<f64>::zeros((3, 2))
396                .qr_into()
397                .unwrap()
398                .inverse()
399                .unwrap_err(),
400            LinalgError::NotSquare { rows: 3, cols: 2 }
401        ));
402    }
403}