1use crate::{check_square, index::*, LinalgError, Result};
4
5use ndarray::{s, Array, ArrayBase, Data, DataMut, Ix2, NdFloat, SliceArg};
6use num_traits::Zero;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum UPLO {
11 Upper,
12 Lower,
13}
14
15pub trait IntoTriangular {
17 fn triangular_inplace(&mut self, uplo: UPLO) -> Result<&mut Self>;
20
21 fn into_triangular(self, uplo: UPLO) -> Result<Self>
23 where
24 Self: Sized;
25}
26
27impl<A, S> IntoTriangular for ArrayBase<S, Ix2>
28where
29 A: Zero,
30 S: DataMut<Elem = A>,
31{
32 fn into_triangular(mut self, uplo: UPLO) -> Result<Self> {
33 self.triangular_inplace(uplo)?;
34 Ok(self)
35 }
36
37 fn triangular_inplace(&mut self, uplo: UPLO) -> Result<&mut Self> {
38 let n = check_square(self)?;
39 if uplo == UPLO::Upper {
40 for i in 0..n {
41 for j in 0..i {
42 unsafe { *self.atm((i, j)) = A::zero() };
43 }
44 }
45 } else {
46 for i in 0..n {
47 for j in i + 1..n {
48 unsafe { *self.atm((i, j)) = A::zero() };
49 }
50 }
51 }
52 Ok(self)
53 }
54}
55
56pub trait Triangular {
58 fn is_triangular(&self, uplo: UPLO) -> bool;
60}
61
62impl<A, S> Triangular for ArrayBase<S, Ix2>
63where
64 A: Zero,
65 S: Data<Elem = A>,
66{
67 fn is_triangular(&self, uplo: UPLO) -> bool {
68 if let Ok(n) = check_square(self) {
69 if uplo == UPLO::Upper {
70 for i in 0..n {
71 for j in 0..i {
72 if unsafe { !self.at((i, j)).is_zero() } {
73 return false;
74 }
75 }
76 }
77 } else {
78 for i in 0..n {
79 for j in i + 1..n {
80 if unsafe { !self.at((i, j)).is_zero() } {
81 return false;
82 }
83 }
84 }
85 }
86 true
87 } else {
88 false
89 }
90 }
91}
92
93#[inline]
94fn solve_triangular_system_common<A: NdFloat, I: Iterator<Item = usize>, S: SliceArg<Ix2>>(
96 a: &ArrayBase<impl Data<Elem = A>, Ix2>,
97 b: &mut ArrayBase<impl DataMut<Elem = A>, Ix2>,
98 row_iter_fn: impl Fn(usize) -> I,
99 row_slice_fn: impl Fn(usize, usize) -> S,
100 diag_fn: impl Fn(usize) -> A,
101) -> Result<()> {
102 let rows = check_square(a)?;
103 if b.nrows() != rows {
104 return Err(LinalgError::WrongRows {
105 expected: rows,
106 actual: b.nrows(),
107 });
108 }
109 let cols = b.ncols();
110
111 for k in 0..cols {
114 for i in row_iter_fn(rows) {
115 let coeff;
116 unsafe {
117 let diag = diag_fn(i);
118 coeff = *b.at((i, k)) / diag;
119 *b.atm((i, k)) = coeff;
120 }
121
122 b.slice_mut(row_slice_fn(i, k))
123 .scaled_add(-coeff, &a.slice(row_slice_fn(i, i)));
124 }
125 }
126
127 Ok(())
128}
129
130pub(crate) fn solve_triangular_system<A: NdFloat>(
134 a: &ArrayBase<impl Data<Elem = A>, Ix2>,
135 b: &mut ArrayBase<impl DataMut<Elem = A>, Ix2>,
136 uplo: UPLO,
137 diag_fn: impl Fn(usize) -> A,
138) -> Result<()> {
139 if uplo == UPLO::Upper {
140 solve_triangular_system_common(a, b, |rows| (0..rows).rev(), |r, c| s![..r, c], diag_fn)
141 } else {
142 solve_triangular_system_common(a, b, |rows| (0..rows), |r, c| s![r + 1.., c], diag_fn)
143 }
144}
145
146pub trait SolveTriangularInplace<B> {
148 fn solve_triangular_inplace<'a>(&self, b: &'a mut B, uplo: UPLO) -> Result<&'a mut B>;
150
151 fn solve_triangular_into(&self, mut b: B, uplo: UPLO) -> Result<B> {
153 self.solve_triangular_inplace(&mut b, uplo)?;
154 Ok(b)
155 }
156}
157
158impl<A: NdFloat, Si: Data<Elem = A>, So: DataMut<Elem = A>>
159 SolveTriangularInplace<ArrayBase<So, Ix2>> for ArrayBase<Si, Ix2>
160{
161 fn solve_triangular_inplace<'a>(
162 &self,
163 b: &'a mut ArrayBase<So, Ix2>,
164 uplo: UPLO,
165 ) -> Result<&'a mut ArrayBase<So, Ix2>> {
166 solve_triangular_system(self, b, uplo, |i| unsafe { *self.at((i, i)) })?;
167 Ok(b)
168 }
169}
170
171pub trait SolveTriangular<B> {
173 type Output;
174
175 fn solve_triangular(&self, b: &B, uplo: UPLO) -> Result<Self::Output>;
177}
178
179impl<A: NdFloat, Si: Data<Elem = A>, So: Data<Elem = A>> SolveTriangular<ArrayBase<So, Ix2>>
180 for ArrayBase<Si, Ix2>
181{
182 type Output = Array<A, Ix2>;
183
184 fn solve_triangular(&self, b: &ArrayBase<So, Ix2>, uplo: UPLO) -> Result<Self::Output> {
185 self.solve_triangular_into(b.to_owned(), uplo)
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use approx::assert_abs_diff_eq;
192 use ndarray::{array, Array2};
193
194 use crate::LinalgError;
195
196 use super::*;
197
198 #[test]
199 fn corner_cases() {
200 let empty = Array2::<f64>::zeros((0, 0));
201 assert!(empty.is_triangular(UPLO::Lower));
202 assert!(empty.is_triangular(UPLO::Upper));
203 assert_eq!(empty.clone().into_triangular(UPLO::Lower).unwrap(), empty);
204
205 let one = array![[1]];
206 assert!(one.is_triangular(UPLO::Lower));
207 assert!(one.is_triangular(UPLO::Upper));
208 assert_eq!(one.clone().into_triangular(UPLO::Upper).unwrap(), one);
209 assert_eq!(one.clone().into_triangular(UPLO::Lower).unwrap(), one);
210 }
211
212 #[test]
213 fn non_square() {
214 let row = array![[1, 2, 3], [3, 4, 5]];
215 assert!(!row.is_triangular(UPLO::Lower));
216 assert!(!row.is_triangular(UPLO::Upper));
217 assert!(matches!(
218 row.into_triangular(UPLO::Lower),
219 Err(LinalgError::NotSquare { rows: 2, cols: 3 })
220 ));
221
222 let col = array![[1, 2], [3, 5], [6, 8]];
223 assert!(!col.is_triangular(UPLO::Lower));
224 assert!(!col.is_triangular(UPLO::Upper));
225 assert!(matches!(
226 col.into_triangular(UPLO::Upper),
227 Err(LinalgError::NotSquare { rows: 3, cols: 2 })
228 ));
229 }
230
231 #[test]
232 fn square() {
233 let square = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
234 assert!(!square.is_triangular(UPLO::Lower));
235 assert!(!square.is_triangular(UPLO::Upper));
236
237 let upper = square.clone().into_triangular(UPLO::Upper).unwrap();
238 assert_eq!(upper, array![[1, 2, 3], [0, 5, 6], [0, 0, 9]]);
239 assert!(!upper.is_triangular(UPLO::Lower));
240 assert!(upper.is_triangular(UPLO::Upper));
241
242 let lower = square.into_triangular(UPLO::Lower).unwrap();
243 assert!(lower.is_triangular(UPLO::Lower));
244 assert!(!lower.is_triangular(UPLO::Upper));
245 assert_eq!(lower, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]);
246 }
247
248 #[test]
249 fn solve_triangular() {
250 let lower = array![[1.0, 0.0], [3.0, 4.0]];
251 assert!(lower.is_triangular(UPLO::Lower));
252 let expected = array![[2.2, 3.1, 2.2], [1.0, 0.0, 5.7]];
253 let b = lower.dot(&expected);
254 let x = lower.solve_triangular_into(b, UPLO::Lower).unwrap();
255 assert_abs_diff_eq!(x, expected, epsilon = 1e-7);
256
257 let upper = array![[4.4, 2.1], [0.0, 4.3]];
258 assert!(upper.is_triangular(UPLO::Upper));
259 let b = upper.dot(&expected);
260 let x = upper.solve_triangular_into(b, UPLO::Upper).unwrap();
261 assert_abs_diff_eq!(x, expected, epsilon = 1e-7);
262 }
263
264 #[test]
265 fn solve_corner_cases() {
266 let empty = Array2::<f64>::zeros((0, 0));
267 let out = empty.solve_triangular(&empty, UPLO::Upper).unwrap();
268 assert_eq!(out.dim(), (0, 0));
269
270 let one = Array2::<f64>::ones((1, 1));
271 let out = one.solve_triangular(&one, UPLO::Upper).unwrap();
272 assert_abs_diff_eq!(out, one);
273
274 let diag_zero = array![[0., 3.], [2., 0.]];
275 let zeros = Array2::<f64>::zeros((2, 2));
276 diag_zero.solve_triangular(&zeros, UPLO::Lower).unwrap(); }
278
279 #[test]
280 fn solve_error() {
281 let non_square = array![[1.2f64, 3.3]];
282 assert!(matches!(
283 non_square
284 .solve_triangular(&non_square, UPLO::Lower)
285 .unwrap_err(),
286 LinalgError::NotSquare { .. }
287 ));
288
289 let square = array![[1.1, 2.2], [3.3, 2.1]];
290 assert!(matches!(
291 square
292 .solve_triangular(&array![[2.2, 3.3]], UPLO::Upper)
293 .unwrap_err(),
294 LinalgError::WrongRows {
295 expected: 2,
296 actual: 1
297 }
298 ));
299 }
300}