1use ndarray::{s, Array1, Array2, ArrayBase, Data, DataMut, Ix2, NdFloat};
4
5use crate::{
6 check_square, givens::GivensRotation, index::*, tridiagonal::SymmetricTridiagonal, Order,
7 Result,
8};
9
10fn symmetric_eig<A: NdFloat, S: DataMut<Elem = A>>(
11 mut matrix: ArrayBase<S, Ix2>,
12 eigenvectors: bool,
13 eps: A,
14) -> Result<(Array1<A>, Option<Array2<A>>)> {
15 let dim = check_square(&matrix)?;
16 if dim < 1 {
17 return Ok((
18 Array1::zeros(0),
19 if eigenvectors {
20 Some(Array2::zeros((0, 0)))
21 } else {
22 None
23 },
24 ));
25 }
26
27 let amax = matrix
28 .iter()
29 .map(|f| f.abs())
30 .fold(A::neg_infinity(), |a, b| a.max(b));
31
32 if amax != A::zero() {
33 matrix /= amax;
34 }
35
36 let tridiag_decomp = matrix.sym_tridiagonal()?;
37 let mut q_mat = if eigenvectors {
38 Some(tridiag_decomp.generate_q())
39 } else {
40 None
41 };
42 let (mut diag, mut off_diag) = tridiag_decomp.into_diagonals();
43
44 if dim == 1 {
45 diag *= amax;
46 return Ok((diag, q_mat));
47 }
48
49 let (mut start, mut end) = delimit_subproblem(&diag, &mut off_diag, dim - 1, eps);
50
51 while end != start {
52 let subdim = end - start + 1;
53
54 #[allow(clippy::comparison_chain)]
55 if subdim > 2 {
56 let m = end - 1;
57 let n = end;
58
59 let mut x = diag[start] - wilkinson_shift(diag[m], diag[n], off_diag[m]);
60 let mut y = off_diag[start];
61
62 for i in start..n {
63 let j = i + 1;
64
65 if let Some((rot, norm)) = GivensRotation::cancel_y(x, y) {
66 if i > start {
67 unsafe { *off_diag.atm(i - 1) = norm };
68 }
69
70 let cc = rot.c() * rot.c();
71 let ss = rot.s() * rot.s();
72 let cs = rot.c() * rot.s();
73 unsafe {
74 let mii = *diag.at(i);
75 let mjj = *diag.at(j);
76 let mij = *off_diag.at(i);
77 let b = cs * mij * A::from(2.0f64).unwrap();
78 *diag.atm(i) = cc * mii + ss * mjj - b;
79 *diag.atm(j) = ss * mii + cc * mjj + b;
80 *off_diag.atm(i) = cs * (mii - mjj) + mij * (cc - ss);
81
82 if i != n - 1 {
83 x = *off_diag.at(i);
84 y = -rot.s() * *off_diag.at(i + 1);
85 *off_diag.atm(i + 1) *= rot.c();
86 }
87 }
88
89 if let Some(q) = &mut q_mat {
90 rot.clone()
91 .inverse()
92 .rotate_rows(&mut q.slice_mut(s![.., i..i + 2]))
93 .unwrap();
94 }
95 } else {
96 break;
97 }
98 }
99
100 if off_diag[m].abs() <= eps * (diag[m].abs() + diag[n].abs()) {
101 end -= 1;
102 }
103 } else if subdim == 2 {
104 let eigvals = compute_2x2_eigvals(
105 diag[start],
106 off_diag[start],
107 off_diag[start],
108 diag[start + 1],
109 )
110 .unwrap(); let basis = (eigvals.0 - diag[start + 1], off_diag[start]);
112
113 diag[start] = eigvals.0;
114 diag[start + 1] = eigvals.1;
115
116 if let (Some(q), Some((rot, _))) =
117 (&mut q_mat, GivensRotation::try_new(basis.0, basis.1, eps))
118 {
119 rot.rotate_rows(&mut q.slice_mut(s![.., start..start + 2]))
120 .unwrap();
121 }
122 end -= 1;
123 }
124
125 let sub = delimit_subproblem(&diag, &mut off_diag, end, eps);
126 start = sub.0;
127 end = sub.1;
128 }
129
130 diag *= amax;
131 Ok((diag, q_mat))
132}
133
134fn delimit_subproblem<A: NdFloat>(
135 diag: &Array1<A>,
136 off_diag: &mut Array1<A>,
137 end: usize,
138 eps: A,
139) -> (usize, usize) {
140 let mut n = end;
141
142 while n > 0 {
143 let m = n - 1;
144 unsafe {
145 if off_diag.at(m).abs() > eps * (diag.at(n).abs() + diag.at(m).abs()) {
146 break;
147 }
148 }
149 n -= 1;
150 }
151
152 if n == 0 {
153 return (0, 0);
154 }
155
156 let mut new_start = n - 1;
157 while new_start > 0 {
158 let m = new_start - 1;
159 unsafe {
160 if off_diag.at(m).is_zero()
161 || off_diag.at(m).abs() <= eps * (diag.at(new_start).abs() + diag.at(m).abs())
162 {
163 *off_diag.atm(m) = A::zero();
164 break;
165 }
166 }
167 new_start -= 1;
168 }
169
170 (new_start, n)
171}
172
173pub(crate) fn wilkinson_shift<A: NdFloat>(tmm: A, tnn: A, tmn: A) -> A {
180 if !tmn.is_zero() {
181 let tmn_sq = tmn * tmn;
182 let d = (tmm - tnn) * A::from(0.5).unwrap();
183 tnn - tmn_sq / (d + d.signum() * (d * d + tmn_sq).sqrt())
184 } else {
185 tnn
186 }
187}
188
189fn compute_2x2_eigvals<A: NdFloat>(h00: A, h10: A, h01: A, h11: A) -> Option<(A, A)> {
190 let val = (h00 - h11) * A::from(0.5f64).unwrap();
191 let discr = h10 * h01 + val * val;
192 if discr >= A::zero() {
193 let sqrt_discr = discr.sqrt();
194 let half_tra = (h00 + h11) * A::from(0.5f64).unwrap();
195 Some((half_tra + sqrt_discr, half_tra - sqrt_discr))
196 } else {
197 None
198 }
199}
200
201pub trait EighInto: Sized {
203 type EigVal;
204 type EigVec;
205
206 fn eigh_into(self) -> Result<(Self::EigVal, Self::EigVec)>;
208}
209
210impl<A: NdFloat, S: DataMut<Elem = A>> EighInto for ArrayBase<S, Ix2> {
211 type EigVal = Array1<A>;
212 type EigVec = Array2<A>;
213
214 fn eigh_into(self) -> Result<(Self::EigVal, Self::EigVec)> {
215 let (val, vecs) = symmetric_eig(self, true, A::epsilon())?;
216 Ok((val, vecs.unwrap()))
217 }
218}
219
220pub trait Eigh {
222 type EigVal;
223 type EigVec;
224
225 fn eigh(&self) -> Result<(Self::EigVal, Self::EigVec)>;
227}
228
229impl<A: NdFloat, S: Data<Elem = A>> Eigh for ArrayBase<S, Ix2> {
230 type EigVal = Array1<A>;
231 type EigVec = Array2<A>;
232
233 fn eigh(&self) -> Result<(Self::EigVal, Self::EigVec)> {
234 self.to_owned().eigh_into()
235 }
236}
237
238pub trait EigValshInto {
240 type EigVal;
241
242 fn eigvalsh_into(self) -> Result<Self::EigVal>;
244}
245
246impl<A: NdFloat, S: DataMut<Elem = A>> EigValshInto for ArrayBase<S, Ix2> {
247 type EigVal = Array1<A>;
248
249 fn eigvalsh_into(self) -> Result<Self::EigVal> {
250 symmetric_eig(self, false, A::epsilon()).map(|(vals, _)| vals)
251 }
252}
253
254pub trait EigValsh {
256 type EigVal;
257
258 fn eigvalsh(&self) -> Result<Self::EigVal>;
260}
261
262impl<A: NdFloat, S: Data<Elem = A>> EigValsh for ArrayBase<S, Ix2> {
263 type EigVal = Array1<A>;
264
265 fn eigvalsh(&self) -> Result<Self::EigVal> {
266 self.to_owned().eigvalsh_into()
267 }
268}
269
270pub trait EigSort: Sized {
276 fn sort_eig(self, order: Order) -> Self;
277
278 fn sort_eig_asc(self) -> Self {
280 self.sort_eig(Order::Smallest)
281 }
282
283 fn sort_eig_desc(self) -> Self {
285 self.sort_eig(Order::Largest)
286 }
287}
288
289impl<A: NdFloat> EigSort for Array1<A> {
291 fn sort_eig(mut self, order: Order) -> Self {
292 let slice = self.as_slice_mut().unwrap();
294 match order {
296 Order::Largest => slice.sort_by(|a, b| cmp_floats(b, a)),
297 Order::Smallest => slice.sort_by(|a, b| cmp_floats(a, b)),
298 }
299 self
300 }
301}
302
303impl<A: NdFloat> EigSort for (Array1<A>, Array2<A>) {
305 fn sort_eig(self, order: Order) -> Self {
306 let (mut vals, vecs) = self;
307 let mut value_idx: Vec<_> = vals.iter().copied().enumerate().collect();
308 match order {
310 Order::Largest => value_idx.sort_by(|a, b| cmp_floats(&b.1, &a.1)),
311 Order::Smallest => value_idx.sort_by(|a, b| cmp_floats(&a.1, &b.1)),
312 }
313
314 let mut out = Array2::zeros(vecs.dim());
315 for (out_idx, &(arr_idx, _)) in value_idx.iter().enumerate() {
316 out.column_mut(out_idx).assign(&vecs.column(arr_idx));
317 }
318 vals.iter_mut()
319 .zip(value_idx.iter())
320 .for_each(|(si, (_, f))| *si = *f);
321 (vals, out)
322 }
323}
324
325#[inline]
326pub(crate) fn cmp_floats<A: NdFloat>(a: &A, b: &A) -> std::cmp::Ordering {
327 a.partial_cmp(b).expect("NaN values in array")
328}
329
330#[cfg(test)]
331mod tests {
332 use approx::assert_abs_diff_eq;
333 use ndarray::array;
334 use ndarray::Axis;
335
336 use crate::LinalgError;
337
338 use super::*;
339
340 #[test]
341 fn eigvals_2x2() {
342 let (e1, e2) = compute_2x2_eigvals(5., 4., 3., 2.).unwrap();
343 assert_abs_diff_eq!(e1, 7.2749172, epsilon = 1e-5);
344 assert_abs_diff_eq!(e2, -0.2749172, epsilon = 1e-5);
345
346 let (e1, e2) = compute_2x2_eigvals(6., 2., -1., 3.).unwrap();
347 assert_abs_diff_eq!(e1, 5., epsilon = 1e-5);
348 assert_abs_diff_eq!(e2, 4., epsilon = 1e-5);
349
350 let (e1, e2) = compute_2x2_eigvals(6., 2., 2., 6.).unwrap();
351 assert_abs_diff_eq!(e1, 8., epsilon = 1e-5);
352 assert_abs_diff_eq!(e2, 4., epsilon = 1e-5);
353
354 assert_eq!(compute_2x2_eigvals(-2., 3., -3., -2.), None);
355 }
356
357 #[test]
358 fn symm_eigvals() {
359 let (vals, vecs) = symmetric_eig(array![[6., 2.], [2., 6.]], false, f64::EPSILON).unwrap();
360 assert_abs_diff_eq!(vals, array![8., 4.]);
361 assert_eq!(vecs, None);
362
363 let (vals, vecs) = symmetric_eig(
364 array![[1., -5., 7.], [-5., 2., -9.], [7., -9., 3.]],
365 false,
366 f64::EPSILON,
367 )
368 .unwrap();
369 let vals = vals.sort_eig_asc();
370 assert_abs_diff_eq!(vals, array![-6.86819, -3.41558, 16.28378], epsilon = 1e-5);
371 assert_eq!(vecs, None);
372 }
373
374 fn test_eigvecs(a: Array2<f64>, exp_vals: Array1<f64>) {
375 let n = a.nrows();
376 let (vals, vecs) = symmetric_eig(a.clone(), true, f64::EPSILON).unwrap();
377 let (vals, vecs) = (vals, vecs.unwrap()).sort_eig_desc();
378 assert_abs_diff_eq!(vals, exp_vals, epsilon = 1e-5);
379
380 let s = vecs.t().dot(&vecs);
381 assert_abs_diff_eq!(s, Array2::eye(n), epsilon = 1e-5);
382
383 for (i, v) in vecs.axis_iter(Axis(1)).enumerate() {
384 let av = a.dot(&v);
385 let ev = v.mapv(|x| vals[i] * x);
386 assert_abs_diff_eq!(av, ev, epsilon = 1e-5);
387 }
388 }
389
390 #[test]
391 fn sym_eigvecs1() {
392 test_eigvecs(
393 array![[3., 1., 1.], [1., 3., 1.], [1., 1., 3.]],
394 array![5., 2., 2.],
395 );
396 }
397
398 #[test]
399 fn sym_eigvecs2() {
400 test_eigvecs(array![[6., 2.], [2., 6.]], array![8., 4.]);
401 }
402
403 #[test]
404 fn sym_eigvecs3() {
405 test_eigvecs(
406 array![[1., -5., 7.], [-5., 2., -9.], [7., -9., 3.]],
407 array![16.28378, -3.41558, -6.86819],
408 );
409 }
410
411 #[test]
412 fn corner() {
413 assert_eq!(
414 symmetric_eig(Array2::zeros((0, 0)), false, f64::EPSILON).unwrap(),
415 (Array1::zeros(0), None)
416 );
417 assert_eq!(
418 symmetric_eig(Array2::zeros((0, 0)), true, f64::EPSILON).unwrap(),
419 (Array1::zeros(0), Some(Array2::zeros((0, 0))))
420 );
421
422 symmetric_eig(Array2::zeros((1, 1)), true, f64::EPSILON).unwrap();
423 symmetric_eig(Array2::zeros((4, 4)), true, f64::EPSILON).unwrap();
424 assert!(matches!(
425 symmetric_eig(Array2::zeros((3, 1)), true, f64::EPSILON),
426 Err(LinalgError::NotSquare { rows: 3, cols: 1 })
427 ));
428 symmetric_eig(array![[5., 4.], [3., 2.]], true, f64::EPSILON).unwrap();
430 symmetric_eig(array![[-2., 3.], [-3., -2.]], true, f64::EPSILON).unwrap();
431 }
432}