linfa_linalg/
eigh.rs

1//! Eigendecomposition for symmetric square matrices
2
3use ndarray::{s, Array1, Array2, ArrayBase, Data, DataMut, Ix2, NdFloat};
4
5use crate::{
6    check_square, givens::GivensRotation, index::*, tridiagonal::SymmetricTridiagonal, Order,
7    Result,
8};
9
10fn symmetric_eig<A: NdFloat, S: DataMut<Elem = A>>(
11    mut matrix: ArrayBase<S, Ix2>,
12    eigenvectors: bool,
13    eps: A,
14) -> Result<(Array1<A>, Option<Array2<A>>)> {
15    let dim = check_square(&matrix)?;
16    if dim < 1 {
17        return Ok((
18            Array1::zeros(0),
19            if eigenvectors {
20                Some(Array2::zeros((0, 0)))
21            } else {
22                None
23            },
24        ));
25    }
26
27    let amax = matrix
28        .iter()
29        .map(|f| f.abs())
30        .fold(A::neg_infinity(), |a, b| a.max(b));
31
32    if amax != A::zero() {
33        matrix /= amax;
34    }
35
36    let tridiag_decomp = matrix.sym_tridiagonal()?;
37    let mut q_mat = if eigenvectors {
38        Some(tridiag_decomp.generate_q())
39    } else {
40        None
41    };
42    let (mut diag, mut off_diag) = tridiag_decomp.into_diagonals();
43
44    if dim == 1 {
45        diag *= amax;
46        return Ok((diag, q_mat));
47    }
48
49    let (mut start, mut end) = delimit_subproblem(&diag, &mut off_diag, dim - 1, eps);
50
51    while end != start {
52        let subdim = end - start + 1;
53
54        #[allow(clippy::comparison_chain)]
55        if subdim > 2 {
56            let m = end - 1;
57            let n = end;
58
59            let mut x = diag[start] - wilkinson_shift(diag[m], diag[n], off_diag[m]);
60            let mut y = off_diag[start];
61
62            for i in start..n {
63                let j = i + 1;
64
65                if let Some((rot, norm)) = GivensRotation::cancel_y(x, y) {
66                    if i > start {
67                        unsafe { *off_diag.atm(i - 1) = norm };
68                    }
69
70                    let cc = rot.c() * rot.c();
71                    let ss = rot.s() * rot.s();
72                    let cs = rot.c() * rot.s();
73                    unsafe {
74                        let mii = *diag.at(i);
75                        let mjj = *diag.at(j);
76                        let mij = *off_diag.at(i);
77                        let b = cs * mij * A::from(2.0f64).unwrap();
78                        *diag.atm(i) = cc * mii + ss * mjj - b;
79                        *diag.atm(j) = ss * mii + cc * mjj + b;
80                        *off_diag.atm(i) = cs * (mii - mjj) + mij * (cc - ss);
81
82                        if i != n - 1 {
83                            x = *off_diag.at(i);
84                            y = -rot.s() * *off_diag.at(i + 1);
85                            *off_diag.atm(i + 1) *= rot.c();
86                        }
87                    }
88
89                    if let Some(q) = &mut q_mat {
90                        rot.clone()
91                            .inverse()
92                            .rotate_rows(&mut q.slice_mut(s![.., i..i + 2]))
93                            .unwrap();
94                    }
95                } else {
96                    break;
97                }
98            }
99
100            if off_diag[m].abs() <= eps * (diag[m].abs() + diag[n].abs()) {
101                end -= 1;
102            }
103        } else if subdim == 2 {
104            let eigvals = compute_2x2_eigvals(
105                diag[start],
106                off_diag[start],
107                off_diag[start],
108                diag[start + 1],
109            )
110            .unwrap(); // XXX not sure when this unwrap panics
111            let basis = (eigvals.0 - diag[start + 1], off_diag[start]);
112
113            diag[start] = eigvals.0;
114            diag[start + 1] = eigvals.1;
115
116            if let (Some(q), Some((rot, _))) =
117                (&mut q_mat, GivensRotation::try_new(basis.0, basis.1, eps))
118            {
119                rot.rotate_rows(&mut q.slice_mut(s![.., start..start + 2]))
120                    .unwrap();
121            }
122            end -= 1;
123        }
124
125        let sub = delimit_subproblem(&diag, &mut off_diag, end, eps);
126        start = sub.0;
127        end = sub.1;
128    }
129
130    diag *= amax;
131    Ok((diag, q_mat))
132}
133
134fn delimit_subproblem<A: NdFloat>(
135    diag: &Array1<A>,
136    off_diag: &mut Array1<A>,
137    end: usize,
138    eps: A,
139) -> (usize, usize) {
140    let mut n = end;
141
142    while n > 0 {
143        let m = n - 1;
144        unsafe {
145            if off_diag.at(m).abs() > eps * (diag.at(n).abs() + diag.at(m).abs()) {
146                break;
147            }
148        }
149        n -= 1;
150    }
151
152    if n == 0 {
153        return (0, 0);
154    }
155
156    let mut new_start = n - 1;
157    while new_start > 0 {
158        let m = new_start - 1;
159        unsafe {
160            if off_diag.at(m).is_zero()
161                || off_diag.at(m).abs() <= eps * (diag.at(new_start).abs() + diag.at(m).abs())
162            {
163                *off_diag.atm(m) = A::zero();
164                break;
165            }
166        }
167        new_start -= 1;
168    }
169
170    (new_start, n)
171}
172
173/// Computes the wilkinson shift, i.e., the 2x2 symmetric matrix eigenvalue to its tailing
174/// component `tnn`.
175///
176/// The inputs are interpreted as the 2x2 matrix:
177///     tmm  tmn
178///     tmn  tnn
179pub(crate) fn wilkinson_shift<A: NdFloat>(tmm: A, tnn: A, tmn: A) -> A {
180    if !tmn.is_zero() {
181        let tmn_sq = tmn * tmn;
182        let d = (tmm - tnn) * A::from(0.5).unwrap();
183        tnn - tmn_sq / (d + d.signum() * (d * d + tmn_sq).sqrt())
184    } else {
185        tnn
186    }
187}
188
189fn compute_2x2_eigvals<A: NdFloat>(h00: A, h10: A, h01: A, h11: A) -> Option<(A, A)> {
190    let val = (h00 - h11) * A::from(0.5f64).unwrap();
191    let discr = h10 * h01 + val * val;
192    if discr >= A::zero() {
193        let sqrt_discr = discr.sqrt();
194        let half_tra = (h00 + h11) * A::from(0.5f64).unwrap();
195        Some((half_tra + sqrt_discr, half_tra - sqrt_discr))
196    } else {
197        None
198    }
199}
200
201/// Eigendecomposition of symmetric matrices
202pub trait EighInto: Sized {
203    type EigVal;
204    type EigVec;
205
206    /// Calculate eigenvalues and eigenvectors of symmetric matrices, consuming the original
207    fn eigh_into(self) -> Result<(Self::EigVal, Self::EigVec)>;
208}
209
210impl<A: NdFloat, S: DataMut<Elem = A>> EighInto for ArrayBase<S, Ix2> {
211    type EigVal = Array1<A>;
212    type EigVec = Array2<A>;
213
214    fn eigh_into(self) -> Result<(Self::EigVal, Self::EigVec)> {
215        let (val, vecs) = symmetric_eig(self, true, A::epsilon())?;
216        Ok((val, vecs.unwrap()))
217    }
218}
219
220/// Eigendecomposition of symmetric matrices
221pub trait Eigh {
222    type EigVal;
223    type EigVec;
224
225    /// Calculate eigenvalues and eigenvectors of symmetric matrices
226    fn eigh(&self) -> Result<(Self::EigVal, Self::EigVec)>;
227}
228
229impl<A: NdFloat, S: Data<Elem = A>> Eigh for ArrayBase<S, Ix2> {
230    type EigVal = Array1<A>;
231    type EigVec = Array2<A>;
232
233    fn eigh(&self) -> Result<(Self::EigVal, Self::EigVec)> {
234        self.to_owned().eigh_into()
235    }
236}
237
238/// Eigenvalues of symmetric matrices
239pub trait EigValshInto {
240    type EigVal;
241
242    /// Calculate eigenvalues of symmetric matrices without eigenvectors, consuming the original
243    fn eigvalsh_into(self) -> Result<Self::EigVal>;
244}
245
246impl<A: NdFloat, S: DataMut<Elem = A>> EigValshInto for ArrayBase<S, Ix2> {
247    type EigVal = Array1<A>;
248
249    fn eigvalsh_into(self) -> Result<Self::EigVal> {
250        symmetric_eig(self, false, A::epsilon()).map(|(vals, _)| vals)
251    }
252}
253
254/// Eigenvalues of symmetric matrices
255pub trait EigValsh {
256    type EigVal;
257
258    /// Calculate eigenvalues of symmetric matrices without eigenvectors
259    fn eigvalsh(&self) -> Result<Self::EigVal>;
260}
261
262impl<A: NdFloat, S: Data<Elem = A>> EigValsh for ArrayBase<S, Ix2> {
263    type EigVal = Array1<A>;
264
265    fn eigvalsh(&self) -> Result<Self::EigVal> {
266        self.to_owned().eigvalsh_into()
267    }
268}
269
270/// Sorting of eigendecomposition by the eigenvalues.
271///
272/// ## Panic
273///
274/// Will panic if shape or layout of inputs differ from eigen output, or if input contains NaN.
275pub trait EigSort: Sized {
276    fn sort_eig(self, order: Order) -> Self;
277
278    /// Sort eigendecomposition by the eigenvalues in ascending order
279    fn sort_eig_asc(self) -> Self {
280        self.sort_eig(Order::Smallest)
281    }
282
283    /// Sort eigendecomposition by the eigenvalues in descending order
284    fn sort_eig_desc(self) -> Self {
285        self.sort_eig(Order::Largest)
286    }
287}
288
289/// Implementation on output of `EigValsh` traits
290impl<A: NdFloat> EigSort for Array1<A> {
291    fn sort_eig(mut self, order: Order) -> Self {
292        // Panics on non-standard layouts, which is fine because our eigenvals have standard layout
293        let slice = self.as_slice_mut().unwrap();
294        // Panic only happens with NaN values
295        match order {
296            Order::Largest => slice.sort_by(|a, b| cmp_floats(b, a)),
297            Order::Smallest => slice.sort_by(|a, b| cmp_floats(a, b)),
298        }
299        self
300    }
301}
302
303/// Implementation on output of `Eigh` traits
304impl<A: NdFloat> EigSort for (Array1<A>, Array2<A>) {
305    fn sort_eig(self, order: Order) -> Self {
306        let (mut vals, vecs) = self;
307        let mut value_idx: Vec<_> = vals.iter().copied().enumerate().collect();
308        // Panic only happens with NaN values
309        match order {
310            Order::Largest => value_idx.sort_by(|a, b| cmp_floats(&b.1, &a.1)),
311            Order::Smallest => value_idx.sort_by(|a, b| cmp_floats(&a.1, &b.1)),
312        }
313
314        let mut out = Array2::zeros(vecs.dim());
315        for (out_idx, &(arr_idx, _)) in value_idx.iter().enumerate() {
316            out.column_mut(out_idx).assign(&vecs.column(arr_idx));
317        }
318        vals.iter_mut()
319            .zip(value_idx.iter())
320            .for_each(|(si, (_, f))| *si = *f);
321        (vals, out)
322    }
323}
324
325#[inline]
326pub(crate) fn cmp_floats<A: NdFloat>(a: &A, b: &A) -> std::cmp::Ordering {
327    a.partial_cmp(b).expect("NaN values in array")
328}
329
330#[cfg(test)]
331mod tests {
332    use approx::assert_abs_diff_eq;
333    use ndarray::array;
334    use ndarray::Axis;
335
336    use crate::LinalgError;
337
338    use super::*;
339
340    #[test]
341    fn eigvals_2x2() {
342        let (e1, e2) = compute_2x2_eigvals(5., 4., 3., 2.).unwrap();
343        assert_abs_diff_eq!(e1, 7.2749172, epsilon = 1e-5);
344        assert_abs_diff_eq!(e2, -0.2749172, epsilon = 1e-5);
345
346        let (e1, e2) = compute_2x2_eigvals(6., 2., -1., 3.).unwrap();
347        assert_abs_diff_eq!(e1, 5., epsilon = 1e-5);
348        assert_abs_diff_eq!(e2, 4., epsilon = 1e-5);
349
350        let (e1, e2) = compute_2x2_eigvals(6., 2., 2., 6.).unwrap();
351        assert_abs_diff_eq!(e1, 8., epsilon = 1e-5);
352        assert_abs_diff_eq!(e2, 4., epsilon = 1e-5);
353
354        assert_eq!(compute_2x2_eigvals(-2., 3., -3., -2.), None);
355    }
356
357    #[test]
358    fn symm_eigvals() {
359        let (vals, vecs) = symmetric_eig(array![[6., 2.], [2., 6.]], false, f64::EPSILON).unwrap();
360        assert_abs_diff_eq!(vals, array![8., 4.]);
361        assert_eq!(vecs, None);
362
363        let (vals, vecs) = symmetric_eig(
364            array![[1., -5., 7.], [-5., 2., -9.], [7., -9., 3.]],
365            false,
366            f64::EPSILON,
367        )
368        .unwrap();
369        let vals = vals.sort_eig_asc();
370        assert_abs_diff_eq!(vals, array![-6.86819, -3.41558, 16.28378], epsilon = 1e-5);
371        assert_eq!(vecs, None);
372    }
373
374    fn test_eigvecs(a: Array2<f64>, exp_vals: Array1<f64>) {
375        let n = a.nrows();
376        let (vals, vecs) = symmetric_eig(a.clone(), true, f64::EPSILON).unwrap();
377        let (vals, vecs) = (vals, vecs.unwrap()).sort_eig_desc();
378        assert_abs_diff_eq!(vals, exp_vals, epsilon = 1e-5);
379
380        let s = vecs.t().dot(&vecs);
381        assert_abs_diff_eq!(s, Array2::eye(n), epsilon = 1e-5);
382
383        for (i, v) in vecs.axis_iter(Axis(1)).enumerate() {
384            let av = a.dot(&v);
385            let ev = v.mapv(|x| vals[i] * x);
386            assert_abs_diff_eq!(av, ev, epsilon = 1e-5);
387        }
388    }
389
390    #[test]
391    fn sym_eigvecs1() {
392        test_eigvecs(
393            array![[3., 1., 1.], [1., 3., 1.], [1., 1., 3.]],
394            array![5., 2., 2.],
395        );
396    }
397
398    #[test]
399    fn sym_eigvecs2() {
400        test_eigvecs(array![[6., 2.], [2., 6.]], array![8., 4.]);
401    }
402
403    #[test]
404    fn sym_eigvecs3() {
405        test_eigvecs(
406            array![[1., -5., 7.], [-5., 2., -9.], [7., -9., 3.]],
407            array![16.28378, -3.41558, -6.86819],
408        );
409    }
410
411    #[test]
412    fn corner() {
413        assert_eq!(
414            symmetric_eig(Array2::zeros((0, 0)), false, f64::EPSILON).unwrap(),
415            (Array1::zeros(0), None)
416        );
417        assert_eq!(
418            symmetric_eig(Array2::zeros((0, 0)), true, f64::EPSILON).unwrap(),
419            (Array1::zeros(0), Some(Array2::zeros((0, 0))))
420        );
421
422        symmetric_eig(Array2::zeros((1, 1)), true, f64::EPSILON).unwrap();
423        symmetric_eig(Array2::zeros((4, 4)), true, f64::EPSILON).unwrap();
424        assert!(matches!(
425            symmetric_eig(Array2::zeros((3, 1)), true, f64::EPSILON),
426            Err(LinalgError::NotSquare { rows: 3, cols: 1 })
427        ));
428        // Non-symmetric cases
429        symmetric_eig(array![[5., 4.], [3., 2.]], true, f64::EPSILON).unwrap();
430        symmetric_eig(array![[-2., 3.], [-3., -2.]], true, f64::EPSILON).unwrap();
431    }
432}