crypto_bigint/uint/
shr.rs

1//! [`Uint`] bitwise right shift operations.
2
3use crate::{ConstChoice, ConstCtOption, Limb, ShrVartime, Uint, WrappingShr};
4use core::ops::{Shr, ShrAssign};
5use subtle::CtOption;
6
7impl<const LIMBS: usize> Uint<LIMBS> {
8    /// Computes `self >> shift`.
9    ///
10    /// Panics if `shift >= Self::BITS`.
11    pub const fn shr(&self, shift: u32) -> Self {
12        self.overflowing_shr(shift)
13            .expect("`shift` within the bit size of the integer")
14    }
15
16    /// Computes `self >> shift` in variable time.
17    ///
18    /// Panics if `shift >= Self::BITS`.
19    pub const fn shr_vartime(&self, shift: u32) -> Self {
20        self.overflowing_shr_vartime(shift)
21            .expect("`shift` within the bit size of the integer")
22    }
23
24    /// Computes `self >> shift`.
25    ///
26    /// Returns `None` if `shift >= Self::BITS`.
27    pub const fn overflowing_shr(&self, shift: u32) -> ConstCtOption<Self> {
28        // `floor(log2(BITS - 1))` is the number of bits in the representation of `shift`
29        // (which lies in range `0 <= shift < BITS`).
30        let shift_bits = u32::BITS - (Self::BITS - 1).leading_zeros();
31        let overflow = ConstChoice::from_u32_lt(shift, Self::BITS).not();
32        let shift = shift % Self::BITS;
33        let mut result = *self;
34        let mut i = 0;
35        while i < shift_bits {
36            let bit = ConstChoice::from_u32_lsb((shift >> i) & 1);
37            result = Uint::select(
38                &result,
39                &result
40                    .overflowing_shr_vartime(1 << i)
41                    .expect("shift within range"),
42                bit,
43            );
44            i += 1;
45        }
46
47        ConstCtOption::new(Uint::select(&result, &Self::ZERO, overflow), overflow.not())
48    }
49
50    /// Computes `self >> shift`.
51    ///
52    /// Returns `None` if `shift >= Self::BITS`.
53    ///
54    /// NOTE: this operation is variable time with respect to `shift` *ONLY*.
55    ///
56    /// When used with a fixed `shift`, this function is constant-time with respect
57    /// to `self`.
58    #[inline(always)]
59    pub const fn overflowing_shr_vartime(&self, shift: u32) -> ConstCtOption<Self> {
60        let mut limbs = [Limb::ZERO; LIMBS];
61
62        if shift >= Self::BITS {
63            return ConstCtOption::none(Self::ZERO);
64        }
65
66        let shift_num = (shift / Limb::BITS) as usize;
67        let rem = shift % Limb::BITS;
68
69        let mut i = 0;
70        while i < LIMBS - shift_num {
71            limbs[i] = self.limbs[i + shift_num];
72            i += 1;
73        }
74
75        if rem == 0 {
76            return ConstCtOption::some(Self { limbs });
77        }
78
79        let mut carry = Limb::ZERO;
80
81        while i > 0 {
82            i -= 1;
83            let shifted = limbs[i].shr(rem);
84            let new_carry = limbs[i].shl(Limb::BITS - rem);
85            limbs[i] = shifted.bitor(carry);
86            carry = new_carry;
87        }
88
89        ConstCtOption::some(Self { limbs })
90    }
91
92    /// Computes a right shift on a wide input as `(lo, hi)`.
93    ///
94    /// Returns `None` if `shift >= Self::BITS`.
95    ///
96    /// NOTE: this operation is variable time with respect to `shift` *ONLY*.
97    ///
98    /// When used with a fixed `shift`, this function is constant-time with respect
99    /// to `self`.
100    #[inline(always)]
101    pub const fn overflowing_shr_vartime_wide(
102        lower_upper: (Self, Self),
103        shift: u32,
104    ) -> ConstCtOption<(Self, Self)> {
105        let (lower, upper) = lower_upper;
106        if shift >= 2 * Self::BITS {
107            ConstCtOption::none((Self::ZERO, Self::ZERO))
108        } else if shift >= Self::BITS {
109            let lower = upper
110                .overflowing_shr_vartime(shift - Self::BITS)
111                .expect("shift within range");
112            ConstCtOption::some((lower, Self::ZERO))
113        } else {
114            let new_upper = upper
115                .overflowing_shr_vartime(shift)
116                .expect("shift within range");
117            let lower_hi = upper
118                .overflowing_shl_vartime(Self::BITS - shift)
119                .expect("shift within range");
120            let lower_lo = lower
121                .overflowing_shr_vartime(shift)
122                .expect("shift within range");
123            ConstCtOption::some((lower_lo.bitor(&lower_hi), new_upper))
124        }
125    }
126
127    /// Computes `self >> shift` in a panic-free manner, returning zero if the shift exceeds the
128    /// precision.
129    pub const fn wrapping_shr(&self, shift: u32) -> Self {
130        self.overflowing_shr(shift).unwrap_or(Self::ZERO)
131    }
132
133    /// Computes `self >> shift` in variable-time in a panic-free manner, returning zero if the
134    /// shift exceeds the precision.
135    pub const fn wrapping_shr_vartime(&self, shift: u32) -> Self {
136        self.overflowing_shr_vartime(shift).unwrap_or(Self::ZERO)
137    }
138
139    /// Computes `self >> 1` in constant-time.
140    pub(crate) const fn shr1(&self) -> Self {
141        self.shr1_with_carry().0
142    }
143
144    /// Computes `self >> 1` in constant-time, returning [`ConstChoice::TRUE`]
145    /// if the least significant bit was set, and [`ConstChoice::FALSE`] otherwise.
146    #[inline(always)]
147    pub(crate) const fn shr1_with_carry(&self) -> (Self, ConstChoice) {
148        let mut ret = Self::ZERO;
149        let mut i = LIMBS;
150        let mut carry = Limb::ZERO;
151        while i > 0 {
152            i -= 1;
153            let (shifted, new_carry) = self.limbs[i].shr1();
154            ret.limbs[i] = shifted.bitor(carry);
155            carry = new_carry;
156        }
157
158        (ret, ConstChoice::from_word_lsb(carry.0 >> Limb::HI_BIT))
159    }
160}
161
162macro_rules! impl_shr {
163    ($($shift:ty),+) => {
164        $(
165            impl<const LIMBS: usize> Shr<$shift> for Uint<LIMBS> {
166                type Output = Uint<LIMBS>;
167
168                #[inline]
169                fn shr(self, shift: $shift) -> Uint<LIMBS> {
170                    <&Self>::shr(&self, shift)
171                }
172            }
173
174            impl<const LIMBS: usize> Shr<$shift> for &Uint<LIMBS> {
175                type Output = Uint<LIMBS>;
176
177                #[inline]
178                fn shr(self, shift: $shift) -> Uint<LIMBS> {
179                    Uint::<LIMBS>::shr(self, u32::try_from(shift).expect("invalid shift"))
180                }
181            }
182
183            impl<const LIMBS: usize> ShrAssign<$shift> for Uint<LIMBS> {
184                fn shr_assign(&mut self, shift: $shift) {
185                    *self = self.shr(shift)
186                }
187            }
188        )+
189    };
190}
191
192impl_shr!(i32, u32, usize);
193
194impl<const LIMBS: usize> WrappingShr for Uint<LIMBS> {
195    fn wrapping_shr(&self, shift: u32) -> Uint<LIMBS> {
196        self.wrapping_shr(shift)
197    }
198}
199
200impl<const LIMBS: usize> ShrVartime for Uint<LIMBS> {
201    fn overflowing_shr_vartime(&self, shift: u32) -> CtOption<Self> {
202        self.overflowing_shr(shift).into()
203    }
204    fn wrapping_shr_vartime(&self, shift: u32) -> Self {
205        self.wrapping_shr(shift)
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use crate::{Uint, U128, U256};
212
213    const N: U256 =
214        U256::from_be_hex("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141");
215
216    const N_2: U256 =
217        U256::from_be_hex("7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF5D576E7357A4501DDFE92F46681B20A0");
218
219    #[test]
220    fn shr1() {
221        assert_eq!(N.shr1(), N_2);
222        assert_eq!(N >> 1, N_2);
223    }
224
225    #[test]
226    fn shr256_const() {
227        assert!(N.overflowing_shr(256).is_none().is_true_vartime());
228        assert!(N.overflowing_shr_vartime(256).is_none().is_true_vartime());
229    }
230
231    #[test]
232    #[should_panic(expected = "`shift` within the bit size of the integer")]
233    fn shr256() {
234        let _ = N >> 256;
235    }
236
237    #[test]
238    fn shr_wide_1_1_128() {
239        assert_eq!(
240            Uint::overflowing_shr_vartime_wide((U128::ONE, U128::ONE), 128).unwrap(),
241            (U128::ONE, U128::ZERO)
242        );
243    }
244
245    #[test]
246    fn shr_wide_0_max_1() {
247        assert_eq!(
248            Uint::overflowing_shr_vartime_wide((U128::ZERO, U128::MAX), 1).unwrap(),
249            (U128::ONE << 127, U128::MAX >> 1)
250        );
251    }
252
253    #[test]
254    fn shr_wide_max_max_256() {
255        assert!(
256            Uint::overflowing_shr_vartime_wide((U128::MAX, U128::MAX), 256)
257                .is_none()
258                .is_true_vartime()
259        );
260    }
261}