crypto_bigint/
non_zero.rs

1//! Wrapper type for non-zero integers.
2
3use crate::{Bounded, ConstChoice, Constants, Encoding, Int, Limb, Uint, Zero};
4use core::{
5    fmt,
6    num::{NonZeroU128, NonZeroU16, NonZeroU32, NonZeroU64, NonZeroU8},
7    ops::Deref,
8};
9use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
10
11#[cfg(feature = "hybrid-array")]
12use crate::{ArrayEncoding, ByteArray};
13
14#[cfg(feature = "rand_core")]
15use {crate::Random, rand_core::RngCore};
16
17#[cfg(feature = "serde")]
18use serdect::serde::{
19    de::{Error, Unexpected},
20    Deserialize, Deserializer, Serialize, Serializer,
21};
22
23/// Wrapper type for non-zero integers.
24#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, PartialOrd, Ord)]
25#[repr(transparent)]
26pub struct NonZero<T>(pub(crate) T);
27
28impl<T> NonZero<T> {
29    /// Create a new non-zero integer.
30    pub fn new(n: T) -> CtOption<Self>
31    where
32        T: Zero,
33    {
34        let is_zero = n.is_zero();
35        CtOption::new(Self(n), !is_zero)
36    }
37
38    /// Provides access to the contents of `NonZero` in a `const` context.
39    pub const fn as_ref(&self) -> &T {
40        &self.0
41    }
42
43    /// Returns the inner value.
44    pub fn get(self) -> T {
45        self.0
46    }
47}
48
49impl<T> NonZero<T>
50where
51    T: Bounded,
52{
53    /// Total size of the represented integer in bits.
54    pub const BITS: u32 = T::BITS;
55
56    /// Total size of the represented integer in bytes.
57    pub const BYTES: usize = T::BYTES;
58}
59
60impl<T> NonZero<T>
61where
62    T: Constants,
63{
64    /// The value `1`.
65    pub const ONE: Self = Self(T::ONE);
66
67    /// Maximum value this integer can express.
68    pub const MAX: Self = Self(T::MAX);
69}
70
71impl<T> NonZero<T>
72where
73    T: Encoding + Zero,
74{
75    /// Decode from big endian bytes.
76    pub fn from_be_bytes(bytes: T::Repr) -> CtOption<Self> {
77        Self::new(T::from_be_bytes(bytes))
78    }
79
80    /// Decode from little endian bytes.
81    pub fn from_le_bytes(bytes: T::Repr) -> CtOption<Self> {
82        Self::new(T::from_le_bytes(bytes))
83    }
84}
85
86impl NonZero<Limb> {
87    /// Creates a new non-zero limb in a const context.
88    /// Panics if the value is zero.
89    ///
90    /// In future versions of Rust it should be possible to replace this with
91    /// `NonZero::new(…).unwrap()`
92    // TODO: Remove when `Self::new` and `CtOption::unwrap` support `const fn`
93    pub const fn new_unwrap(n: Limb) -> Self {
94        if n.is_nonzero().is_true_vartime() {
95            Self(n)
96        } else {
97            panic!("Invalid value: zero")
98        }
99    }
100
101    /// Create a [`NonZero<Limb>`] from a [`NonZeroU8`] (const-friendly)
102    // TODO(tarcieri): replace with `const impl From<NonZeroU8>` when stable
103    pub const fn from_u8(n: NonZeroU8) -> Self {
104        Self(Limb::from_u8(n.get()))
105    }
106
107    /// Create a [`NonZero<Limb>`] from a [`NonZeroU16`] (const-friendly)
108    // TODO(tarcieri): replace with `const impl From<NonZeroU16>` when stable
109    pub const fn from_u16(n: NonZeroU16) -> Self {
110        Self(Limb::from_u16(n.get()))
111    }
112
113    /// Create a [`NonZero<Limb>`] from a [`NonZeroU32`] (const-friendly)
114    // TODO(tarcieri): replace with `const impl From<NonZeroU32>` when stable
115    pub const fn from_u32(n: NonZeroU32) -> Self {
116        Self(Limb::from_u32(n.get()))
117    }
118
119    /// Create a [`NonZero<Limb>`] from a [`NonZeroU64`] (const-friendly)
120    // TODO(tarcieri): replace with `const impl From<NonZeroU64>` when stable
121    #[cfg(target_pointer_width = "64")]
122    pub const fn from_u64(n: NonZeroU64) -> Self {
123        Self(Limb::from_u64(n.get()))
124    }
125}
126
127impl<const LIMBS: usize> NonZero<Uint<LIMBS>> {
128    /// Creates a new non-zero integer in a const context.
129    /// Panics if the value is zero.
130    ///
131    /// In future versions of Rust it should be possible to replace this with
132    /// `NonZero::new(…).unwrap()`
133    // TODO: Remove when `Self::new` and `CtOption::unwrap` support `const fn`
134    pub const fn new_unwrap(n: Uint<LIMBS>) -> Self {
135        if n.is_nonzero().is_true_vartime() {
136            Self(n)
137        } else {
138            panic!("Invalid value: zero")
139        }
140    }
141
142    /// Create a [`NonZero<Uint>`] from a [`NonZeroU8`] (const-friendly)
143    // TODO(tarcieri): replace with `const impl From<NonZeroU8>` when stable
144    pub const fn from_u8(n: NonZeroU8) -> Self {
145        Self(Uint::from_u8(n.get()))
146    }
147
148    /// Create a [`NonZero<Uint>`] from a [`NonZeroU16`] (const-friendly)
149    // TODO(tarcieri): replace with `const impl From<NonZeroU16>` when stable
150    pub const fn from_u16(n: NonZeroU16) -> Self {
151        Self(Uint::from_u16(n.get()))
152    }
153
154    /// Create a [`NonZero<Uint>`] from a [`NonZeroU32`] (const-friendly)
155    // TODO(tarcieri): replace with `const impl From<NonZeroU32>` when stable
156    pub const fn from_u32(n: NonZeroU32) -> Self {
157        Self(Uint::from_u32(n.get()))
158    }
159
160    /// Create a [`NonZero<Uint>`] from a [`NonZeroU64`] (const-friendly)
161    // TODO(tarcieri): replace with `const impl From<NonZeroU64>` when stable
162    pub const fn from_u64(n: NonZeroU64) -> Self {
163        Self(Uint::from_u64(n.get()))
164    }
165
166    /// Create a [`NonZero<Uint>`] from a [`NonZeroU128`] (const-friendly)
167    // TODO(tarcieri): replace with `const impl From<NonZeroU128>` when stable
168    pub const fn from_u128(n: NonZeroU128) -> Self {
169        Self(Uint::from_u128(n.get()))
170    }
171}
172
173impl<const LIMBS: usize> NonZero<Int<LIMBS>> {
174    /// Convert a [`NonZero<Int>`] to its sign and [`NonZero<Uint>`] magnitude.
175    pub const fn abs_sign(&self) -> (NonZero<Uint<LIMBS>>, ConstChoice) {
176        let (abs, sign) = self.0.abs_sign();
177        // Note: a NonZero<Int> always has a non-zero magnitude, so it is safe to unwrap.
178        (NonZero::<Uint<LIMBS>>::new_unwrap(abs), sign)
179    }
180}
181
182#[cfg(feature = "hybrid-array")]
183impl<T> NonZero<T>
184where
185    T: ArrayEncoding + Zero,
186{
187    /// Decode a non-zero integer from big endian bytes.
188    pub fn from_be_byte_array(bytes: ByteArray<T>) -> CtOption<Self> {
189        Self::new(T::from_be_byte_array(bytes))
190    }
191
192    /// Decode a non-zero integer from big endian bytes.
193    pub fn from_le_byte_array(bytes: ByteArray<T>) -> CtOption<Self> {
194        Self::new(T::from_be_byte_array(bytes))
195    }
196}
197
198impl<T> AsRef<T> for NonZero<T> {
199    fn as_ref(&self) -> &T {
200        &self.0
201    }
202}
203
204impl<T> ConditionallySelectable for NonZero<T>
205where
206    T: ConditionallySelectable,
207{
208    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
209        Self(T::conditional_select(&a.0, &b.0, choice))
210    }
211}
212
213impl<T> ConstantTimeEq for NonZero<T>
214where
215    T: ConstantTimeEq,
216{
217    fn ct_eq(&self, other: &Self) -> Choice {
218        self.0.ct_eq(&other.0)
219    }
220}
221
222impl<T> Default for NonZero<T>
223where
224    T: Constants,
225{
226    fn default() -> Self {
227        Self(T::ONE)
228    }
229}
230
231impl<T> Deref for NonZero<T> {
232    type Target = T;
233
234    fn deref(&self) -> &T {
235        &self.0
236    }
237}
238
239#[cfg(feature = "rand_core")]
240impl<T> Random for NonZero<T>
241where
242    T: Random + Zero,
243{
244    /// This uses rejection sampling to avoid zero.
245    ///
246    /// As a result, it runs in variable time. If the generator `rng` is
247    /// cryptographically secure (for example, it implements `CryptoRng`),
248    /// then this is guaranteed not to leak anything about the output value.
249    fn random(mut rng: &mut (impl RngCore + ?Sized)) -> Self {
250        loop {
251            if let Some(result) = Self::new(T::random(&mut rng)).into() {
252                break result;
253            }
254        }
255    }
256}
257
258impl From<NonZeroU8> for NonZero<Limb> {
259    fn from(integer: NonZeroU8) -> Self {
260        Self::from_u8(integer)
261    }
262}
263
264impl From<NonZeroU16> for NonZero<Limb> {
265    fn from(integer: NonZeroU16) -> Self {
266        Self::from_u16(integer)
267    }
268}
269
270impl From<NonZeroU32> for NonZero<Limb> {
271    fn from(integer: NonZeroU32) -> Self {
272        Self::from_u32(integer)
273    }
274}
275
276#[cfg(target_pointer_width = "64")]
277impl From<NonZeroU64> for NonZero<Limb> {
278    fn from(integer: NonZeroU64) -> Self {
279        Self::from_u64(integer)
280    }
281}
282
283impl<const LIMBS: usize> From<NonZeroU8> for NonZero<Uint<LIMBS>> {
284    fn from(integer: NonZeroU8) -> Self {
285        Self::from_u8(integer)
286    }
287}
288
289impl<const LIMBS: usize> From<NonZeroU16> for NonZero<Uint<LIMBS>> {
290    fn from(integer: NonZeroU16) -> Self {
291        Self::from_u16(integer)
292    }
293}
294
295impl<const LIMBS: usize> From<NonZeroU32> for NonZero<Uint<LIMBS>> {
296    fn from(integer: NonZeroU32) -> Self {
297        Self::from_u32(integer)
298    }
299}
300
301impl<const LIMBS: usize> From<NonZeroU64> for NonZero<Uint<LIMBS>> {
302    fn from(integer: NonZeroU64) -> Self {
303        Self::from_u64(integer)
304    }
305}
306
307impl<const LIMBS: usize> From<NonZeroU128> for NonZero<Uint<LIMBS>> {
308    fn from(integer: NonZeroU128) -> Self {
309        Self::from_u128(integer)
310    }
311}
312
313impl<T> fmt::Display for NonZero<T>
314where
315    T: fmt::Display,
316{
317    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
318        fmt::Display::fmt(&self.0, f)
319    }
320}
321
322impl<T> fmt::Binary for NonZero<T>
323where
324    T: fmt::Binary,
325{
326    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
327        fmt::Binary::fmt(&self.0, f)
328    }
329}
330
331impl<T> fmt::Octal for NonZero<T>
332where
333    T: fmt::Octal,
334{
335    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
336        fmt::Octal::fmt(&self.0, f)
337    }
338}
339
340impl<T> fmt::LowerHex for NonZero<T>
341where
342    T: fmt::LowerHex,
343{
344    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
345        fmt::LowerHex::fmt(&self.0, f)
346    }
347}
348
349impl<T> fmt::UpperHex for NonZero<T>
350where
351    T: fmt::UpperHex,
352{
353    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
354        fmt::UpperHex::fmt(&self.0, f)
355    }
356}
357
358#[cfg(feature = "serde")]
359impl<'de, T: Deserialize<'de> + Zero> Deserialize<'de> for NonZero<T> {
360    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
361    where
362        D: Deserializer<'de>,
363    {
364        let value: T = T::deserialize(deserializer)?;
365
366        if bool::from(value.is_zero()) {
367            Err(D::Error::invalid_value(
368                Unexpected::Other("zero"),
369                &"a non-zero value",
370            ))
371        } else {
372            Ok(Self(value))
373        }
374    }
375}
376
377#[cfg(feature = "serde")]
378impl<T: Serialize + Zero> Serialize for NonZero<T> {
379    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
380    where
381        S: Serializer,
382    {
383        self.0.serialize(serializer)
384    }
385}
386
387#[cfg(feature = "zeroize")]
388impl<T: zeroize::Zeroize + Zero> zeroize::Zeroize for NonZero<T> {
389    fn zeroize(&mut self) {
390        self.0.zeroize();
391    }
392}
393
394#[cfg(test)]
395mod tests {
396    use crate::{ConstChoice, I128, U128};
397
398    #[test]
399    fn int_abs_sign() {
400        let x = I128::from(-55).to_nz().unwrap();
401        let (abs, sgn) = x.abs_sign();
402        assert_eq!(abs, U128::from(55u32).to_nz().unwrap());
403        assert_eq!(sgn, ConstChoice::TRUE);
404    }
405}
406
407#[cfg(all(test, feature = "serde"))]
408#[allow(clippy::unwrap_used)]
409mod tests_serde {
410    use bincode::ErrorKind;
411
412    use crate::{NonZero, U64};
413
414    #[test]
415    fn serde() {
416        let test =
417            Option::<NonZero<U64>>::from(NonZero::new(U64::from_u64(0x0011223344556677))).unwrap();
418
419        let serialized = bincode::serialize(&test).unwrap();
420        let deserialized: NonZero<U64> = bincode::deserialize(&serialized).unwrap();
421
422        assert_eq!(test, deserialized);
423
424        let serialized = bincode::serialize(&U64::ZERO).unwrap();
425        assert!(matches!(
426            *bincode::deserialize::<NonZero<U64>>(&serialized).unwrap_err(),
427            ErrorKind::Custom(message) if message == "invalid value: zero, expected a non-zero value"
428        ));
429    }
430
431    #[test]
432    fn serde_owned() {
433        let test =
434            Option::<NonZero<U64>>::from(NonZero::new(U64::from_u64(0x0011223344556677))).unwrap();
435
436        let serialized = bincode::serialize(&test).unwrap();
437        let deserialized: NonZero<U64> = bincode::deserialize_from(serialized.as_slice()).unwrap();
438
439        assert_eq!(test, deserialized);
440
441        let serialized = bincode::serialize(&U64::ZERO).unwrap();
442        assert!(matches!(
443            *bincode::deserialize_from::<_, NonZero<U64>>(serialized.as_slice()).unwrap_err(),
444            ErrorKind::Custom(message) if message == "invalid value: zero, expected a non-zero value"
445        ));
446    }
447}