rand_distr/
inverse_gaussian.rs1use crate::{Distribution, StandardNormal, StandardUniform};
4use core::fmt;
5use num_traits::Float;
6use rand::Rng;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum Error {
11 MeanNegativeOrNull,
13 ShapeNegativeOrNull,
15}
16
17impl fmt::Display for Error {
18 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
19 f.write_str(match self {
20 Error::MeanNegativeOrNull => "mean <= 0 or is NaN in inverse Gaussian distribution",
21 Error::ShapeNegativeOrNull => "shape <= 0 or is NaN in inverse Gaussian distribution",
22 })
23 }
24}
25
26#[cfg(feature = "std")]
27impl std::error::Error for Error {}
28
29#[derive(Debug, Clone, Copy, PartialEq)]
51#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
52pub struct InverseGaussian<F>
53where
54 F: Float,
55 StandardNormal: Distribution<F>,
56 StandardUniform: Distribution<F>,
57{
58 mean: F,
59 shape: F,
60}
61
62impl<F> InverseGaussian<F>
63where
64 F: Float,
65 StandardNormal: Distribution<F>,
66 StandardUniform: Distribution<F>,
67{
68 pub fn new(mean: F, shape: F) -> Result<InverseGaussian<F>, Error> {
71 let zero = F::zero();
72 if !(mean > zero) {
73 return Err(Error::MeanNegativeOrNull);
74 }
75
76 if !(shape > zero) {
77 return Err(Error::ShapeNegativeOrNull);
78 }
79
80 Ok(Self { mean, shape })
81 }
82}
83
84impl<F> Distribution<F> for InverseGaussian<F>
85where
86 F: Float,
87 StandardNormal: Distribution<F>,
88 StandardUniform: Distribution<F>,
89{
90 #[allow(clippy::many_single_char_names)]
91 fn sample<R>(&self, rng: &mut R) -> F
92 where
93 R: Rng + ?Sized,
94 {
95 let mu = self.mean;
96 let l = self.shape;
97
98 let v: F = rng.sample(StandardNormal);
99 let y = mu * v * v;
100
101 let mu_2l = mu / (F::from(2.).unwrap() * l);
102
103 let x = mu + mu_2l * (y - (F::from(4.).unwrap() * l * y + y * y).sqrt());
104
105 let u: F = rng.random();
106
107 if u <= mu / (mu + x) {
108 return x;
109 }
110
111 mu * mu / x
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 #[test]
120 fn test_inverse_gaussian() {
121 let inv_gauss = InverseGaussian::new(1.0, 1.0).unwrap();
122 let mut rng = crate::test::rng(210);
123 for _ in 0..1000 {
124 inv_gauss.sample(&mut rng);
125 }
126 }
127
128 #[test]
129 fn test_inverse_gaussian_invalid_param() {
130 assert!(InverseGaussian::new(-1.0, 1.0).is_err());
131 assert!(InverseGaussian::new(-1.0, -1.0).is_err());
132 assert!(InverseGaussian::new(1.0, -1.0).is_err());
133 assert!(InverseGaussian::new(1.0, 1.0).is_ok());
134 }
135
136 #[test]
137 fn inverse_gaussian_distributions_can_be_compared() {
138 assert_eq!(
139 InverseGaussian::new(1.0, 2.0),
140 InverseGaussian::new(1.0, 2.0)
141 );
142 }
143}