1use ndarray::{s, Array1, Array2, ArrayBase, DataMut, Ix2, NdFloat, RawDataClone};
4
5use crate::{
6 householder::{assemble_q, clear_column, clear_row},
7 LinalgError, Result,
8};
9
10pub trait Bidiagonal {
12 type Decomp;
13
14 fn bidiagonal(self) -> Result<Self::Decomp>;
18}
19
20impl<S, A> Bidiagonal for ArrayBase<S, Ix2>
21where
22 A: NdFloat,
23 S: DataMut<Elem = A>,
24{
25 type Decomp = BidiagonalDecomp<A, S>;
26
27 fn bidiagonal(mut self) -> Result<Self::Decomp> {
28 let (nrows, ncols) = self.dim();
29 let min_dim = nrows.min(ncols);
30 if min_dim == 0 {
31 return Err(LinalgError::EmptyMatrix);
32 }
33
34 let mut diagonal = Array1::zeros(min_dim);
36 let mut off_diagonal = Array1::zeros(min_dim - 1);
37
38 let upper_diag = nrows >= ncols;
39 if upper_diag {
40 for i in 0..min_dim - 1 {
41 diagonal[i] = clear_column(&mut self, i, 0);
42 off_diagonal[i] = clear_row(&mut self, i, 1);
43 }
44 diagonal[min_dim - 1] = clear_column(&mut self, min_dim - 1, 0);
45 } else {
46 for i in 0..min_dim - 1 {
47 diagonal[i] = clear_row(&mut self, i, 0);
48 off_diagonal[i] = clear_column(&mut self, i, 1);
49 }
50 diagonal[min_dim - 1] = clear_row(&mut self, min_dim - 1, 0);
51 }
52
53 Ok(BidiagonalDecomp {
54 uv: self,
55 diagonal,
56 off_diagonal,
57 upper_diag,
58 })
59 }
60}
61
62#[derive(Debug)]
63pub struct BidiagonalDecomp<A, S: DataMut<Elem = A>> {
65 uv: ArrayBase<S, Ix2>,
66 off_diagonal: Array1<A>,
67 diagonal: Array1<A>,
68 upper_diag: bool,
69}
70
71impl<A: Clone, S: DataMut<Elem = A> + RawDataClone> Clone for BidiagonalDecomp<A, S> {
72 fn clone(&self) -> Self {
73 Self {
74 uv: self.uv.clone(),
75 off_diagonal: self.off_diagonal.clone(),
76 diagonal: self.diagonal.clone(),
77 upper_diag: self.upper_diag,
78 }
79 }
80}
81
82impl<A: NdFloat, S: DataMut<Elem = A>> BidiagonalDecomp<A, S> {
83 pub fn is_upper_diag(&self) -> bool {
85 self.upper_diag
86 }
87
88 pub fn generate_u(&self) -> Array2<A> {
91 let shift = !self.upper_diag as usize;
92 if self.upper_diag {
93 assemble_q(&self.uv, shift, |i| self.diagonal[i])
94 } else {
95 assemble_q(&self.uv, shift, |i| self.off_diagonal[i])
96 }
97 }
98
99 pub fn generate_vt(&self) -> Array2<A> {
102 let shift = self.upper_diag as usize;
103 if self.upper_diag {
104 assemble_q(&self.uv.t(), shift, |i| self.off_diagonal[i])
105 } else {
106 assemble_q(&self.uv.t(), shift, |i| self.diagonal[i])
107 }
108 .reversed_axes()
109 }
110
111 pub fn into_b(self) -> Array2<A> {
114 let d = self.diagonal.len();
115 let (r, c) = if self.upper_diag { (0, 1) } else { (1, 0) };
116 let (diagonal, off_diagonal) = self.into_diagonals();
117 let mut res = Array2::from_diag(&diagonal);
118
119 res.slice_mut(s![r..d, c..d])
120 .diag_mut()
121 .assign(&off_diagonal);
122 res
123 }
124
125 pub fn into_diagonals(self) -> (Array1<A>, Array1<A>) {
127 (
128 self.diagonal.mapv_into(A::abs),
129 self.off_diagonal.mapv_into(A::abs),
130 )
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use approx::assert_abs_diff_eq;
137 use ndarray::array;
138
139 use super::*;
140
141 #[test]
142 fn bidiagonal_lower() {
143 let arr = array![
144 [4.0f64, 0., 2., 2.],
145 [-2., 6., 3., -2.],
146 [2., 7., -3.2, -1.]
147 ];
148 let decomp = arr.clone().bidiagonal().unwrap();
149 let u = decomp.generate_u();
150 let vt = decomp.generate_vt();
151 let b = decomp.clone().into_b();
152 let (diag, offdiag) = decomp.into_diagonals();
153
154 assert_eq!(u.dim(), (3, 3));
155 assert_eq!(b.dim(), (3, 3));
156 assert_eq!(vt.dim(), (3, 4));
157 assert_abs_diff_eq!(u.dot(&u.t()), Array2::eye(3), epsilon = 1e-5);
158 assert_abs_diff_eq!(vt.dot(&vt.t()), Array2::eye(3), epsilon = 1e-5);
159 assert_abs_diff_eq!(u.dot(&b).dot(&vt), arr, epsilon = 1e-5);
160
161 assert_abs_diff_eq!(diag, b.diag());
162 let partial = b.slice(s![1.., 0..]);
163 assert_abs_diff_eq!(offdiag, partial.diag());
164 }
165
166 #[test]
167 fn bidiagonal_upper() {
168 let arr = array![
169 [4.0f64, 0., 2.],
170 [-2., 6., 3.],
171 [2., 7., -3.2],
172 [4., -3., 0.2]
173 ];
174 let decomp = arr.clone().bidiagonal().unwrap();
175 let u = decomp.generate_u();
176 let vt = decomp.generate_vt();
177 let b = decomp.clone().into_b();
178 let (diag, offdiag) = decomp.into_diagonals();
179
180 assert_eq!(u.dim(), (4, 3));
181 assert_eq!(b.dim(), (3, 3));
182 assert_eq!(vt.dim(), (3, 3));
183 assert_abs_diff_eq!(u.t().dot(&u), Array2::eye(3), epsilon = 1e-5);
184 assert_abs_diff_eq!(vt.dot(&vt.t()), Array2::eye(3), epsilon = 1e-5);
185 assert_abs_diff_eq!(u.dot(&b).dot(&vt), arr, epsilon = 1e-5);
186
187 assert_abs_diff_eq!(diag, b.diag());
188 let partial = b.slice(s![0.., 1..]);
189 assert_abs_diff_eq!(offdiag, partial.diag());
190 }
191
192 #[test]
193 fn bidiagonal_error() {
194 assert!(matches!(
195 Array2::<f64>::zeros((0, 0)).bidiagonal(),
196 Err(LinalgError::EmptyMatrix)
197 ));
198 }
199}