1use crate::{Integer, Limb, NonZero, Uint};
4use core::{cmp::Ordering, fmt, ops::Deref};
5use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
6
7#[cfg(feature = "alloc")]
8use crate::BoxedUint;
9
10#[cfg(feature = "rand_core")]
11use {crate::Random, rand_core::RngCore};
12
13#[cfg(all(feature = "alloc", feature = "rand_core"))]
14use crate::RandomBits;
15
16#[cfg(feature = "serde")]
17use crate::Zero;
18#[cfg(feature = "serde")]
19use serdect::serde::{
20 de::{Error, Unexpected},
21 Deserialize, Deserializer, Serialize, Serializer,
22};
23
24#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, PartialOrd, Ord)]
28#[repr(transparent)]
29pub struct Odd<T>(pub(crate) T);
30
31impl<T> Odd<T> {
32 pub fn new(n: T) -> CtOption<Self>
34 where
35 T: Integer,
36 {
37 let is_odd = n.is_odd();
38 CtOption::new(Self(n), is_odd)
39 }
40
41 pub const fn as_ref(&self) -> &T {
43 &self.0
44 }
45
46 pub const fn as_nz_ref(&self) -> &NonZero<T> {
48 #[allow(trivial_casts, unsafe_code)]
49 unsafe {
50 &*(&self.0 as *const T as *const NonZero<T>)
51 }
52 }
53
54 pub fn get(self) -> T {
56 self.0
57 }
58}
59
60impl<const LIMBS: usize> Odd<Uint<LIMBS>> {
61 pub const fn from_be_hex(hex: &str) -> Self {
65 let uint = Uint::<LIMBS>::from_be_hex(hex);
66 assert!(uint.is_odd().is_true_vartime(), "number must be odd");
67 Odd(uint)
68 }
69
70 pub const fn from_le_hex(hex: &str) -> Self {
74 let uint = Uint::<LIMBS>::from_be_hex(hex);
75 assert!(uint.is_odd().is_true_vartime(), "number must be odd");
76 Odd(uint)
77 }
78}
79
80impl<T> AsRef<T> for Odd<T> {
81 fn as_ref(&self) -> &T {
82 &self.0
83 }
84}
85
86impl<T> AsRef<[Limb]> for Odd<T>
87where
88 T: AsRef<[Limb]>,
89{
90 fn as_ref(&self) -> &[Limb] {
91 self.0.as_ref()
92 }
93}
94
95impl<T> AsRef<NonZero<T>> for Odd<T> {
96 fn as_ref(&self) -> &NonZero<T> {
97 self.as_nz_ref()
98 }
99}
100
101impl<T> ConditionallySelectable for Odd<T>
102where
103 T: ConditionallySelectable,
104{
105 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
106 Self(T::conditional_select(&a.0, &b.0, choice))
107 }
108}
109
110impl<T> ConstantTimeEq for Odd<T>
111where
112 T: ConstantTimeEq,
113{
114 fn ct_eq(&self, other: &Self) -> Choice {
115 self.0.ct_eq(&other.0)
116 }
117}
118
119impl<T> Deref for Odd<T> {
120 type Target = T;
121
122 fn deref(&self) -> &T {
123 &self.0
124 }
125}
126
127impl<const LIMBS: usize> PartialEq<Odd<Uint<LIMBS>>> for Uint<LIMBS> {
128 fn eq(&self, other: &Odd<Uint<LIMBS>>) -> bool {
129 self.eq(&other.0)
130 }
131}
132
133impl<const LIMBS: usize> PartialOrd<Odd<Uint<LIMBS>>> for Uint<LIMBS> {
134 fn partial_cmp(&self, other: &Odd<Uint<LIMBS>>) -> Option<Ordering> {
135 Some(self.cmp(&other.0))
136 }
137}
138
139#[cfg(feature = "alloc")]
140impl PartialEq<Odd<BoxedUint>> for BoxedUint {
141 fn eq(&self, other: &Odd<BoxedUint>) -> bool {
142 self.eq(&other.0)
143 }
144}
145
146#[cfg(feature = "alloc")]
147impl PartialOrd<Odd<BoxedUint>> for BoxedUint {
148 fn partial_cmp(&self, other: &Odd<BoxedUint>) -> Option<Ordering> {
149 Some(self.cmp(&other.0))
150 }
151}
152
153#[cfg(feature = "rand_core")]
154impl<const LIMBS: usize> Random for Odd<Uint<LIMBS>> {
155 fn random(rng: &mut (impl RngCore + ?Sized)) -> Self {
157 let mut ret = Uint::random(rng);
158 ret.limbs[0] |= Limb::ONE;
159 Odd(ret)
160 }
161}
162
163#[cfg(all(feature = "alloc", feature = "rand_core"))]
164impl Odd<BoxedUint> {
165 pub fn random(rng: &mut (impl RngCore + ?Sized), bit_length: u32) -> Self {
167 let mut ret = BoxedUint::random_bits(rng, bit_length);
168 ret.limbs[0] |= Limb::ONE;
169 Odd(ret)
170 }
171}
172
173impl<T> fmt::Display for Odd<T>
174where
175 T: fmt::Display,
176{
177 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
178 fmt::Display::fmt(&self.0, f)
179 }
180}
181
182impl<T> fmt::Binary for Odd<T>
183where
184 T: fmt::Binary,
185{
186 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
187 fmt::Binary::fmt(&self.0, f)
188 }
189}
190
191impl<T> fmt::Octal for Odd<T>
192where
193 T: fmt::Octal,
194{
195 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
196 fmt::Octal::fmt(&self.0, f)
197 }
198}
199
200impl<T> fmt::LowerHex for Odd<T>
201where
202 T: fmt::LowerHex,
203{
204 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
205 fmt::LowerHex::fmt(&self.0, f)
206 }
207}
208
209impl<T> fmt::UpperHex for Odd<T>
210where
211 T: fmt::UpperHex,
212{
213 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
214 fmt::UpperHex::fmt(&self.0, f)
215 }
216}
217
218#[cfg(feature = "serde")]
219impl<'de, T: Deserialize<'de> + Integer + Zero> Deserialize<'de> for Odd<T> {
220 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
221 where
222 D: Deserializer<'de>,
223 {
224 let value: T = T::deserialize(deserializer)?;
225 Option::<Self>::from(Self::new(value)).ok_or(D::Error::invalid_value(
226 Unexpected::Other("even"),
227 &"a non-zero odd value",
228 ))
229 }
230}
231
232#[cfg(feature = "serde")]
233impl<T: Serialize + Zero> Serialize for Odd<T> {
234 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
235 where
236 S: Serializer,
237 {
238 self.0.serialize(serializer)
239 }
240}
241#[cfg(feature = "zeroize")]
242impl<T: zeroize::Zeroize> zeroize::Zeroize for Odd<T> {
243 fn zeroize(&mut self) {
244 self.0.zeroize();
245 }
246}
247#[cfg(test)]
248mod tests {
249 #[cfg(feature = "alloc")]
250 use super::BoxedUint;
251 use super::{Odd, Uint};
252
253 #[test]
254 fn not_odd_numbers() {
255 let zero = Odd::new(Uint::<4>::ZERO);
256 assert!(bool::from(zero.is_none()));
257 let two = Odd::new(Uint::<4>::from(2u8));
258 assert!(bool::from(two.is_none()));
259 }
260
261 #[cfg(feature = "alloc")]
262 #[test]
263 fn not_odd_numbers_boxed() {
264 let zero = Odd::new(BoxedUint::zero());
265 assert!(bool::from(zero.is_none()));
266 let two = Odd::new(BoxedUint::from(2u8));
267 assert!(bool::from(two.is_none()));
268 }
269
270 #[cfg(feature = "serde")]
271 mod serde_tests {
272 use crate::{Odd, U128, U64};
273 use bincode::ErrorKind;
274
275 #[test]
276 fn roundtrip() {
277 let uint = Odd::new(U64::from_u64(0x00123)).unwrap();
278 let ser = bincode::serialize(&uint).unwrap();
279 let deser = bincode::deserialize::<Odd<U64>>(&ser).unwrap();
280
281 assert_eq!(uint, deser);
282 }
283
284 #[test]
285 fn even_values_do_not_deserialize() {
286 let two = U128::from_u64(0x2);
287 let two_ser = bincode::serialize(&two).unwrap();
288 assert!(matches!(
289 *bincode::deserialize::<Odd<U128>>(&two_ser).unwrap_err(),
290 ErrorKind::Custom(mess) if mess == "invalid value: even, expected a non-zero odd value"
291 ))
292 }
293
294 #[test]
295 fn zero_does_not_deserialize() {
296 let zero = U64::ZERO;
297 let zero_ser = bincode::serialize(&zero).unwrap();
298
299 assert!(matches!(
300 *bincode::deserialize::<Odd<U64>>(&zero_ser).unwrap_err(),
301 ErrorKind::Custom(mess) if mess == "invalid value: even, expected a non-zero odd value"
302 ))
303 }
304 }
305}