1use ndarray::concatenate;
6use ndarray::prelude::*;
7use num_traits::NumCast;
8use std::iter::Sum;
9
10use crate::{cholesky::*, eigh::*, norm::*, triangular::*};
11use crate::{LinalgError, Order, Result};
12
13use super::{Lobpcg, LobpcgResult};
14
15fn generalized_eig<A: NdFloat>(a: Array2<A>, b: Array2<A>) -> Result<(Array1<A>, Array2<A>)> {
17 let (vals_b, vecs_b) = b.eigh_into()?;
18 let vals_b_recip = vals_b.mapv_into(|x| (x.max(A::from(1e-10f32).unwrap())).sqrt().recip());
19 let vecs_b_tilde = vecs_b * vals_b_recip;
20 let a_tilde = vecs_b_tilde.t().dot(&a.dot(&vecs_b_tilde));
21 let (vals_a, vecs_a) = a_tilde.eigh_into()?;
22 let vecs = vecs_b_tilde.dot(&vecs_a);
23
24 Ok((vals_a, vecs))
25}
26
27fn sorted_eig<A: NdFloat>(
29 a: Array2<A>,
30 b: Option<Array2<A>>,
31 size: usize,
32 order: Order,
33) -> Result<(Array1<A>, Array2<A>)> {
34 let res = match b {
35 Some(b) => generalized_eig(a, b)?,
36 _ => a.eigh_into()?,
37 };
38
39 let (vals, vecs) = res.sort_eig(order);
41 let s = vecs.row(0).mapv(|x| x.signum());
42 let vecs = vecs * s;
43 Ok((vals.slice_move(s![..size]), vecs.slice_move(s![.., ..size])))
44}
45
46fn ndarray_mask<A: NdFloat>(matrix: ArrayView2<A>, mask: &[bool]) -> Array2<A> {
48 assert_eq!(mask.len(), matrix.ncols());
49
50 let indices = mask
51 .iter()
52 .enumerate()
53 .filter(|(_, b)| **b)
54 .map(|(a, _)| a)
55 .collect::<Vec<usize>>();
56
57 matrix.select(Axis(1), &indices)
58}
59
60fn apply_constraints<A: NdFloat>(
64 mut v: ArrayViewMut<A, Ix2>,
65 cholesky_yy: &Array2<A>,
66 y: ArrayView2<A>,
67) {
68 let gram_yv = y.t().dot(&v);
69
70 let u = cholesky_yy
71 .solve_triangular_into(gram_yv, UPLO::Lower)
72 .unwrap();
73
74 ndarray::linalg::general_mat_mul(-A::one(), &y, &u, A::one(), &mut v);
76}
77
78fn orthonormalize<T: NdFloat>(v: Array2<T>) -> Result<(Array2<T>, Array2<T>)> {
82 let gram_vv = v.t().dot(&v);
83 let gram_vv_fac = gram_vv.cholesky_into()?;
84
85 let v_t = v.reversed_axes();
92 let u = gram_vv_fac
93 .solve_triangular_into(v_t, UPLO::Lower)?
94 .reversed_axes();
95
96 Ok((u, gram_vv_fac))
97}
98
99pub fn lobpcg<A: NdFloat + Sum, F: Fn(ArrayView2<A>) -> Array2<A>, G: Fn(ArrayViewMut2<A>)>(
120 a: F,
121 mut x: Array2<A>,
122 m: G,
123 y: Option<ArrayView2<A>>,
124 tol: f32,
125 maxiter: usize,
126 order: Order,
127) -> LobpcgResult<A> {
128 let (n, size_x) = (x.nrows(), x.ncols());
131 if size_x > n {
132 return Err((
133 LinalgError::NotThin {
134 rows: size_x,
135 cols: n,
136 },
137 None,
138 ));
139 }
140
141 let mut iter = usize::min(n * 10, maxiter);
152 let tol = NumCast::from(tol).unwrap();
153
154 let cholesky_yy = y.as_ref().map(|y| {
156 let cholesky_yy = y.t().dot(y).cholesky_into().unwrap();
157 apply_constraints(x.view_mut(), &cholesky_yy, y.view());
158 cholesky_yy
159 });
160
161 let (x, _) = orthonormalize(x).map_err(|err| (err, None))?;
163
164 let ax = a(x.view());
166 let xax = x.t().dot(&ax);
167
168 let (mut lambda, eig_block) =
170 sorted_eig(xax, None, size_x, order).map_err(|err| (err, None))?;
171
172 let mut x = x.dot(&eig_block);
174 let mut ax = ax.dot(&eig_block);
175
176 let mut activemask = vec![true; size_x];
178
179 let mut residual_norms_history = Vec::new();
181 let mut best_result = None;
182
183 let mut previous_block_size = size_x;
184
185 let mut ident: Array2<A> = Array2::eye(size_x);
186 let ident0: Array2<A> = Array2::eye(size_x);
187 let two = A::from(2.0).unwrap();
189
190 let mut previous_p_ap: Option<(Array2<A>, Array2<A>)> = None;
191 let mut explicit_gram_flag = true;
192
193 let final_norm = loop {
194 let lambda_diag = Array2::from_diag(&lambda);
196 let lambda_x = x.dot(&lambda_diag);
197
198 let r = &ax - &lambda_x;
200
201 let residual_norms = r
203 .columns()
204 .into_iter()
205 .map(|x| x.norm_l2())
206 .collect::<Vec<A>>();
207 residual_norms_history.push(residual_norms.clone());
208
209 let sum_rnorm = residual_norms.iter().cloned().sum();
211 if best_result
212 .as_ref()
213 .map(|x: &(_, _, Vec<A>)| x.2.iter().cloned().sum::<A>() > sum_rnorm)
214 .unwrap_or(true)
215 {
216 best_result = Some((lambda.clone(), x.clone(), residual_norms.clone()));
217 }
218
219 activemask = residual_norms
221 .iter()
222 .zip(activemask.iter())
223 .map(|(x, a)| *x > tol && *a)
224 .collect();
225
226 let current_block_size = activemask.iter().filter(|x| **x).count();
228 if current_block_size != previous_block_size {
229 previous_block_size = current_block_size;
230 ident = Array2::eye(current_block_size);
231 }
232
233 if current_block_size == 0 || iter == 0 {
236 break Ok(residual_norms);
237 }
238
239 let mut active_block_r = ndarray_mask(r.view(), &activemask);
241 m(active_block_r.view_mut());
243 if let (Some(ref y), Some(ref cholesky_yy)) = (&y, &cholesky_yy) {
245 apply_constraints(active_block_r.view_mut(), cholesky_yy, y.view());
246 }
247 ndarray::linalg::general_mat_mul(
250 -A::one(),
251 &x,
252 &x.t().dot(&active_block_r),
253 A::one(),
254 &mut active_block_r,
255 );
256
257 let (r, _) = match orthonormalize(active_block_r) {
258 Ok(x) => x,
259 Err(err) => break Err(err),
260 };
261
262 let ar = a(r.view());
263
264 let max_rnorm_float = if A::epsilon() > NumCast::from(1e-8).unwrap() {
266 NumCast::from(1.0).unwrap()
267 } else {
268 NumCast::from(1.0e-8).unwrap()
269 };
270
271 let max_norm = residual_norms.into_iter().fold(A::neg_infinity(), A::max);
273 explicit_gram_flag = max_norm <= max_rnorm_float || explicit_gram_flag;
274
275 let xar = x.t().dot(&ar);
277 let mut rar = r.t().dot(&ar);
278
279 let (xax, xx, rr, xr) = if explicit_gram_flag {
283 rar = (&rar + &rar.t()) / two;
284 let xax = x.t().dot(&ax);
285
286 (
287 (&xax + &xax.t()) / two,
288 x.t().dot(&x),
289 r.t().dot(&r),
290 x.t().dot(&r),
291 )
292 } else {
293 (
294 lambda_diag,
295 ident0.clone(),
296 ident.clone(),
297 Array2::zeros((size_x, current_block_size)),
298 )
299 };
300
301 let mut p_ap = previous_p_ap
303 .as_ref()
304 .and_then(|(p, ap)| {
305 let active_p = ndarray_mask(p.view(), &activemask);
306 let active_ap = ndarray_mask(ap.view(), &activemask);
307
308 orthonormalize(active_p).map(|x| (active_ap, x)).ok()
309 })
310 .and_then(|(active_ap, (active_p, p_r))| {
311 let active_ap = active_ap.reversed_axes();
313 p_r.solve_triangular_into(active_ap, UPLO::Lower)
314 .map(|active_ap| (active_p, active_ap.reversed_axes()))
315 .ok()
316 });
317
318 let result = p_ap
323 .as_ref()
324 .ok_or(LinalgError::NonInvertible)
325 .and_then(|(active_p, active_ap)| {
326 let xap = x.t().dot(active_ap);
327 let rap = r.t().dot(active_ap);
328 let pap = active_p.t().dot(active_ap);
329 let xp = x.t().dot(active_p);
330 let rp = r.t().dot(active_p);
331 let (pap, pp) = if explicit_gram_flag {
332 ((&pap + &pap.t()) / two, active_p.t().dot(active_p))
333 } else {
334 (pap, ident.clone())
335 };
336
337 sorted_eig(
338 concatenate![
339 Axis(0),
340 concatenate![Axis(1), xax, xar, xap],
341 concatenate![Axis(1), xar.t(), rar, rap],
342 concatenate![Axis(1), xap.t(), rap.t(), pap]
343 ],
344 Some(concatenate![
345 Axis(0),
346 concatenate![Axis(1), xx, xr, xp],
347 concatenate![Axis(1), xr.t(), rr, rp],
348 concatenate![Axis(1), xp.t(), rp.t(), pp]
349 ]),
350 size_x,
351 order,
352 )
353 })
354 .or_else(|_| {
355 p_ap = None;
356
357 sorted_eig(
358 concatenate![
359 Axis(0),
360 concatenate![Axis(1), xax, xar],
361 concatenate![Axis(1), xar.t(), rar]
362 ],
363 Some(concatenate![
364 Axis(0),
365 concatenate![Axis(1), xx, xr],
366 concatenate![Axis(1), xr.t(), rr]
367 ]),
368 size_x,
369 order,
370 )
371 });
372
373 let eig_vecs;
375 match result {
376 Ok((x, y)) => {
377 lambda = x;
378 eig_vecs = y;
379 }
380 Err(x) => break Err(x),
381 }
382
383 let (p, ap, tau) = if let Some((active_p, active_ap)) = p_ap {
385 let tau = eig_vecs.slice(s![..size_x, ..]);
387 let alpha = eig_vecs.slice(s![size_x..size_x + current_block_size, ..]);
389 let gamma = eig_vecs.slice(s![size_x + current_block_size.., ..]);
391
392 let updated_p = r.dot(&alpha) + active_p.dot(&gamma);
394 let updated_ap = ar.dot(&alpha) + active_ap.dot(&gamma);
395
396 (updated_p, updated_ap, tau)
397 } else {
398 let tau = eig_vecs.slice(s![..size_x, ..]);
400 let alpha = eig_vecs.slice(s![size_x.., ..]);
402
403 let updated_p = r.dot(&alpha);
405 let updated_ap = ar.dot(&alpha);
406
407 (updated_p, updated_ap, tau)
408 };
409
410 x = x.dot(&tau) + &p;
412 ax = ax.dot(&tau) + ≈
413
414 previous_p_ap = Some((p, ap));
415
416 iter -= 1;
417 };
418
419 let (vals, vecs, rnorm) = best_result.unwrap();
421 let res = Lobpcg {
422 eigvals: vals,
423 eigvecs: vecs,
424 rnorm,
425 };
426
427 match final_norm {
428 Ok(_) => Ok(res),
429 Err(err) => Err((err, Some(res))),
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use super::ndarray_mask;
436 use super::orthonormalize;
437 use super::sorted_eig;
438 use super::Order;
439 use super::{lobpcg, Lobpcg};
440 use crate::qr::*;
441 use approx::assert_abs_diff_eq;
442 use ndarray::prelude::*;
443 use rand::distributions::{Distribution, Standard};
444 use rand::SeedableRng;
445 use rand_xoshiro::Xoshiro256Plus;
446
447 fn random<A>(sh: (usize, usize)) -> Array2<A>
449 where
450 A: NdFloat,
451 Standard: Distribution<A>,
452 {
453 let rng = Xoshiro256Plus::seed_from_u64(3);
454 crate::lobpcg::random(sh, rng)
455 }
456
457 #[test]
459 fn test_sorted_eigen() {
460 let matrix: Array2<f64> = random((10, 10)) * 10.0;
461 let matrix = matrix.t().dot(&matrix);
462
463 let (vals, vecs) = sorted_eig(matrix.clone(), None, 10, Order::Largest).unwrap();
465
466 let diag = Array2::from_diag(&vals);
468 let rec = (vecs.dot(&diag)).dot(&vecs.t());
469
470 assert_abs_diff_eq!(&matrix, &rec, epsilon = 1e-5);
471 }
472
473 #[test]
475 fn test_masking() {
476 let matrix: Array2<f64> = random((10, 5)) * 10.0;
477 let masked_matrix = ndarray_mask(matrix.view(), &[true, true, false, true, false]);
478 assert_abs_diff_eq!(
479 &masked_matrix.slice(s![.., 2]),
480 &matrix.slice(s![.., 3]),
481 epsilon = 1e-12,
482 );
483 }
484
485 #[test]
487 fn test_orthonormalize() {
488 let matrix: Array2<f64> = random((10, 10)) * 10.0;
489
490 let (n, l) = orthonormalize(matrix.clone()).unwrap();
491
492 let identity = n.dot(&n.t());
494 assert_abs_diff_eq!(&identity, &Array2::eye(10), epsilon = 1e-2);
495
496 let qr = matrix.qr().unwrap();
498 assert_abs_diff_eq!(
499 &qr.into_r().mapv(|x| x.abs()),
500 &l.t().mapv(|x| x.abs()),
501 epsilon = 1e-2
502 );
503 }
504
505 #[test]
506 fn test_generalized_eigenvalue() {
507 let matrix: Array2<f64> = random((10, 10)) * 1.;
508 let matrix = matrix.t().dot(&matrix);
509 let identity = Array2::eye(10);
510 let matrix_inv = matrix.qr().unwrap().inverse().unwrap();
511
512 let (vals, _) =
514 sorted_eig(matrix.clone(), Some(matrix.clone()), 10, Order::Largest).unwrap();
515
516 assert_abs_diff_eq!(vals, Array1::from_elem(10, 1.0), epsilon = 1e-4);
517
518 let (vals1, _) = sorted_eig(matrix, Some(identity.clone()), 10, Order::Largest).unwrap();
519 let (vals2, _) = sorted_eig(identity, Some(matrix_inv), 10, Order::Largest).unwrap();
520
521 assert_abs_diff_eq!(vals1, vals2, epsilon = 1e-5);
522 }
524
525 fn assert_symmetric(a: &Array2<f64>) {
526 assert_abs_diff_eq!(a.view(), &a.t(), epsilon = 1e-5);
527 }
528
529 fn check_eigenvalues(a: &Array2<f64>, order: Order, num: usize, ground_truth_eigvals: &[f64]) {
530 assert_symmetric(a);
531
532 let n = a.len_of(Axis(0));
533 let x: Array2<f64> = random((n, num));
534
535 let result = lobpcg(|y| a.dot(&y), x, |_| {}, None, 1e-6, n * 3, order);
536 match result {
537 Ok(Lobpcg { eigvals, rnorm, .. }) | Err((_, Some(Lobpcg { eigvals, rnorm, .. }))) => {
538 for (i, norm) in rnorm.into_iter().enumerate() {
540 if norm > 1e-5 {
541 println!("==== Assertion Failed ====");
542 println!("The {}th eigenvalue estimation did not converge!", i);
543 panic!("Too large deviation of residual norm: {} > 0.01", norm);
544 }
545 }
546
547 if ground_truth_eigvals.len() == num {
549 assert_abs_diff_eq!(
550 &Array1::from(ground_truth_eigvals.to_vec()),
551 &eigvals,
552 epsilon = num as f64 * 5e-5,
553 )
554 }
555 }
556 Err((err, None)) => panic!("Did not converge: {:?}", err),
557 }
558 }
559
560 #[test]
562 fn test_eigsolver_diag() {
563 let diag = arr1(&[
564 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19.,
565 20.,
566 ]);
567 let a = Array2::from_diag(&diag);
568
569 check_eigenvalues(&a, Order::Largest, 3, &[20., 19., 18.]);
570 check_eigenvalues(&a, Order::Smallest, 3, &[1., 2., 3.]);
571 }
572
573 #[test]
575 fn test_eigsolver_constructed() {
576 let n = 50;
577 let tmp = random((n, n));
578 let (v, _) = orthonormalize(tmp).unwrap();
580
581 let t = Array2::from_diag(&Array1::linspace(n as f64, -(n as f64) + 2., n));
583 let a = v.dot(&t.dot(&v.t()));
584
585 check_eigenvalues(&a, Order::Largest, 5, &[50.0, 48.0, 46.0, 44.0, 42.0]);
587 check_eigenvalues(&a, Order::Smallest, 5, &[-48.0, -46.0, -44.0, -42.0, -40.0]);
588 }
589
590 #[test]
591 fn test_eigsolver_constrained() {
592 let diag = arr1(&[1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]);
593 let a = Array2::from_diag(&diag);
594 let x: Array2<f64> = random((10, 1));
595 let y: Array2<f64> = arr2(&[
596 [1.0, 0., 0., 0., 0., 0., 0., 0., 0., 0.],
597 [0., 1.0, 0., 0., 0., 0., 0., 0., 0., 0.],
598 ])
599 .reversed_axes();
600
601 let result = lobpcg(
602 |y| a.dot(&y),
603 x,
604 |_| {},
605 Some(y.view()),
606 1e-10,
607 50,
608 Order::Smallest,
609 );
610 match result {
611 Ok(Lobpcg {
612 eigvals,
613 eigvecs,
614 rnorm,
615 })
616 | Err((
617 _,
618 Some(Lobpcg {
619 eigvals,
620 eigvecs,
621 rnorm,
622 }),
623 )) => {
624 for (i, norm) in rnorm.into_iter().enumerate() {
626 if norm > 0.01 {
627 println!("==== Assertion Failed ====");
628 println!("The {}th eigenvalue estimation did not converge!", i);
629 panic!("Too large deviation of residual norm: {} > 0.01", norm);
630 }
631 }
632
633 assert_abs_diff_eq!(&eigvals, &Array1::from(vec![3.0]), epsilon = 1e-6);
635 assert_abs_diff_eq!(
636 &eigvecs.column(0).mapv(|x| x.abs()),
637 &arr1(&[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]),
638 epsilon = 1e-5,
639 );
640 }
641 Err((err, None)) => panic!("Did not converge: {:?}", err),
642 }
643 }
644}