1#![allow(clippy::type_complexity)]
4
5use std::ops::MulAssign;
6
7use ndarray::{s, Array1, Array2, ArrayBase, Axis, Data, DataMut, Ix2, NdFloat};
8
9use crate::{
10 bidiagonal::Bidiagonal,
11 eigh::{cmp_floats, wilkinson_shift},
12 givens::GivensRotation,
13 index::*,
14 LinalgError, Order, Result,
15};
16
17fn svd<A: NdFloat, S: DataMut<Elem = A>>(
18 mut matrix: ArrayBase<S, Ix2>,
19 compute_u: bool,
20 compute_v: bool,
21 eps: A,
22) -> Result<(Option<Array2<A>>, Array1<A>, Option<Array2<A>>)> {
23 if matrix.is_empty() {
24 return Err(LinalgError::EmptyMatrix);
25 }
26 let (nrows, ncols) = matrix.dim();
27 let dim = nrows.min(ncols);
28
29 let amax = matrix
30 .iter()
31 .map(|f| f.abs())
32 .fold(A::neg_infinity(), |a, b| a.max(b));
33
34 if amax != A::zero() {
35 matrix /= amax;
36 }
37
38 let bidiag = matrix.bidiagonal()?;
39 let is_upper_diag = bidiag.is_upper_diag();
40 let mut u = compute_u.then(|| bidiag.generate_u());
41 let mut vt = compute_v.then(|| bidiag.generate_vt());
42 let (mut diag, mut off_diag) = bidiag.into_diagonals();
43
44 let (mut start, mut end) = delimit_subproblem(
45 &mut diag,
46 &mut off_diag,
47 &mut u,
48 &mut vt,
49 is_upper_diag,
50 dim - 1,
51 eps,
52 );
53
54 #[allow(clippy::comparison_chain)]
55 while end != start {
56 let subdim = end - start + 1;
57
58 if subdim > 2 {
59 let m = end - 1;
60 let n = end;
61
62 let mut vec = unsafe {
63 let dm = *diag.at(m);
64 let dn = *diag.at(n);
65 let fm = *off_diag.at(m);
66 let fm1 = *off_diag.at(m - 1);
67
68 let tmm = dm * dm + fm1 * fm1;
69 let tmn = dm * fm;
70 let tnn = dn * dn + fm * fm;
71 let shift = wilkinson_shift(tmm, tnn, tmn);
72
73 let ds = *diag.at(start);
74 (ds * ds - shift, ds * *off_diag.at(start))
75 };
76
77 for k in start..n {
78 let mut subm = unsafe {
79 let m12 = if k == n - 1 {
80 A::zero()
81 } else {
82 *off_diag.at(k + 1)
83 };
84 Array2::from_shape_vec(
85 (2, 3),
86 vec![
87 *diag.at(k),
88 *off_diag.at(k),
89 A::zero(),
90 A::zero(),
91 *diag.at(k + 1),
92 m12,
93 ],
94 )
95 .unwrap()
96 };
97
98 if let Some((rot1, norm1)) = GivensRotation::cancel_y(vec.0, vec.1) {
99 rot1.inverse()
100 .rotate_rows(&mut subm.slice_mut(s![.., 0..=1]))
101 .unwrap();
102
103 let (rot2, norm2);
104 unsafe {
105 if k > start {
106 *off_diag.atm(k - 1) = norm1;
107 }
108
109 let (v1, v2) = (*subm.at((0, 0)), *subm.at((1, 0)));
110 if let Some((rot, norm)) = GivensRotation::cancel_y(v1, v2) {
111 rot.rotate_cols(&mut subm.slice_mut(s![.., 1..=2])).unwrap();
112 rot2 = Some(rot);
113 norm2 = norm;
114 } else {
115 rot2 = None;
116 norm2 = v1;
117 };
118 *subm.atm((0, 0)) = norm2;
119 }
120
121 if let Some(ref mut vt) = vt {
122 if is_upper_diag {
123 rot1.rotate_cols(&mut vt.slice_mut(s![k..k + 2, ..]))
124 .unwrap();
125 } else if let Some(rot2) = &rot2 {
126 rot2.rotate_cols(&mut vt.slice_mut(s![k..k + 2, ..]))
127 .unwrap();
128 }
129 }
130
131 if let Some(ref mut u) = u {
132 if !is_upper_diag {
133 rot1.inverse()
134 .rotate_rows(&mut u.slice_mut(s![.., k..k + 2]))
135 .unwrap();
136 } else if let Some(rot2) = &rot2 {
137 rot2.inverse()
138 .rotate_rows(&mut u.slice_mut(s![.., k..k + 2]))
139 .unwrap();
140 }
141 }
142
143 unsafe {
144 *diag.atm(k) = *subm.at((0, 0));
145 *diag.atm(k + 1) = *subm.at((1, 1));
146 *off_diag.atm(k) = *subm.at((0, 1));
147 if k != n - 1 {
148 *off_diag.atm(k + 1) = *subm.at((1, 2));
149 }
150 vec.0 = *subm.at((0, 1));
151 vec.1 = *subm.at((0, 2));
152 }
153 } else {
154 break;
155 }
156 }
157 } else if subdim == 2 {
158 let (rot_u, rot_v) = unsafe {
160 let (s1, s2, u2, v2) = compute_2x2_uptrig_svd(
161 *diag.at(start),
162 *off_diag.at(start),
163 *diag.at(start + 1),
164 compute_u && is_upper_diag || compute_v && !is_upper_diag,
165 compute_v && is_upper_diag || compute_u && !is_upper_diag,
166 );
167 *diag.atm(start) = s1;
168 *diag.atm(start + 1) = s2;
169 *off_diag.atm(start) = A::zero();
170
171 if is_upper_diag {
172 (u2, v2)
173 } else {
174 (v2, u2)
175 }
176 };
177
178 if let Some(ref mut u) = u {
179 rot_u
180 .unwrap()
181 .rotate_rows(&mut u.slice_mut(s![.., start..start + 2]))
182 .unwrap();
183 }
184
185 if let Some(ref mut vt) = vt {
186 rot_v
187 .unwrap()
188 .inverse()
189 .rotate_cols(&mut vt.slice_mut(s![start..start + 2, ..]))
190 .unwrap();
191 }
192
193 end -= 1;
194 }
195
196 let sub = delimit_subproblem(
198 &mut diag,
199 &mut off_diag,
200 &mut u,
201 &mut vt,
202 is_upper_diag,
203 end,
204 eps,
205 );
206 start = sub.0;
207 end = sub.1;
208 }
209
210 diag *= amax;
211
212 for i in 0..dim {
214 let val = diag[i];
215 if val.is_sign_negative() {
216 diag[i] = -val;
217 if let Some(u) = &mut u {
218 u.column_mut(i).mul_assign(-A::zero());
219 }
220 }
221 }
222
223 Ok((u, diag, vt))
224}
225
226fn delimit_subproblem<A: NdFloat>(
227 diag: &mut Array1<A>,
228 off_diag: &mut Array1<A>,
229 u: &mut Option<Array2<A>>,
230 v_t: &mut Option<Array2<A>>,
231 is_upper_diag: bool,
232 end: usize,
233 eps: A,
234) -> (usize, usize) {
235 let mut n = end;
236 while n > 0 {
237 let m = n - 1;
238 unsafe {
239 if off_diag.at(m).is_zero()
240 || off_diag.at(m).abs() <= eps * (diag.at(n).abs() + diag.at(m).abs())
241 {
242 *off_diag.atm(m) = A::zero();
243 } else if diag.at(m).abs() <= eps {
244 *diag.atm(m) = A::zero();
245 cancel_horizontal_off_diagonal_elt(diag, off_diag, u, v_t, is_upper_diag, m, m + 1);
246 if m != 0 {
247 cancel_vertical_off_diagonal_elt(diag, off_diag, u, v_t, is_upper_diag, m - 1);
248 }
249 } else if diag.at(n).abs() <= eps {
250 *diag.atm(n) = A::zero();
251 cancel_vertical_off_diagonal_elt(diag, off_diag, u, v_t, is_upper_diag, m);
252 } else {
253 break;
254 }
255 }
256
257 n -= 1;
258 }
259
260 if n == 0 {
261 return (0, 0);
262 }
263
264 let mut new_start = n - 1;
265 while new_start > 0 {
266 let m = new_start - 1;
267
268 unsafe {
269 if off_diag.at(m).abs() <= eps * (diag.at(new_start).abs() + diag.at(m).abs()) {
270 *off_diag.atm(m) = A::zero();
271 break;
272 }
273 }
274
275 if unsafe { diag.at(m).abs() } <= eps {
276 unsafe { *diag.atm(m) = A::zero() };
277 cancel_horizontal_off_diagonal_elt(diag, off_diag, u, v_t, is_upper_diag, m, n);
278 if m != 0 {
279 cancel_vertical_off_diagonal_elt(diag, off_diag, u, v_t, is_upper_diag, m - 1);
280 }
281 break;
282 }
283 new_start -= 1;
284 }
285
286 (new_start, n)
287}
288
289fn cancel_horizontal_off_diagonal_elt<A: NdFloat>(
290 diag: &mut Array1<A>,
291 off_diag: &mut Array1<A>,
292 u: &mut Option<Array2<A>>,
293 v_t: &mut Option<Array2<A>>,
294 is_upper_diag: bool,
295 i: usize,
296 end: usize,
297) {
298 let mut v = (off_diag[i], diag[i + 1]);
299 off_diag[i] = A::zero();
300
301 for k in i..end {
302 if let Some((rot, norm)) = GivensRotation::cancel_x(v.0, v.1) {
303 unsafe { *diag.atm(k + 1) = norm };
304
305 if is_upper_diag {
306 if let Some(u) = u {
307 rot.inverse()
308 .rotate_rows(&mut u.slice_mut(s![.., i..=k+1;k-i+1]))
309 .unwrap()
310 }
311 } else if let Some(v_t) = v_t {
312 rot.rotate_cols(&mut v_t.slice_mut(s![i..=k+1;k-i+1, ..]))
313 .unwrap();
314 }
315
316 if k + 1 != end {
317 unsafe {
318 v.0 = -rot.s() * *off_diag.at(k + 1);
319 v.1 = *diag.at(k + 2);
320 *off_diag.atm(k + 1) *= rot.c();
321 }
322 }
323 } else {
324 break;
325 }
326 }
327}
328
329fn cancel_vertical_off_diagonal_elt<A: NdFloat>(
330 diag: &mut Array1<A>,
331 off_diag: &mut Array1<A>,
332 u: &mut Option<Array2<A>>,
333 v_t: &mut Option<Array2<A>>,
334 is_upper_diag: bool,
335 i: usize,
336) {
337 let mut v = (diag[i], off_diag[i]);
338 off_diag[i] = A::zero();
339
340 for k in (0..i + 1).rev() {
341 if let Some((rot, norm)) = GivensRotation::cancel_y(v.0, v.1) {
342 unsafe { *diag.atm(k) = norm };
343
344 if is_upper_diag {
345 if let Some(v_t) = v_t {
346 rot.rotate_cols(&mut v_t.slice_mut(s![k..=i+1;i-k+1, ..]))
347 .unwrap();
348 }
349 } else if let Some(u) = u {
350 rot.inverse()
351 .rotate_rows(&mut u.slice_mut(s![.., k..=i+1;i-k+1]))
352 .unwrap()
353 }
354
355 if k > 0 {
356 unsafe {
357 v.0 = *diag.at(k - 1);
358 v.1 = rot.s() * *off_diag.at(k - 1);
359 *off_diag.atm(k - 1) *= rot.c();
360 }
361 }
362 } else {
363 break;
364 }
365 }
366}
367
368fn compute_2x2_uptrig_svd<A: NdFloat>(
372 m11: A,
373 m12: A,
374 m22: A,
375 compute_u: bool,
376 compute_v: bool,
377) -> (A, A, Option<GivensRotation<A>>, Option<GivensRotation<A>>) {
378 let two = A::from(2.0).unwrap();
379 let denom = (m11 + m22).hypot(m12) + (m11 - m22).hypot(m12);
380
381 let mut v1 = m11 * m22 * two / denom;
386 let mut v2 = denom / two;
387
388 let mut u = None;
389 let mut v_t = None;
390
391 if compute_v || compute_u {
392 let cv = m11 * m12;
393 let sv = v1 * v1 - m11 * m11;
394 let (csv, sgn_v) = GivensRotation::new(cv, sv);
395 v1 *= sgn_v;
396 v2 *= sgn_v;
397 if compute_v {
398 v_t = Some(csv.clone());
399 }
400
401 let cu = (m11 * csv.c() + m12 * csv.s()) / v1;
402 let su = (m22 * csv.s()) / v1;
403 let (csu, sgn_u) = GivensRotation::new(cu, su);
404 v1 *= sgn_u;
405 v2 *= sgn_u;
406 if compute_u {
407 u = Some(csu);
408 }
409 }
410
411 (v1, v2, u, v_t)
412}
413
414pub trait SVDInto {
416 type U;
417 type Vt;
418 type Sigma;
419
420 fn svd_into(
424 self,
425 calc_u: bool,
426 calc_vt: bool,
427 ) -> Result<(Option<Self::U>, Self::Sigma, Option<Self::Vt>)>;
428}
429
430impl<A: NdFloat, S: DataMut<Elem = A>> SVDInto for ArrayBase<S, Ix2> {
431 type U = Array2<A>;
432 type Vt = Array2<A>;
433 type Sigma = Array1<A>;
434
435 fn svd_into(
436 self,
437 calc_u: bool,
438 calc_vt: bool,
439 ) -> Result<(Option<Self::U>, Self::Sigma, Option<Self::Vt>)> {
440 svd(self, calc_u, calc_vt, A::epsilon() * A::from(5.).unwrap())
442 }
443}
444
445pub trait SVD {
447 type U;
448 type Vt;
449 type Sigma;
450
451 fn svd(
455 &self,
456 calc_u: bool,
457 calc_vt: bool,
458 ) -> Result<(Option<Self::U>, Self::Sigma, Option<Self::Vt>)>;
459}
460
461impl<A: NdFloat, S: Data<Elem = A>> SVD for ArrayBase<S, Ix2> {
462 type U = Array2<A>;
463 type Vt = Array2<A>;
464 type Sigma = Array1<A>;
465
466 fn svd(
467 &self,
468 calc_u: bool,
469 calc_vt: bool,
470 ) -> Result<(Option<Self::U>, Self::Sigma, Option<Self::Vt>)> {
471 svd(
473 self.to_owned(),
474 calc_u,
475 calc_vt,
476 A::epsilon() * A::from(5.).unwrap(),
477 )
478 }
479}
480
481pub trait SvdSort: Sized {
488 fn sort_svd(self, order: Order) -> Self;
489
490 fn sort_svd_asc(self) -> Self {
492 self.sort_svd(Order::Smallest)
493 }
494
495 fn sort_svd_desc(self) -> Self {
497 self.sort_svd(Order::Largest)
498 }
499}
500
501impl<A: NdFloat> SvdSort for (Option<Array2<A>>, Array1<A>, Option<Array2<A>>) {
503 fn sort_svd(self, order: Order) -> Self {
504 let (u, mut s, vt) = self;
505 let mut value_idx: Vec<_> = s.iter().copied().enumerate().collect();
506 match order {
508 Order::Largest => value_idx.sort_by(|a, b| cmp_floats(&b.1, &a.1)),
509 Order::Smallest => value_idx.sort_by(|a, b| cmp_floats(&a.1, &b.1)),
510 }
511
512 let apply_ordering = |arr: &Array2<A>, ax, values_idx: &Vec<_>| {
513 let mut out = Array2::zeros(arr.dim()); for (out_idx, &(arr_idx, _)) in values_idx.iter().enumerate() {
515 out.index_axis_mut(ax, out_idx)
516 .assign(&arr.index_axis(ax, arr_idx));
517 }
518 out
519 };
520
521 let u = u.map(|u| apply_ordering(&u, Axis(1), &value_idx));
522 let vt = vt.map(|vt| apply_ordering(&vt, Axis(0), &value_idx));
523 s.iter_mut()
524 .zip(value_idx.iter())
525 .for_each(|(si, (_, f))| *si = *f);
526 (u, s, vt)
527 }
528}
529
530#[cfg(test)]
531mod tests {
532 use approx::assert_abs_diff_eq;
533 use ndarray::array;
534
535 use super::*;
536
537 #[test]
538 fn svd_test() {
539 let d = svd(array![[3.0, 0.], [0., -2.]], true, true, 1e-15).unwrap();
540 let (u, s, vt) = d.clone().sort_svd_desc();
541 assert_abs_diff_eq!(s, array![3., 2.], epsilon = 1e-7);
542 assert_abs_diff_eq!(u.unwrap(), array![[1., 0.], [0., -1.]], epsilon = 1e-7);
543 assert_abs_diff_eq!(vt.unwrap(), array![[1., 0.], [0., 1.]], epsilon = 1e-7);
544 let (u, s, vt) = d.sort_svd_asc();
545 assert_abs_diff_eq!(s, array![2., 3.], epsilon = 1e-7);
546 assert_abs_diff_eq!(u.unwrap(), array![[0., 1.], [-1., 0.]], epsilon = 1e-7);
547 assert_abs_diff_eq!(vt.unwrap(), array![[0., 1.], [1., 0.]], epsilon = 1e-7);
548
549 let (u, s, vt) = svd(array![[1., 0., -1.], [-2., 1., 4.]], false, false, 1e-15).unwrap();
550 assert_abs_diff_eq!(s, array![0.51371, 4.76824], epsilon = 1e-5);
551 assert!(u.is_none());
552 assert!(vt.is_none());
553 }
554
555 fn test_svd_props(a: Array2<f64>, exp_s: Array1<f64>) {
556 let (u, s, vt) = svd(a.clone(), true, true, 1e-15).unwrap().sort_svd_desc();
557 let (u, vt) = (u.unwrap(), vt.unwrap());
558 assert_abs_diff_eq!(s, exp_s, epsilon = 1e-5);
559 assert!(s.iter().copied().all(f64::is_sign_positive));
560 assert_abs_diff_eq!(u.dot(&Array2::from_diag(&s)).dot(&vt), a, epsilon = 1e-5);
561
562 let (u2, s2, vt2) = svd(a.clone(), false, true, 1e-15).unwrap().sort_svd_desc();
563 assert!(u2.is_none());
564 assert_abs_diff_eq!(s2, s, epsilon = 1e-9);
565 assert_abs_diff_eq!(vt2.unwrap(), vt, epsilon = 1e-9);
566
567 let (u3, s3, vt3) = svd(a.clone(), true, false, 1e-15).unwrap().sort_svd_desc();
568 assert!(vt3.is_none());
569 assert_abs_diff_eq!(s3, s, epsilon = 1e-9);
570 assert_abs_diff_eq!(u3.unwrap(), u, epsilon = 1e-9);
571
572 let (u4, s4, vt4) = svd(a, false, false, 1e-15).unwrap().sort_svd_desc();
573 assert!(vt4.is_none());
574 assert!(u4.is_none());
575 assert_abs_diff_eq!(s4, s, epsilon = 1e-9);
576 }
577
578 #[test]
579 fn svd_props() {
580 test_svd_props(array![[-2., 1., 4.]], array![21f64.sqrt()]);
581 test_svd_props(array![[1., 1.], [1., 1.]], array![2., 0.]);
582 test_svd_props(
583 array![[-3., 4.], [4.3, 2.1], [6.6, 8.7]],
584 array![11.80876, 5.2633658],
585 );
586 test_svd_props(
587 array![
588 [10.74785316637712, -5.994983325167452, -6.064492921857296],
589 [-4.149751381521569, 20.654504205822462, -4.470436210703133],
590 [-22.772715014220207, -1.4554372570788008, 18.108113992170573]
591 ]
592 .reversed_axes(),
593 array![3.16188022e+01, 2.23811978e+01, 0.],
594 );
595 }
596
597 #[test]
598 fn svd_corner() {
599 assert!(matches!(
600 svd(Array2::zeros((0, 1)), false, false, 1e-15).unwrap_err(),
601 LinalgError::EmptyMatrix
602 ));
603
604 let (u, s, vt) = svd(array![[0f64]], true, true, 1e-15).unwrap();
605 assert_abs_diff_eq!(s, array![0.]);
606 assert_abs_diff_eq!(u.unwrap(), array![[1.]]);
607 assert_abs_diff_eq!(vt.unwrap(), array![[1.]]);
608 }
609}