1use crate::*;
5use stdlib::num::NonZeroU64;
6
7use arithmetic::store_carry;
8
9
10include!(concat!(env!("OUT_DIR"), "/default_precision.rs"));
12
13
14#[derive(Debug, Clone)]
31pub struct Context {
32 precision: NonZeroU64,
34 rounding: RoundingMode,
36}
37
38impl Context {
39 pub fn new(precision: NonZeroU64, rounding: RoundingMode) -> Self {
41 Context {
42 precision: precision,
43 rounding: rounding,
44 }
45 }
46
47 pub fn with_precision(&self, precision: NonZeroU64) -> Self {
49 Self {
50 precision: precision,
51 ..*self
52 }
53 }
54
55 pub fn with_prec<T: ToPrimitive>(&self, precision: T) -> Option<Self> {
57 precision
58 .to_u64()
59 .and_then(NonZeroU64::new)
60 .map(|prec| {
61 Self {
62 precision: prec,
63 ..*self
64 }
65 })
66 }
67
68 pub fn with_rounding_mode(&self, mode: RoundingMode) -> Self {
70 Self {
71 rounding: mode,
72 ..*self
73 }
74 }
75
76 pub fn precision(&self) -> NonZeroU64 {
78 self.precision
79 }
80
81 pub fn rounding_mode(&self) -> RoundingMode {
83 self.rounding
84 }
85
86 pub fn round_decimal(&self, n: BigDecimal) -> BigDecimal {
88 n.with_precision_round(self.precision(), self.rounding_mode())
89 }
90
91 pub fn round_decimal_ref<'a, D: Into<BigDecimalRef<'a>>>(&self, n: D) -> BigDecimal {
93 let d = n.into().to_owned();
94 d.with_precision_round(self.precision(), self.rounding_mode())
95 }
96
97 #[allow(dead_code)]
99 pub(crate) fn round_pair(&self, sign: Sign, x: u8, y: u8, trailing_zeros: bool) -> u8 {
100 self.rounding.round_pair(sign, (x, y), trailing_zeros)
101 }
102
103 #[allow(dead_code)]
105 pub(crate) fn round_pair_with_carry(
106 &self,
107 sign: Sign,
108 x: u8,
109 y: u8,
110 trailing_zeros: bool,
111 carry: &mut u8,
112 ) -> u8 {
113 self.rounding.round_pair_with_carry(sign, (x, y), trailing_zeros, carry)
114 }
115}
116
117impl stdlib::default::Default for Context {
118 fn default() -> Self {
119 Self {
120 precision: NonZeroU64::new(DEFAULT_PRECISION).unwrap(),
121 rounding: RoundingMode::default(),
122 }
123 }
124}
125
126impl Context {
127 pub fn add_refs<'a, 'b, A, B>(&self, a: A, b: B) -> BigDecimal
129 where
130 A: Into<BigDecimalRef<'a>>,
131 B: Into<BigDecimalRef<'b>>,
132 {
133 let mut sum = BigDecimal::zero();
134 self.add_refs_into(a, b, &mut sum);
135 sum
136 }
137
138 pub fn add_refs_into<'a, 'b, A, B>(&self, a: A, b: B, dest: &mut BigDecimal)
140 where
141 A: Into<BigDecimalRef<'a>>,
142 B: Into<BigDecimalRef<'b>>,
143 {
144 let sum = a.into() + b.into();
145 *dest = sum.with_precision_round(self.precision, self.rounding)
146 }
147}
148
149
150#[cfg(test)]
151mod test_context {
152 use super::*;
153
154 #[test]
155 fn contstructor_and_setters() {
156 let ctx = Context::default();
157 let c = ctx.with_prec(44).unwrap();
158 assert_eq!(c.precision.get(), 44);
159 assert_eq!(c.rounding, RoundingMode::HalfEven);
160
161 let c = c.with_rounding_mode(RoundingMode::Down);
162 assert_eq!(c.precision.get(), 44);
163 assert_eq!(c.rounding, RoundingMode::Down);
164 }
165
166 #[test]
167 fn sum_two_references() {
168 use stdlib::ops::Neg;
169
170 let ctx = Context::default();
171 let a: BigDecimal = "209682.134972197168613072130300".parse().unwrap();
172 let b: BigDecimal = "3.0782968222271332463325639E-12".parse().unwrap();
173
174 let sum = ctx.add_refs(&a, &b);
175 assert_eq!(sum, "209682.1349721971716913689525271332463325639".parse().unwrap());
176
177 let neg_b = b.to_ref().neg();
179
180 let sum = ctx.add_refs(&a, neg_b);
181 assert_eq!(sum, "209682.1349721971655347753080728667536674361".parse().unwrap());
182
183 let sum = ctx.with_prec(27).unwrap().with_rounding_mode(RoundingMode::Up).add_refs(&a, neg_b);
184 assert_eq!(sum, "209682.134972197165534775309".parse().unwrap());
185 }
186
187 mod round_decimal_ref {
188 use super::*;
189
190 #[test]
191 fn case_bigint_1234567_prec3() {
192 let ctx = Context::default().with_prec(3).unwrap();
193 let i = BigInt::from(1234567);
194 let d = ctx.round_decimal_ref(&i);
195 assert_eq!(d.int_val, 123.into());
196 assert_eq!(d.scale, -4);
197 }
198
199 #[test]
200 fn case_bigint_1234500_prec4_halfup() {
201 let ctx = Context::default()
202 .with_prec(4).unwrap()
203 .with_rounding_mode(RoundingMode::HalfUp);
204 let i = BigInt::from(1234500);
205 let d = ctx.round_decimal_ref(&i);
206 assert_eq!(d.int_val, 1235.into());
207 assert_eq!(d.scale, -3);
208 }
209
210 #[test]
211 fn case_bigint_1234500_prec4_halfeven() {
212 let ctx = Context::default()
213 .with_prec(4).unwrap()
214 .with_rounding_mode(RoundingMode::HalfEven);
215 let i = BigInt::from(1234500);
216 let d = ctx.round_decimal_ref(&i);
217 assert_eq!(d.int_val, 1234.into());
218 assert_eq!(d.scale, -3);
219 }
220
221 #[test]
222 fn case_bigint_1234567_prec10() {
223 let ctx = Context::default().with_prec(10).unwrap();
224 let i = BigInt::from(1234567);
225 let d = ctx.round_decimal_ref(&i);
226 assert_eq!(d.int_val, 1234567000.into());
227 assert_eq!(d.scale, 3);
228 }
229 }
230}