linfa_linalg/
bidiagonal.rs

1//! Compact bidiagonal decomposition for matrices
2
3use ndarray::{s, Array1, Array2, ArrayBase, DataMut, Ix2, NdFloat, RawDataClone};
4
5use crate::{
6    householder::{assemble_q, clear_column, clear_row},
7    LinalgError, Result,
8};
9
10/// Compact bidiagonal decomposition of a non-empty matrix
11pub trait Bidiagonal {
12    type Decomp;
13
14    /// Calculate the compact bidiagonal decomposition of a matrix, consisting of square bidiagonal
15    /// matrix `B` and rectangular semi-orthogonal matrices `U` and `Vt`, such that `U * B * Vt`
16    /// yields the original matrix.
17    fn bidiagonal(self) -> Result<Self::Decomp>;
18}
19
20impl<S, A> Bidiagonal for ArrayBase<S, Ix2>
21where
22    A: NdFloat,
23    S: DataMut<Elem = A>,
24{
25    type Decomp = BidiagonalDecomp<A, S>;
26
27    fn bidiagonal(mut self) -> Result<Self::Decomp> {
28        let (nrows, ncols) = self.dim();
29        let min_dim = nrows.min(ncols);
30        if min_dim == 0 {
31            return Err(LinalgError::EmptyMatrix);
32        }
33
34        // XXX diagonal and off_diagonal could be uninit
35        let mut diagonal = Array1::zeros(min_dim);
36        let mut off_diagonal = Array1::zeros(min_dim - 1);
37
38        let upper_diag = nrows >= ncols;
39        if upper_diag {
40            for i in 0..min_dim - 1 {
41                diagonal[i] = clear_column(&mut self, i, 0);
42                off_diagonal[i] = clear_row(&mut self, i, 1);
43            }
44            diagonal[min_dim - 1] = clear_column(&mut self, min_dim - 1, 0);
45        } else {
46            for i in 0..min_dim - 1 {
47                diagonal[i] = clear_row(&mut self, i, 0);
48                off_diagonal[i] = clear_column(&mut self, i, 1);
49            }
50            diagonal[min_dim - 1] = clear_row(&mut self, min_dim - 1, 0);
51        }
52
53        Ok(BidiagonalDecomp {
54            uv: self,
55            diagonal,
56            off_diagonal,
57            upper_diag,
58        })
59    }
60}
61
62#[derive(Debug)]
63/// Full bidiagonal decomposition
64pub struct BidiagonalDecomp<A, S: DataMut<Elem = A>> {
65    uv: ArrayBase<S, Ix2>,
66    off_diagonal: Array1<A>,
67    diagonal: Array1<A>,
68    upper_diag: bool,
69}
70
71impl<A: Clone, S: DataMut<Elem = A> + RawDataClone> Clone for BidiagonalDecomp<A, S> {
72    fn clone(&self) -> Self {
73        Self {
74            uv: self.uv.clone(),
75            off_diagonal: self.off_diagonal.clone(),
76            diagonal: self.diagonal.clone(),
77            upper_diag: self.upper_diag,
78        }
79    }
80}
81
82impl<A: NdFloat, S: DataMut<Elem = A>> BidiagonalDecomp<A, S> {
83    /// Whether `B` is upper-bidiagonal or not
84    pub fn is_upper_diag(&self) -> bool {
85        self.upper_diag
86    }
87
88    /// Generates `U` matrix, which is R x min(R, C), where R and C are dimensions of the original
89    /// matrix
90    pub fn generate_u(&self) -> Array2<A> {
91        let shift = !self.upper_diag as usize;
92        if self.upper_diag {
93            assemble_q(&self.uv, shift, |i| self.diagonal[i])
94        } else {
95            assemble_q(&self.uv, shift, |i| self.off_diagonal[i])
96        }
97    }
98
99    /// Generates `Vt` matrix, which is min(R, C) x C, where R and C are dimensions of the original
100    /// matrix
101    pub fn generate_vt(&self) -> Array2<A> {
102        let shift = self.upper_diag as usize;
103        if self.upper_diag {
104            assemble_q(&self.uv.t(), shift, |i| self.off_diagonal[i])
105        } else {
106            assemble_q(&self.uv.t(), shift, |i| self.diagonal[i])
107        }
108        .reversed_axes()
109    }
110
111    /// Returns `B` matrix, which is min(R, C) x min(R, C), where R and C are dimensions of the
112    /// original matrix
113    pub fn into_b(self) -> Array2<A> {
114        let d = self.diagonal.len();
115        let (r, c) = if self.upper_diag { (0, 1) } else { (1, 0) };
116        let (diagonal, off_diagonal) = self.into_diagonals();
117        let mut res = Array2::from_diag(&diagonal);
118
119        res.slice_mut(s![r..d, c..d])
120            .diag_mut()
121            .assign(&off_diagonal);
122        res
123    }
124
125    /// Returns the diagonal and off-diagonal elements of `B` as 1D arrays
126    pub fn into_diagonals(self) -> (Array1<A>, Array1<A>) {
127        (
128            self.diagonal.mapv_into(A::abs),
129            self.off_diagonal.mapv_into(A::abs),
130        )
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use approx::assert_abs_diff_eq;
137    use ndarray::array;
138
139    use super::*;
140
141    #[test]
142    fn bidiagonal_lower() {
143        let arr = array![
144            [4.0f64, 0., 2., 2.],
145            [-2., 6., 3., -2.],
146            [2., 7., -3.2, -1.]
147        ];
148        let decomp = arr.clone().bidiagonal().unwrap();
149        let u = decomp.generate_u();
150        let vt = decomp.generate_vt();
151        let b = decomp.clone().into_b();
152        let (diag, offdiag) = decomp.into_diagonals();
153
154        assert_eq!(u.dim(), (3, 3));
155        assert_eq!(b.dim(), (3, 3));
156        assert_eq!(vt.dim(), (3, 4));
157        assert_abs_diff_eq!(u.dot(&u.t()), Array2::eye(3), epsilon = 1e-5);
158        assert_abs_diff_eq!(vt.dot(&vt.t()), Array2::eye(3), epsilon = 1e-5);
159        assert_abs_diff_eq!(u.dot(&b).dot(&vt), arr, epsilon = 1e-5);
160
161        assert_abs_diff_eq!(diag, b.diag());
162        let partial = b.slice(s![1.., 0..]);
163        assert_abs_diff_eq!(offdiag, partial.diag());
164    }
165
166    #[test]
167    fn bidiagonal_upper() {
168        let arr = array![
169            [4.0f64, 0., 2.],
170            [-2., 6., 3.],
171            [2., 7., -3.2],
172            [4., -3., 0.2]
173        ];
174        let decomp = arr.clone().bidiagonal().unwrap();
175        let u = decomp.generate_u();
176        let vt = decomp.generate_vt();
177        let b = decomp.clone().into_b();
178        let (diag, offdiag) = decomp.into_diagonals();
179
180        assert_eq!(u.dim(), (4, 3));
181        assert_eq!(b.dim(), (3, 3));
182        assert_eq!(vt.dim(), (3, 3));
183        assert_abs_diff_eq!(u.t().dot(&u), Array2::eye(3), epsilon = 1e-5);
184        assert_abs_diff_eq!(vt.dot(&vt.t()), Array2::eye(3), epsilon = 1e-5);
185        assert_abs_diff_eq!(u.dot(&b).dot(&vt), arr, epsilon = 1e-5);
186
187        assert_abs_diff_eq!(diag, b.diag());
188        let partial = b.slice(s![0.., 1..]);
189        assert_abs_diff_eq!(offdiag, partial.diag());
190    }
191
192    #[test]
193    fn bidiagonal_error() {
194        assert!(matches!(
195            Array2::<f64>::zeros((0, 0)).bidiagonal(),
196            Err(LinalgError::EmptyMatrix)
197        ));
198    }
199}