linfa_linalg/
tridiagonal.rs

1//! Tridiagonal decomposition of a symmetric matrix
2
3use ndarray::{
4    linalg::{general_mat_mul, general_mat_vec_mul},
5    s, Array1, Array2, ArrayBase, Axis, DataMut, Ix2, NdFloat, RawDataClone,
6};
7
8use crate::{
9    check_square, householder,
10    triangular::{IntoTriangular, UPLO},
11    LinalgError, Result,
12};
13
14/// Tridiagonal decomposition of a non-empty symmetric matrix
15pub trait SymmetricTridiagonal {
16    type Decomp;
17
18    /// Calculate the tridiagonal decomposition of a symmetric matrix, consisting of symmetric
19    /// tridiagonal matrix `T` and orthogonal matrix `Q`, such that `Q * T * Q.t` yields the
20    /// original matrix.
21    fn sym_tridiagonal(self) -> Result<Self::Decomp>;
22}
23
24impl<S, A> SymmetricTridiagonal for ArrayBase<S, Ix2>
25where
26    A: NdFloat,
27    S: DataMut<Elem = A>,
28{
29    type Decomp = TridiagonalDecomp<A, S>;
30
31    fn sym_tridiagonal(mut self) -> Result<Self::Decomp> {
32        let n = check_square(&self)?;
33        if n < 1 {
34            return Err(LinalgError::EmptyMatrix);
35        }
36
37        let mut off_diagonal = Array1::zeros(n - 1); // TODO can be uninit
38        let mut p = Array1::zeros(n - 1);
39
40        for i in 0..n - 1 {
41            let mut m = self.slice_mut(s![i + 1.., ..]);
42            let (mut axis, mut m) = m.multi_slice_mut((s![.., i], s![.., i + 1..]));
43
44            let norm = householder::reflection_axis_mut(&mut axis);
45            *off_diagonal.get_mut(i).unwrap() = norm.unwrap_or_else(A::zero);
46
47            if norm.is_some() {
48                let mut p = p.slice_mut(s![i..]);
49                general_mat_vec_mul(A::from(2.0f64).unwrap(), &m, &axis, A::zero(), &mut p);
50                let dot = axis.dot(&p);
51
52                let p_row = p.view().insert_axis(Axis(0));
53                let p_col = p.view().insert_axis(Axis(1));
54                let ax_row = axis.view().insert_axis(Axis(0));
55                let ax_col = axis.view().insert_axis(Axis(1));
56                general_mat_mul(-A::one(), &p_col, &ax_row, A::one(), &mut m);
57                general_mat_mul(-A::one(), &ax_col, &p_row, A::one(), &mut m);
58                general_mat_mul(dot + dot, &ax_col, &ax_row, A::one(), &mut m);
59            }
60        }
61
62        Ok(TridiagonalDecomp {
63            diag_matrix: self,
64            off_diagonal,
65        })
66    }
67}
68
69/// Full tridiagonal decomposition, containing the symmetric tridiagonal matrix `T`
70#[derive(Debug)]
71pub struct TridiagonalDecomp<A, S: DataMut<Elem = A>> {
72    // This matrix is only useful for its diagonal, which is the diagonal of the tridiagonal matrix
73    // Guaranteed to be square matrix
74    diag_matrix: ArrayBase<S, Ix2>,
75    // The off-diagonal elements of the tridiagonal matrix
76    off_diagonal: Array1<A>,
77}
78
79impl<A: Clone, S: DataMut<Elem = A> + RawDataClone> Clone for TridiagonalDecomp<A, S> {
80    fn clone(&self) -> Self {
81        Self {
82            diag_matrix: self.diag_matrix.clone(),
83            off_diagonal: self.off_diagonal.clone(),
84        }
85    }
86}
87
88impl<A: NdFloat, S: DataMut<Elem = A>> TridiagonalDecomp<A, S> {
89    /// Construct the orthogonal matrix `Q`, where `Q * T * Q.t` results in the original matrix
90    pub fn generate_q(&self) -> Array2<A> {
91        householder::assemble_q(&self.diag_matrix, 1, |i| self.off_diagonal[i])
92    }
93
94    /// Return the diagonal elements and off-diagonal elements of the tridiagonal matrix as 1D
95    /// arrays
96    pub fn into_diagonals(self) -> (Array1<A>, Array1<A>) {
97        (
98            self.diag_matrix.diag().to_owned(),
99            self.off_diagonal.mapv_into(A::abs),
100        )
101    }
102
103    /// Return the full tridiagonal matrix `T`
104    pub fn into_tridiag_matrix(mut self) -> ArrayBase<S, Ix2> {
105        self.diag_matrix.triangular_inplace(UPLO::Upper).unwrap();
106        self.diag_matrix.triangular_inplace(UPLO::Lower).unwrap();
107        for (i, off) in self.off_diagonal.into_iter().enumerate() {
108            let off = off.abs();
109            self.diag_matrix[(i + 1, i)] = off;
110            self.diag_matrix[(i, i + 1)] = off;
111        }
112        self.diag_matrix
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use approx::assert_abs_diff_eq;
119    use ndarray::array;
120
121    use super::*;
122
123    #[test]
124    fn sym_tridiagonal() {
125        let arr = array![
126            [4.0f64, 1., -2., 2.],
127            [1., 2., 0., 1.],
128            [-2., 0., 3., -2.],
129            [2., 1., -2., -1.]
130        ];
131
132        let decomp = arr.clone().sym_tridiagonal().unwrap();
133        let (diag, offdiag) = decomp.into_diagonals();
134        assert_abs_diff_eq!(
135            diag,
136            array![4., 10. / 3., -33. / 25., 149. / 75.],
137            epsilon = 1e-5
138        );
139        assert_abs_diff_eq!(offdiag, array![3., 5. / 3., 68. / 75.], epsilon = 1e-5);
140
141        let decomp = arr.clone().sym_tridiagonal().unwrap();
142        let q = decomp.generate_q();
143        let tri = decomp.into_tridiag_matrix();
144        assert_abs_diff_eq!(q.dot(&tri).dot(&q.t()), arr, epsilon = 1e-9);
145        // Q must be orthogonal
146        assert_abs_diff_eq!(q.dot(&q.t()), Array2::eye(4), epsilon = 1e-9);
147
148        let one = array![[1.1f64]].sym_tridiagonal().unwrap();
149        let (one_diag, one_offdiag) = one.into_diagonals();
150        assert_abs_diff_eq!(one_diag, array![1.1f64]);
151        assert!(one_offdiag.is_empty());
152    }
153
154    #[test]
155    fn sym_tridiag_error() {
156        assert!(matches!(
157            array![[1., 2., 3.], [5., 4., 3.0f64]].sym_tridiagonal(),
158            Err(LinalgError::NotSquare { rows: 2, cols: 3 })
159        ));
160        assert!(matches!(
161            Array2::<f64>::zeros((0, 0)).sym_tridiagonal(),
162            Err(LinalgError::EmptyMatrix)
163        ));
164    }
165}