1use ndarray::{
4 linalg::{general_mat_mul, general_mat_vec_mul},
5 s, Array1, Array2, ArrayBase, Axis, DataMut, Ix2, NdFloat, RawDataClone,
6};
7
8use crate::{
9 check_square, householder,
10 triangular::{IntoTriangular, UPLO},
11 LinalgError, Result,
12};
13
14pub trait SymmetricTridiagonal {
16 type Decomp;
17
18 fn sym_tridiagonal(self) -> Result<Self::Decomp>;
22}
23
24impl<S, A> SymmetricTridiagonal for ArrayBase<S, Ix2>
25where
26 A: NdFloat,
27 S: DataMut<Elem = A>,
28{
29 type Decomp = TridiagonalDecomp<A, S>;
30
31 fn sym_tridiagonal(mut self) -> Result<Self::Decomp> {
32 let n = check_square(&self)?;
33 if n < 1 {
34 return Err(LinalgError::EmptyMatrix);
35 }
36
37 let mut off_diagonal = Array1::zeros(n - 1); let mut p = Array1::zeros(n - 1);
39
40 for i in 0..n - 1 {
41 let mut m = self.slice_mut(s![i + 1.., ..]);
42 let (mut axis, mut m) = m.multi_slice_mut((s![.., i], s![.., i + 1..]));
43
44 let norm = householder::reflection_axis_mut(&mut axis);
45 *off_diagonal.get_mut(i).unwrap() = norm.unwrap_or_else(A::zero);
46
47 if norm.is_some() {
48 let mut p = p.slice_mut(s![i..]);
49 general_mat_vec_mul(A::from(2.0f64).unwrap(), &m, &axis, A::zero(), &mut p);
50 let dot = axis.dot(&p);
51
52 let p_row = p.view().insert_axis(Axis(0));
53 let p_col = p.view().insert_axis(Axis(1));
54 let ax_row = axis.view().insert_axis(Axis(0));
55 let ax_col = axis.view().insert_axis(Axis(1));
56 general_mat_mul(-A::one(), &p_col, &ax_row, A::one(), &mut m);
57 general_mat_mul(-A::one(), &ax_col, &p_row, A::one(), &mut m);
58 general_mat_mul(dot + dot, &ax_col, &ax_row, A::one(), &mut m);
59 }
60 }
61
62 Ok(TridiagonalDecomp {
63 diag_matrix: self,
64 off_diagonal,
65 })
66 }
67}
68
69#[derive(Debug)]
71pub struct TridiagonalDecomp<A, S: DataMut<Elem = A>> {
72 diag_matrix: ArrayBase<S, Ix2>,
75 off_diagonal: Array1<A>,
77}
78
79impl<A: Clone, S: DataMut<Elem = A> + RawDataClone> Clone for TridiagonalDecomp<A, S> {
80 fn clone(&self) -> Self {
81 Self {
82 diag_matrix: self.diag_matrix.clone(),
83 off_diagonal: self.off_diagonal.clone(),
84 }
85 }
86}
87
88impl<A: NdFloat, S: DataMut<Elem = A>> TridiagonalDecomp<A, S> {
89 pub fn generate_q(&self) -> Array2<A> {
91 householder::assemble_q(&self.diag_matrix, 1, |i| self.off_diagonal[i])
92 }
93
94 pub fn into_diagonals(self) -> (Array1<A>, Array1<A>) {
97 (
98 self.diag_matrix.diag().to_owned(),
99 self.off_diagonal.mapv_into(A::abs),
100 )
101 }
102
103 pub fn into_tridiag_matrix(mut self) -> ArrayBase<S, Ix2> {
105 self.diag_matrix.triangular_inplace(UPLO::Upper).unwrap();
106 self.diag_matrix.triangular_inplace(UPLO::Lower).unwrap();
107 for (i, off) in self.off_diagonal.into_iter().enumerate() {
108 let off = off.abs();
109 self.diag_matrix[(i + 1, i)] = off;
110 self.diag_matrix[(i, i + 1)] = off;
111 }
112 self.diag_matrix
113 }
114}
115
116#[cfg(test)]
117mod tests {
118 use approx::assert_abs_diff_eq;
119 use ndarray::array;
120
121 use super::*;
122
123 #[test]
124 fn sym_tridiagonal() {
125 let arr = array![
126 [4.0f64, 1., -2., 2.],
127 [1., 2., 0., 1.],
128 [-2., 0., 3., -2.],
129 [2., 1., -2., -1.]
130 ];
131
132 let decomp = arr.clone().sym_tridiagonal().unwrap();
133 let (diag, offdiag) = decomp.into_diagonals();
134 assert_abs_diff_eq!(
135 diag,
136 array![4., 10. / 3., -33. / 25., 149. / 75.],
137 epsilon = 1e-5
138 );
139 assert_abs_diff_eq!(offdiag, array![3., 5. / 3., 68. / 75.], epsilon = 1e-5);
140
141 let decomp = arr.clone().sym_tridiagonal().unwrap();
142 let q = decomp.generate_q();
143 let tri = decomp.into_tridiag_matrix();
144 assert_abs_diff_eq!(q.dot(&tri).dot(&q.t()), arr, epsilon = 1e-9);
145 assert_abs_diff_eq!(q.dot(&q.t()), Array2::eye(4), epsilon = 1e-9);
147
148 let one = array![[1.1f64]].sym_tridiagonal().unwrap();
149 let (one_diag, one_offdiag) = one.into_diagonals();
150 assert_abs_diff_eq!(one_diag, array![1.1f64]);
151 assert!(one_offdiag.is_empty());
152 }
153
154 #[test]
155 fn sym_tridiag_error() {
156 assert!(matches!(
157 array![[1., 2., 3.], [5., 4., 3.0f64]].sym_tridiagonal(),
158 Err(LinalgError::NotSquare { rows: 2, cols: 3 })
159 ));
160 assert!(matches!(
161 Array2::<f64>::zeros((0, 0)).sym_tridiagonal(),
162 Err(LinalgError::EmptyMatrix)
163 ));
164 }
165}