1use crate::{Bounded, ConstChoice, Constants, Encoding, Int, Limb, Uint, Zero};
4use core::{
5 fmt,
6 num::{NonZeroU128, NonZeroU16, NonZeroU32, NonZeroU64, NonZeroU8},
7 ops::Deref,
8};
9use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
10
11#[cfg(feature = "hybrid-array")]
12use crate::{ArrayEncoding, ByteArray};
13
14#[cfg(feature = "rand_core")]
15use {crate::Random, rand_core::RngCore};
16
17#[cfg(feature = "serde")]
18use serdect::serde::{
19 de::{Error, Unexpected},
20 Deserialize, Deserializer, Serialize, Serializer,
21};
22
23#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, PartialOrd, Ord)]
25#[repr(transparent)]
26pub struct NonZero<T>(pub(crate) T);
27
28impl<T> NonZero<T> {
29 pub fn new(n: T) -> CtOption<Self>
31 where
32 T: Zero,
33 {
34 let is_zero = n.is_zero();
35 CtOption::new(Self(n), !is_zero)
36 }
37
38 pub const fn as_ref(&self) -> &T {
40 &self.0
41 }
42
43 pub fn get(self) -> T {
45 self.0
46 }
47}
48
49impl<T> NonZero<T>
50where
51 T: Bounded,
52{
53 pub const BITS: u32 = T::BITS;
55
56 pub const BYTES: usize = T::BYTES;
58}
59
60impl<T> NonZero<T>
61where
62 T: Constants,
63{
64 pub const ONE: Self = Self(T::ONE);
66
67 pub const MAX: Self = Self(T::MAX);
69}
70
71impl<T> NonZero<T>
72where
73 T: Encoding + Zero,
74{
75 pub fn from_be_bytes(bytes: T::Repr) -> CtOption<Self> {
77 Self::new(T::from_be_bytes(bytes))
78 }
79
80 pub fn from_le_bytes(bytes: T::Repr) -> CtOption<Self> {
82 Self::new(T::from_le_bytes(bytes))
83 }
84}
85
86impl NonZero<Limb> {
87 pub const fn new_unwrap(n: Limb) -> Self {
94 if n.is_nonzero().is_true_vartime() {
95 Self(n)
96 } else {
97 panic!("Invalid value: zero")
98 }
99 }
100
101 pub const fn from_u8(n: NonZeroU8) -> Self {
104 Self(Limb::from_u8(n.get()))
105 }
106
107 pub const fn from_u16(n: NonZeroU16) -> Self {
110 Self(Limb::from_u16(n.get()))
111 }
112
113 pub const fn from_u32(n: NonZeroU32) -> Self {
116 Self(Limb::from_u32(n.get()))
117 }
118
119 #[cfg(target_pointer_width = "64")]
122 pub const fn from_u64(n: NonZeroU64) -> Self {
123 Self(Limb::from_u64(n.get()))
124 }
125}
126
127impl<const LIMBS: usize> NonZero<Uint<LIMBS>> {
128 pub const fn new_unwrap(n: Uint<LIMBS>) -> Self {
135 if n.is_nonzero().is_true_vartime() {
136 Self(n)
137 } else {
138 panic!("Invalid value: zero")
139 }
140 }
141
142 pub const fn from_u8(n: NonZeroU8) -> Self {
145 Self(Uint::from_u8(n.get()))
146 }
147
148 pub const fn from_u16(n: NonZeroU16) -> Self {
151 Self(Uint::from_u16(n.get()))
152 }
153
154 pub const fn from_u32(n: NonZeroU32) -> Self {
157 Self(Uint::from_u32(n.get()))
158 }
159
160 pub const fn from_u64(n: NonZeroU64) -> Self {
163 Self(Uint::from_u64(n.get()))
164 }
165
166 pub const fn from_u128(n: NonZeroU128) -> Self {
169 Self(Uint::from_u128(n.get()))
170 }
171}
172
173impl<const LIMBS: usize> NonZero<Int<LIMBS>> {
174 pub const fn abs_sign(&self) -> (NonZero<Uint<LIMBS>>, ConstChoice) {
176 let (abs, sign) = self.0.abs_sign();
177 (NonZero::<Uint<LIMBS>>::new_unwrap(abs), sign)
179 }
180}
181
182#[cfg(feature = "hybrid-array")]
183impl<T> NonZero<T>
184where
185 T: ArrayEncoding + Zero,
186{
187 pub fn from_be_byte_array(bytes: ByteArray<T>) -> CtOption<Self> {
189 Self::new(T::from_be_byte_array(bytes))
190 }
191
192 pub fn from_le_byte_array(bytes: ByteArray<T>) -> CtOption<Self> {
194 Self::new(T::from_be_byte_array(bytes))
195 }
196}
197
198impl<T> AsRef<T> for NonZero<T> {
199 fn as_ref(&self) -> &T {
200 &self.0
201 }
202}
203
204impl<T> ConditionallySelectable for NonZero<T>
205where
206 T: ConditionallySelectable,
207{
208 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
209 Self(T::conditional_select(&a.0, &b.0, choice))
210 }
211}
212
213impl<T> ConstantTimeEq for NonZero<T>
214where
215 T: ConstantTimeEq,
216{
217 fn ct_eq(&self, other: &Self) -> Choice {
218 self.0.ct_eq(&other.0)
219 }
220}
221
222impl<T> Default for NonZero<T>
223where
224 T: Constants,
225{
226 fn default() -> Self {
227 Self(T::ONE)
228 }
229}
230
231impl<T> Deref for NonZero<T> {
232 type Target = T;
233
234 fn deref(&self) -> &T {
235 &self.0
236 }
237}
238
239#[cfg(feature = "rand_core")]
240impl<T> Random for NonZero<T>
241where
242 T: Random + Zero,
243{
244 fn random(mut rng: &mut (impl RngCore + ?Sized)) -> Self {
250 loop {
251 if let Some(result) = Self::new(T::random(&mut rng)).into() {
252 break result;
253 }
254 }
255 }
256}
257
258impl From<NonZeroU8> for NonZero<Limb> {
259 fn from(integer: NonZeroU8) -> Self {
260 Self::from_u8(integer)
261 }
262}
263
264impl From<NonZeroU16> for NonZero<Limb> {
265 fn from(integer: NonZeroU16) -> Self {
266 Self::from_u16(integer)
267 }
268}
269
270impl From<NonZeroU32> for NonZero<Limb> {
271 fn from(integer: NonZeroU32) -> Self {
272 Self::from_u32(integer)
273 }
274}
275
276#[cfg(target_pointer_width = "64")]
277impl From<NonZeroU64> for NonZero<Limb> {
278 fn from(integer: NonZeroU64) -> Self {
279 Self::from_u64(integer)
280 }
281}
282
283impl<const LIMBS: usize> From<NonZeroU8> for NonZero<Uint<LIMBS>> {
284 fn from(integer: NonZeroU8) -> Self {
285 Self::from_u8(integer)
286 }
287}
288
289impl<const LIMBS: usize> From<NonZeroU16> for NonZero<Uint<LIMBS>> {
290 fn from(integer: NonZeroU16) -> Self {
291 Self::from_u16(integer)
292 }
293}
294
295impl<const LIMBS: usize> From<NonZeroU32> for NonZero<Uint<LIMBS>> {
296 fn from(integer: NonZeroU32) -> Self {
297 Self::from_u32(integer)
298 }
299}
300
301impl<const LIMBS: usize> From<NonZeroU64> for NonZero<Uint<LIMBS>> {
302 fn from(integer: NonZeroU64) -> Self {
303 Self::from_u64(integer)
304 }
305}
306
307impl<const LIMBS: usize> From<NonZeroU128> for NonZero<Uint<LIMBS>> {
308 fn from(integer: NonZeroU128) -> Self {
309 Self::from_u128(integer)
310 }
311}
312
313impl<T> fmt::Display for NonZero<T>
314where
315 T: fmt::Display,
316{
317 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
318 fmt::Display::fmt(&self.0, f)
319 }
320}
321
322impl<T> fmt::Binary for NonZero<T>
323where
324 T: fmt::Binary,
325{
326 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
327 fmt::Binary::fmt(&self.0, f)
328 }
329}
330
331impl<T> fmt::Octal for NonZero<T>
332where
333 T: fmt::Octal,
334{
335 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
336 fmt::Octal::fmt(&self.0, f)
337 }
338}
339
340impl<T> fmt::LowerHex for NonZero<T>
341where
342 T: fmt::LowerHex,
343{
344 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
345 fmt::LowerHex::fmt(&self.0, f)
346 }
347}
348
349impl<T> fmt::UpperHex for NonZero<T>
350where
351 T: fmt::UpperHex,
352{
353 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
354 fmt::UpperHex::fmt(&self.0, f)
355 }
356}
357
358#[cfg(feature = "serde")]
359impl<'de, T: Deserialize<'de> + Zero> Deserialize<'de> for NonZero<T> {
360 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
361 where
362 D: Deserializer<'de>,
363 {
364 let value: T = T::deserialize(deserializer)?;
365
366 if bool::from(value.is_zero()) {
367 Err(D::Error::invalid_value(
368 Unexpected::Other("zero"),
369 &"a non-zero value",
370 ))
371 } else {
372 Ok(Self(value))
373 }
374 }
375}
376
377#[cfg(feature = "serde")]
378impl<T: Serialize + Zero> Serialize for NonZero<T> {
379 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
380 where
381 S: Serializer,
382 {
383 self.0.serialize(serializer)
384 }
385}
386
387#[cfg(feature = "zeroize")]
388impl<T: zeroize::Zeroize + Zero> zeroize::Zeroize for NonZero<T> {
389 fn zeroize(&mut self) {
390 self.0.zeroize();
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use crate::{ConstChoice, I128, U128};
397
398 #[test]
399 fn int_abs_sign() {
400 let x = I128::from(-55).to_nz().unwrap();
401 let (abs, sgn) = x.abs_sign();
402 assert_eq!(abs, U128::from(55u32).to_nz().unwrap());
403 assert_eq!(sgn, ConstChoice::TRUE);
404 }
405}
406
407#[cfg(all(test, feature = "serde"))]
408#[allow(clippy::unwrap_used)]
409mod tests_serde {
410 use bincode::ErrorKind;
411
412 use crate::{NonZero, U64};
413
414 #[test]
415 fn serde() {
416 let test =
417 Option::<NonZero<U64>>::from(NonZero::new(U64::from_u64(0x0011223344556677))).unwrap();
418
419 let serialized = bincode::serialize(&test).unwrap();
420 let deserialized: NonZero<U64> = bincode::deserialize(&serialized).unwrap();
421
422 assert_eq!(test, deserialized);
423
424 let serialized = bincode::serialize(&U64::ZERO).unwrap();
425 assert!(matches!(
426 *bincode::deserialize::<NonZero<U64>>(&serialized).unwrap_err(),
427 ErrorKind::Custom(message) if message == "invalid value: zero, expected a non-zero value"
428 ));
429 }
430
431 #[test]
432 fn serde_owned() {
433 let test =
434 Option::<NonZero<U64>>::from(NonZero::new(U64::from_u64(0x0011223344556677))).unwrap();
435
436 let serialized = bincode::serialize(&test).unwrap();
437 let deserialized: NonZero<U64> = bincode::deserialize_from(serialized.as_slice()).unwrap();
438
439 assert_eq!(test, deserialized);
440
441 let serialized = bincode::serialize(&U64::ZERO).unwrap();
442 assert!(matches!(
443 *bincode::deserialize_from::<_, NonZero<U64>>(serialized.as_slice()).unwrap_err(),
444 ErrorKind::Custom(message) if message == "invalid value: zero, expected a non-zero value"
445 ));
446 }
447}