linfa_linalg/
svd.rs

1//! Compact singular-value decomposition of matrices
2
3#![allow(clippy::type_complexity)]
4
5use std::ops::MulAssign;
6
7use ndarray::{s, Array1, Array2, ArrayBase, Axis, Data, DataMut, Ix2, NdFloat};
8
9use crate::{
10    bidiagonal::Bidiagonal,
11    eigh::{cmp_floats, wilkinson_shift},
12    givens::GivensRotation,
13    index::*,
14    LinalgError, Order, Result,
15};
16
17fn svd<A: NdFloat, S: DataMut<Elem = A>>(
18    mut matrix: ArrayBase<S, Ix2>,
19    compute_u: bool,
20    compute_v: bool,
21    eps: A,
22) -> Result<(Option<Array2<A>>, Array1<A>, Option<Array2<A>>)> {
23    if matrix.is_empty() {
24        return Err(LinalgError::EmptyMatrix);
25    }
26    let (nrows, ncols) = matrix.dim();
27    let dim = nrows.min(ncols);
28
29    let amax = matrix
30        .iter()
31        .map(|f| f.abs())
32        .fold(A::neg_infinity(), |a, b| a.max(b));
33
34    if amax != A::zero() {
35        matrix /= amax;
36    }
37
38    let bidiag = matrix.bidiagonal()?;
39    let is_upper_diag = bidiag.is_upper_diag();
40    let mut u = compute_u.then(|| bidiag.generate_u());
41    let mut vt = compute_v.then(|| bidiag.generate_vt());
42    let (mut diag, mut off_diag) = bidiag.into_diagonals();
43
44    let (mut start, mut end) = delimit_subproblem(
45        &mut diag,
46        &mut off_diag,
47        &mut u,
48        &mut vt,
49        is_upper_diag,
50        dim - 1,
51        eps,
52    );
53
54    #[allow(clippy::comparison_chain)]
55    while end != start {
56        let subdim = end - start + 1;
57
58        if subdim > 2 {
59            let m = end - 1;
60            let n = end;
61
62            let mut vec = unsafe {
63                let dm = *diag.at(m);
64                let dn = *diag.at(n);
65                let fm = *off_diag.at(m);
66                let fm1 = *off_diag.at(m - 1);
67
68                let tmm = dm * dm + fm1 * fm1;
69                let tmn = dm * fm;
70                let tnn = dn * dn + fm * fm;
71                let shift = wilkinson_shift(tmm, tnn, tmn);
72
73                let ds = *diag.at(start);
74                (ds * ds - shift, ds * *off_diag.at(start))
75            };
76
77            for k in start..n {
78                let mut subm = unsafe {
79                    let m12 = if k == n - 1 {
80                        A::zero()
81                    } else {
82                        *off_diag.at(k + 1)
83                    };
84                    Array2::from_shape_vec(
85                        (2, 3),
86                        vec![
87                            *diag.at(k),
88                            *off_diag.at(k),
89                            A::zero(),
90                            A::zero(),
91                            *diag.at(k + 1),
92                            m12,
93                        ],
94                    )
95                    .unwrap()
96                };
97
98                if let Some((rot1, norm1)) = GivensRotation::cancel_y(vec.0, vec.1) {
99                    rot1.inverse()
100                        .rotate_rows(&mut subm.slice_mut(s![.., 0..=1]))
101                        .unwrap();
102
103                    let (rot2, norm2);
104                    unsafe {
105                        if k > start {
106                            *off_diag.atm(k - 1) = norm1;
107                        }
108
109                        let (v1, v2) = (*subm.at((0, 0)), *subm.at((1, 0)));
110                        if let Some((rot, norm)) = GivensRotation::cancel_y(v1, v2) {
111                            rot.rotate_cols(&mut subm.slice_mut(s![.., 1..=2])).unwrap();
112                            rot2 = Some(rot);
113                            norm2 = norm;
114                        } else {
115                            rot2 = None;
116                            norm2 = v1;
117                        };
118                        *subm.atm((0, 0)) = norm2;
119                    }
120
121                    if let Some(ref mut vt) = vt {
122                        if is_upper_diag {
123                            rot1.rotate_cols(&mut vt.slice_mut(s![k..k + 2, ..]))
124                                .unwrap();
125                        } else if let Some(rot2) = &rot2 {
126                            rot2.rotate_cols(&mut vt.slice_mut(s![k..k + 2, ..]))
127                                .unwrap();
128                        }
129                    }
130
131                    if let Some(ref mut u) = u {
132                        if !is_upper_diag {
133                            rot1.inverse()
134                                .rotate_rows(&mut u.slice_mut(s![.., k..k + 2]))
135                                .unwrap();
136                        } else if let Some(rot2) = &rot2 {
137                            rot2.inverse()
138                                .rotate_rows(&mut u.slice_mut(s![.., k..k + 2]))
139                                .unwrap();
140                        }
141                    }
142
143                    unsafe {
144                        *diag.atm(k) = *subm.at((0, 0));
145                        *diag.atm(k + 1) = *subm.at((1, 1));
146                        *off_diag.atm(k) = *subm.at((0, 1));
147                        if k != n - 1 {
148                            *off_diag.atm(k + 1) = *subm.at((1, 2));
149                        }
150                        vec.0 = *subm.at((0, 1));
151                        vec.1 = *subm.at((0, 2));
152                    }
153                } else {
154                    break;
155                }
156            }
157        } else if subdim == 2 {
158            // Solve 2x2 subproblem
159            let (rot_u, rot_v) = unsafe {
160                let (s1, s2, u2, v2) = compute_2x2_uptrig_svd(
161                    *diag.at(start),
162                    *off_diag.at(start),
163                    *diag.at(start + 1),
164                    compute_u && is_upper_diag || compute_v && !is_upper_diag,
165                    compute_v && is_upper_diag || compute_u && !is_upper_diag,
166                );
167                *diag.atm(start) = s1;
168                *diag.atm(start + 1) = s2;
169                *off_diag.atm(start) = A::zero();
170
171                if is_upper_diag {
172                    (u2, v2)
173                } else {
174                    (v2, u2)
175                }
176            };
177
178            if let Some(ref mut u) = u {
179                rot_u
180                    .unwrap()
181                    .rotate_rows(&mut u.slice_mut(s![.., start..start + 2]))
182                    .unwrap();
183            }
184
185            if let Some(ref mut vt) = vt {
186                rot_v
187                    .unwrap()
188                    .inverse()
189                    .rotate_cols(&mut vt.slice_mut(s![start..start + 2, ..]))
190                    .unwrap();
191            }
192
193            end -= 1;
194        }
195
196        // Re-delimit the subproblem in case some decoupling occurred.
197        let sub = delimit_subproblem(
198            &mut diag,
199            &mut off_diag,
200            &mut u,
201            &mut vt,
202            is_upper_diag,
203            end,
204            eps,
205        );
206        start = sub.0;
207        end = sub.1;
208    }
209
210    diag *= amax;
211
212    // Ensure singular values are positive
213    for i in 0..dim {
214        let val = diag[i];
215        if val.is_sign_negative() {
216            diag[i] = -val;
217            if let Some(u) = &mut u {
218                u.column_mut(i).mul_assign(-A::zero());
219            }
220        }
221    }
222
223    Ok((u, diag, vt))
224}
225
226fn delimit_subproblem<A: NdFloat>(
227    diag: &mut Array1<A>,
228    off_diag: &mut Array1<A>,
229    u: &mut Option<Array2<A>>,
230    v_t: &mut Option<Array2<A>>,
231    is_upper_diag: bool,
232    end: usize,
233    eps: A,
234) -> (usize, usize) {
235    let mut n = end;
236    while n > 0 {
237        let m = n - 1;
238        unsafe {
239            if off_diag.at(m).is_zero()
240                || off_diag.at(m).abs() <= eps * (diag.at(n).abs() + diag.at(m).abs())
241            {
242                *off_diag.atm(m) = A::zero();
243            } else if diag.at(m).abs() <= eps {
244                *diag.atm(m) = A::zero();
245                cancel_horizontal_off_diagonal_elt(diag, off_diag, u, v_t, is_upper_diag, m, m + 1);
246                if m != 0 {
247                    cancel_vertical_off_diagonal_elt(diag, off_diag, u, v_t, is_upper_diag, m - 1);
248                }
249            } else if diag.at(n).abs() <= eps {
250                *diag.atm(n) = A::zero();
251                cancel_vertical_off_diagonal_elt(diag, off_diag, u, v_t, is_upper_diag, m);
252            } else {
253                break;
254            }
255        }
256
257        n -= 1;
258    }
259
260    if n == 0 {
261        return (0, 0);
262    }
263
264    let mut new_start = n - 1;
265    while new_start > 0 {
266        let m = new_start - 1;
267
268        unsafe {
269            if off_diag.at(m).abs() <= eps * (diag.at(new_start).abs() + diag.at(m).abs()) {
270                *off_diag.atm(m) = A::zero();
271                break;
272            }
273        }
274
275        if unsafe { diag.at(m).abs() } <= eps {
276            unsafe { *diag.atm(m) = A::zero() };
277            cancel_horizontal_off_diagonal_elt(diag, off_diag, u, v_t, is_upper_diag, m, n);
278            if m != 0 {
279                cancel_vertical_off_diagonal_elt(diag, off_diag, u, v_t, is_upper_diag, m - 1);
280            }
281            break;
282        }
283        new_start -= 1;
284    }
285
286    (new_start, n)
287}
288
289fn cancel_horizontal_off_diagonal_elt<A: NdFloat>(
290    diag: &mut Array1<A>,
291    off_diag: &mut Array1<A>,
292    u: &mut Option<Array2<A>>,
293    v_t: &mut Option<Array2<A>>,
294    is_upper_diag: bool,
295    i: usize,
296    end: usize,
297) {
298    let mut v = (off_diag[i], diag[i + 1]);
299    off_diag[i] = A::zero();
300
301    for k in i..end {
302        if let Some((rot, norm)) = GivensRotation::cancel_x(v.0, v.1) {
303            unsafe { *diag.atm(k + 1) = norm };
304
305            if is_upper_diag {
306                if let Some(u) = u {
307                    rot.inverse()
308                        .rotate_rows(&mut u.slice_mut(s![.., i..=k+1;k-i+1]))
309                        .unwrap()
310                }
311            } else if let Some(v_t) = v_t {
312                rot.rotate_cols(&mut v_t.slice_mut(s![i..=k+1;k-i+1, ..]))
313                    .unwrap();
314            }
315
316            if k + 1 != end {
317                unsafe {
318                    v.0 = -rot.s() * *off_diag.at(k + 1);
319                    v.1 = *diag.at(k + 2);
320                    *off_diag.atm(k + 1) *= rot.c();
321                }
322            }
323        } else {
324            break;
325        }
326    }
327}
328
329fn cancel_vertical_off_diagonal_elt<A: NdFloat>(
330    diag: &mut Array1<A>,
331    off_diag: &mut Array1<A>,
332    u: &mut Option<Array2<A>>,
333    v_t: &mut Option<Array2<A>>,
334    is_upper_diag: bool,
335    i: usize,
336) {
337    let mut v = (diag[i], off_diag[i]);
338    off_diag[i] = A::zero();
339
340    for k in (0..i + 1).rev() {
341        if let Some((rot, norm)) = GivensRotation::cancel_y(v.0, v.1) {
342            unsafe { *diag.atm(k) = norm };
343
344            if is_upper_diag {
345                if let Some(v_t) = v_t {
346                    rot.rotate_cols(&mut v_t.slice_mut(s![k..=i+1;i-k+1, ..]))
347                        .unwrap();
348                }
349            } else if let Some(u) = u {
350                rot.inverse()
351                    .rotate_rows(&mut u.slice_mut(s![.., k..=i+1;i-k+1]))
352                    .unwrap()
353            }
354
355            if k > 0 {
356                unsafe {
357                    v.0 = *diag.at(k - 1);
358                    v.1 = rot.s() * *off_diag.at(k - 1);
359                    *off_diag.atm(k - 1) *= rot.c();
360                }
361            }
362        } else {
363            break;
364        }
365    }
366}
367
368// Explicit formulae inspired from the paper "Computing the Singular Values of 2-by-2 Complex
369// Matrices", Sanzheng Qiao and Xiaohong Wang.
370// http://www.cas.mcmaster.ca/sqrl/papers/sqrl5.pdf
371fn compute_2x2_uptrig_svd<A: NdFloat>(
372    m11: A,
373    m12: A,
374    m22: A,
375    compute_u: bool,
376    compute_v: bool,
377) -> (A, A, Option<GivensRotation<A>>, Option<GivensRotation<A>>) {
378    let two = A::from(2.0).unwrap();
379    let denom = (m11 + m22).hypot(m12) + (m11 - m22).hypot(m12);
380
381    // NOTE: v1 is the singular value that is the closest to m22.
382    // This prevents cancellation issues when constructing the vector `csv` below. If we chose
383    // otherwise, we would have v1 ~= m11 when m12 is small. This would cause catastrophic
384    // cancellation on `v1 * v1 - m11 * m11` below.
385    let mut v1 = m11 * m22 * two / denom;
386    let mut v2 = denom / two;
387
388    let mut u = None;
389    let mut v_t = None;
390
391    if compute_v || compute_u {
392        let cv = m11 * m12;
393        let sv = v1 * v1 - m11 * m11;
394        let (csv, sgn_v) = GivensRotation::new(cv, sv);
395        v1 *= sgn_v;
396        v2 *= sgn_v;
397        if compute_v {
398            v_t = Some(csv.clone());
399        }
400
401        let cu = (m11 * csv.c() + m12 * csv.s()) / v1;
402        let su = (m22 * csv.s()) / v1;
403        let (csu, sgn_u) = GivensRotation::new(cu, su);
404        v1 *= sgn_u;
405        v2 *= sgn_u;
406        if compute_u {
407            u = Some(csu);
408        }
409    }
410
411    (v1, v2, u, v_t)
412}
413
414/// Compact singular-value decomposition of a non-empty matrix
415pub trait SVDInto {
416    type U;
417    type Vt;
418    type Sigma;
419
420    /// Calculates the compact SVD of a matrix, consisting of a square non-negative diagonal
421    /// matrix `S` and rectangular semi-orthogonal matrices `U` and `Vt`, such that `U * S * Vt`
422    /// yields the original matrix. Only the diagonal elements of `S` is returned.
423    fn svd_into(
424        self,
425        calc_u: bool,
426        calc_vt: bool,
427    ) -> Result<(Option<Self::U>, Self::Sigma, Option<Self::Vt>)>;
428}
429
430impl<A: NdFloat, S: DataMut<Elem = A>> SVDInto for ArrayBase<S, Ix2> {
431    type U = Array2<A>;
432    type Vt = Array2<A>;
433    type Sigma = Array1<A>;
434
435    fn svd_into(
436        self,
437        calc_u: bool,
438        calc_vt: bool,
439    ) -> Result<(Option<Self::U>, Self::Sigma, Option<Self::Vt>)> {
440        // epsilon = 1e-15 for f64
441        svd(self, calc_u, calc_vt, A::epsilon() * A::from(5.).unwrap())
442    }
443}
444
445/// Compact singular-value decomposition of a non-empty matrix
446pub trait SVD {
447    type U;
448    type Vt;
449    type Sigma;
450
451    /// Calculates the compact SVD of a matrix, consisting of a square non-negative diagonal
452    /// matrix `S` and rectangular semi-orthogonal matrices `U` and `Vt`, such that `U * S * Vt`
453    /// yields the original matrix. Only the diagonal elements of `S` is returned.
454    fn svd(
455        &self,
456        calc_u: bool,
457        calc_vt: bool,
458    ) -> Result<(Option<Self::U>, Self::Sigma, Option<Self::Vt>)>;
459}
460
461impl<A: NdFloat, S: Data<Elem = A>> SVD for ArrayBase<S, Ix2> {
462    type U = Array2<A>;
463    type Vt = Array2<A>;
464    type Sigma = Array1<A>;
465
466    fn svd(
467        &self,
468        calc_u: bool,
469        calc_vt: bool,
470    ) -> Result<(Option<Self::U>, Self::Sigma, Option<Self::Vt>)> {
471        // epsilon = 1e-15 for f64
472        svd(
473            self.to_owned(),
474            calc_u,
475            calc_vt,
476            A::epsilon() * A::from(5.).unwrap(),
477        )
478    }
479}
480
481/// Sorting of SVD decomposition by the singular values. Rearranges the columns of `U` and rows of
482/// `Vt` accordingly.
483///
484/// ## Panic
485///
486/// Will panic if shape of inputs differs from shape of SVD output, or if input contains NaN.
487pub trait SvdSort: Sized {
488    fn sort_svd(self, order: Order) -> Self;
489
490    /// Sort SVD decomposition by the singular values in ascending order
491    fn sort_svd_asc(self) -> Self {
492        self.sort_svd(Order::Smallest)
493    }
494
495    /// Sort SVD decomposition by the singular values in descending order
496    fn sort_svd_desc(self) -> Self {
497        self.sort_svd(Order::Largest)
498    }
499}
500
501/// Implemented on the output of the `SVD` traits
502impl<A: NdFloat> SvdSort for (Option<Array2<A>>, Array1<A>, Option<Array2<A>>) {
503    fn sort_svd(self, order: Order) -> Self {
504        let (u, mut s, vt) = self;
505        let mut value_idx: Vec<_> = s.iter().copied().enumerate().collect();
506        // Panic only happens with NaN values
507        match order {
508            Order::Largest => value_idx.sort_by(|a, b| cmp_floats(&b.1, &a.1)),
509            Order::Smallest => value_idx.sort_by(|a, b| cmp_floats(&a.1, &b.1)),
510        }
511
512        let apply_ordering = |arr: &Array2<A>, ax, values_idx: &Vec<_>| {
513            let mut out = Array2::zeros(arr.dim()); // Could be uninit
514            for (out_idx, &(arr_idx, _)) in values_idx.iter().enumerate() {
515                out.index_axis_mut(ax, out_idx)
516                    .assign(&arr.index_axis(ax, arr_idx));
517            }
518            out
519        };
520
521        let u = u.map(|u| apply_ordering(&u, Axis(1), &value_idx));
522        let vt = vt.map(|vt| apply_ordering(&vt, Axis(0), &value_idx));
523        s.iter_mut()
524            .zip(value_idx.iter())
525            .for_each(|(si, (_, f))| *si = *f);
526        (u, s, vt)
527    }
528}
529
530#[cfg(test)]
531mod tests {
532    use approx::assert_abs_diff_eq;
533    use ndarray::array;
534
535    use super::*;
536
537    #[test]
538    fn svd_test() {
539        let d = svd(array![[3.0, 0.], [0., -2.]], true, true, 1e-15).unwrap();
540        let (u, s, vt) = d.clone().sort_svd_desc();
541        assert_abs_diff_eq!(s, array![3., 2.], epsilon = 1e-7);
542        assert_abs_diff_eq!(u.unwrap(), array![[1., 0.], [0., -1.]], epsilon = 1e-7);
543        assert_abs_diff_eq!(vt.unwrap(), array![[1., 0.], [0., 1.]], epsilon = 1e-7);
544        let (u, s, vt) = d.sort_svd_asc();
545        assert_abs_diff_eq!(s, array![2., 3.], epsilon = 1e-7);
546        assert_abs_diff_eq!(u.unwrap(), array![[0., 1.], [-1., 0.]], epsilon = 1e-7);
547        assert_abs_diff_eq!(vt.unwrap(), array![[0., 1.], [1., 0.]], epsilon = 1e-7);
548
549        let (u, s, vt) = svd(array![[1., 0., -1.], [-2., 1., 4.]], false, false, 1e-15).unwrap();
550        assert_abs_diff_eq!(s, array![0.51371, 4.76824], epsilon = 1e-5);
551        assert!(u.is_none());
552        assert!(vt.is_none());
553    }
554
555    fn test_svd_props(a: Array2<f64>, exp_s: Array1<f64>) {
556        let (u, s, vt) = svd(a.clone(), true, true, 1e-15).unwrap().sort_svd_desc();
557        let (u, vt) = (u.unwrap(), vt.unwrap());
558        assert_abs_diff_eq!(s, exp_s, epsilon = 1e-5);
559        assert!(s.iter().copied().all(f64::is_sign_positive));
560        assert_abs_diff_eq!(u.dot(&Array2::from_diag(&s)).dot(&vt), a, epsilon = 1e-5);
561
562        let (u2, s2, vt2) = svd(a.clone(), false, true, 1e-15).unwrap().sort_svd_desc();
563        assert!(u2.is_none());
564        assert_abs_diff_eq!(s2, s, epsilon = 1e-9);
565        assert_abs_diff_eq!(vt2.unwrap(), vt, epsilon = 1e-9);
566
567        let (u3, s3, vt3) = svd(a.clone(), true, false, 1e-15).unwrap().sort_svd_desc();
568        assert!(vt3.is_none());
569        assert_abs_diff_eq!(s3, s, epsilon = 1e-9);
570        assert_abs_diff_eq!(u3.unwrap(), u, epsilon = 1e-9);
571
572        let (u4, s4, vt4) = svd(a, false, false, 1e-15).unwrap().sort_svd_desc();
573        assert!(vt4.is_none());
574        assert!(u4.is_none());
575        assert_abs_diff_eq!(s4, s, epsilon = 1e-9);
576    }
577
578    #[test]
579    fn svd_props() {
580        test_svd_props(array![[-2., 1., 4.]], array![21f64.sqrt()]);
581        test_svd_props(array![[1., 1.], [1., 1.]], array![2., 0.]);
582        test_svd_props(
583            array![[-3., 4.], [4.3, 2.1], [6.6, 8.7]],
584            array![11.80876, 5.2633658],
585        );
586        test_svd_props(
587            array![
588                [10.74785316637712, -5.994983325167452, -6.064492921857296],
589                [-4.149751381521569, 20.654504205822462, -4.470436210703133],
590                [-22.772715014220207, -1.4554372570788008, 18.108113992170573]
591            ]
592            .reversed_axes(),
593            array![3.16188022e+01, 2.23811978e+01, 0.],
594        );
595    }
596
597    #[test]
598    fn svd_corner() {
599        assert!(matches!(
600            svd(Array2::zeros((0, 1)), false, false, 1e-15).unwrap_err(),
601            LinalgError::EmptyMatrix
602        ));
603
604        let (u, s, vt) = svd(array![[0f64]], true, true, 1e-15).unwrap();
605        assert_abs_diff_eq!(s, array![0.]);
606        assert_abs_diff_eq!(u.unwrap(), array![[1.]]);
607        assert_abs_diff_eq!(vt.unwrap(), array![[1.]]);
608    }
609}