crypto_bigint/uint/boxed/
shr.rs

1//! [`BoxedUint`] bitwise right shift operations.
2
3use crate::{BoxedUint, ConstantTimeSelect, Limb, ShrVartime, WrappingShr, Zero};
4use core::ops::{Shr, ShrAssign};
5use subtle::{Choice, ConstantTimeLess, CtOption};
6
7impl BoxedUint {
8    /// Computes `self >> shift`.
9    ///
10    /// Panics if `shift >= Self::BITS`.
11    pub fn shr(&self, shift: u32) -> BoxedUint {
12        let (result, overflow) = self.overflowing_shr(shift);
13        assert!(
14            !bool::from(overflow),
15            "attempt to shift right with overflow"
16        );
17        result
18    }
19
20    /// Computes `self >>= shift`.
21    ///
22    /// Panics if `shift >= Self::BITS`.
23    pub fn shr_assign(&mut self, shift: u32) {
24        let overflow = self.overflowing_shr_assign(shift);
25        assert!(
26            !bool::from(overflow),
27            "attempt to shift right with overflow"
28        );
29    }
30
31    /// Computes `self >> shift`.
32    ///
33    /// Returns a zero and a truthy `Choice` if `shift >= self.bits_precision()`,
34    /// or the result and a falsy `Choice` otherwise.
35    pub fn overflowing_shr(&self, shift: u32) -> (Self, Choice) {
36        let mut result = self.clone();
37        let overflow = result.overflowing_shr_assign(shift);
38        (result, overflow)
39    }
40
41    /// Computes `self >>= shift`.
42    ///
43    /// Returns a truthy `Choice` if `shift >= self.bits_precision()` or a falsy `Choice` otherwise.
44    pub fn overflowing_shr_assign(&mut self, shift: u32) -> Choice {
45        // `floor(log2(bits_precision - 1))` is the number of bits in the representation of `shift`
46        // (which lies in range `0 <= shift < bits_precision`).
47        let shift_bits = u32::BITS - (self.bits_precision() - 1).leading_zeros();
48        let overflow = !shift.ct_lt(&self.bits_precision());
49        let shift = shift % self.bits_precision();
50        let mut temp = self.clone();
51
52        for i in 0..shift_bits {
53            let bit = Choice::from(((shift >> i) & 1) as u8);
54            temp.set_zero();
55            // Will not overflow by construction
56            self.shr_vartime_into(&mut temp, 1 << i)
57                .expect("shift within range");
58            self.ct_assign(&temp, bit);
59        }
60
61        #[cfg(feature = "zeroize")]
62        zeroize::Zeroize::zeroize(&mut temp);
63
64        self.conditional_set_zero(overflow);
65        overflow
66    }
67
68    /// Computes `self >> shift` in a panic-free manner, masking off bits of `shift` which would cause the shift to
69    /// exceed the type's width.
70    pub fn wrapping_shr(&self, shift: u32) -> Self {
71        self.overflowing_shr(shift).0
72    }
73
74    /// Computes `self >> shift` in variable-time in a panic-free manner, masking off bits of `shift` which would cause
75    /// the shift to exceed the type's width.
76    pub fn wrapping_shr_vartime(&self, shift: u32) -> Self {
77        let mut result = Self::zero_with_precision(self.bits_precision());
78        let _ = self.shr_vartime_into(&mut result, shift);
79        result
80    }
81
82    /// Computes `self >> shift`.
83    /// Returns `None` if `shift >= self.bits_precision()`.
84    ///
85    /// WARNING: for performance reasons, `dest` is assumed to be pre-zeroized.
86    ///
87    /// NOTE: this operation is variable time with respect to `shift` *ONLY*.
88    ///
89    /// When used with a fixed `shift`, this function is constant-time with respect to `self`.
90    #[inline(always)]
91    fn shr_vartime_into(&self, dest: &mut Self, shift: u32) -> Option<()> {
92        if shift >= self.bits_precision() {
93            return None;
94        }
95
96        let nlimbs = self.nlimbs();
97        let shift_num = (shift / Limb::BITS) as usize;
98        let rem = shift % Limb::BITS;
99
100        for i in 0..nlimbs - shift_num {
101            dest.limbs[i] = self.limbs[i + shift_num];
102        }
103
104        if rem == 0 {
105            return Some(());
106        }
107
108        for i in 0..nlimbs - shift_num - 1 {
109            let shifted = dest.limbs[i].shr(rem);
110            let carry = dest.limbs[i + 1].shl(Limb::BITS - rem);
111            dest.limbs[i] = shifted.bitor(carry);
112        }
113        dest.limbs[nlimbs - shift_num - 1] = dest.limbs[nlimbs - shift_num - 1].shr(rem);
114
115        Some(())
116    }
117
118    /// Computes `self >> shift`.
119    /// Returns `None` if `shift >= self.bits_precision()`.
120    ///
121    /// NOTE: this operation is variable time with respect to `shift` *ONLY*.
122    ///
123    /// When used with a fixed `shift`, this function is constant-time with respect to `self`.
124    #[inline(always)]
125    pub fn shr_vartime(&self, shift: u32) -> Option<Self> {
126        let mut result = Self::zero_with_precision(self.bits_precision());
127        let success = self.shr_vartime_into(&mut result, shift);
128        success.map(|_| result)
129    }
130
131    /// Computes `self >> 1` in constant-time, returning a true [`Choice`]
132    /// if the least significant bit was set, and a false [`Choice::FALSE`] otherwise.
133    pub(crate) fn shr1_with_carry(&self) -> (Self, Choice) {
134        let carry = self.limbs[0].0 & 1;
135        (self.shr1(), Choice::from(carry as u8))
136    }
137
138    /// Computes `self >> 1` in constant-time.
139    pub(crate) fn shr1(&self) -> Self {
140        let mut ret = self.clone();
141        ret.shr1_assign();
142        ret
143    }
144
145    /// Computes `self >> 1` in-place in constant-time.
146    pub(crate) fn shr1_assign(&mut self) {
147        self.limbs[0].shr_assign(1);
148
149        for i in 1..self.limbs.len() {
150            // set carry bit
151            self.limbs[i - 1].0 |= (self.limbs[i].0 & 1) << Limb::HI_BIT;
152            self.limbs[i].shr_assign(1);
153        }
154    }
155}
156
157macro_rules! impl_shr {
158    ($($shift:ty),+) => {
159        $(
160            impl Shr<$shift> for BoxedUint {
161                type Output = BoxedUint;
162
163                #[inline]
164                fn shr(self, shift: $shift) -> BoxedUint {
165                    <&Self>::shr(&self, shift)
166                }
167            }
168
169            impl Shr<$shift> for &BoxedUint {
170                type Output = BoxedUint;
171
172                #[inline]
173                fn shr(self, shift: $shift) -> BoxedUint {
174                    BoxedUint::shr(self, u32::try_from(shift).expect("invalid shift"))
175                }
176            }
177
178            impl ShrAssign<$shift> for BoxedUint {
179                fn shr_assign(&mut self, shift: $shift) {
180                    BoxedUint::shr_assign(self, u32::try_from(shift).expect("invalid shift"))
181                }
182            }
183        )+
184    };
185}
186
187impl_shr!(i32, u32, usize);
188
189impl WrappingShr for BoxedUint {
190    fn wrapping_shr(&self, shift: u32) -> BoxedUint {
191        self.wrapping_shr(shift)
192    }
193}
194
195impl ShrVartime for BoxedUint {
196    fn overflowing_shr_vartime(&self, shift: u32) -> CtOption<Self> {
197        let (result, overflow) = self.overflowing_shr(shift);
198        CtOption::new(result, !overflow)
199    }
200    fn wrapping_shr_vartime(&self, shift: u32) -> Self {
201        self.wrapping_shr(shift)
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::BoxedUint;
208
209    #[test]
210    fn shr1_assign() {
211        let mut n = BoxedUint::from(0x3c442b21f19185fe433f0a65af902b8fu128);
212        let n_shr1 = BoxedUint::from(0x1e221590f8c8c2ff219f8532d7c815c7u128);
213        n.shr1_assign();
214        assert_eq!(n, n_shr1);
215    }
216
217    #[test]
218    fn shr() {
219        let n = BoxedUint::from(0x80000000000000000u128);
220        assert_eq!(BoxedUint::zero(), &n >> 68);
221        assert_eq!(BoxedUint::one(), &n >> 67);
222        assert_eq!(BoxedUint::from(2u8), &n >> 66);
223        assert_eq!(BoxedUint::from(4u8), &n >> 65);
224    }
225
226    #[test]
227    fn shr_vartime() {
228        let n = BoxedUint::from(0x80000000000000000u128);
229        assert_eq!(BoxedUint::zero(), n.shr_vartime(68).unwrap());
230        assert_eq!(BoxedUint::one(), n.shr_vartime(67).unwrap());
231        assert_eq!(BoxedUint::from(2u8), n.shr_vartime(66).unwrap());
232        assert_eq!(BoxedUint::from(4u8), n.shr_vartime(65).unwrap());
233    }
234}