crypto_bigint/int/
shr.rs

1//! [`Int`] bitwise right shift operations.
2
3use core::ops::{Shr, ShrAssign};
4
5use subtle::CtOption;
6
7use crate::{ConstChoice, ConstCtOption, Int, Limb, ShrVartime, Uint, WrappingShr};
8
9impl<const LIMBS: usize> Int<LIMBS> {
10    /// Computes `self >> shift`.
11    ///
12    /// Note, this is _signed_ shift right, i.e., the value shifted in on the left is equal to
13    /// the most significant bit.
14    ///
15    /// Panics if `shift >= Self::BITS`.
16    pub const fn shr(&self, shift: u32) -> Self {
17        self.overflowing_shr(shift)
18            .expect("`shift` within the bit size of the integer")
19    }
20
21    /// Computes `self >> shift` in variable time.
22    ///
23    /// Note, this is _signed_ shift right, i.e., the value shifted in on the left is equal to
24    /// the most significant bit.
25    ///
26    /// Panics if `shift >= Self::BITS`.
27    pub const fn shr_vartime(&self, shift: u32) -> Self {
28        self.overflowing_shr_vartime(shift)
29            .expect("`shift` within the bit size of the integer")
30    }
31
32    /// Computes `self >> shift`.
33    ///
34    /// Note, this is _signed_ shift right, i.e., the value shifted in on the left is equal to
35    /// the most significant bit.
36    ///
37    /// Returns `None` if `shift >= Self::BITS`.
38    pub const fn overflowing_shr(&self, shift: u32) -> ConstCtOption<Self> {
39        // `floor(log2(BITS - 1))` is the number of bits in the representation of `shift`
40        // (which lies in range `0 <= shift < BITS`).
41        let shift_bits = u32::BITS - (Self::BITS - 1).leading_zeros();
42        let overflow = ConstChoice::from_u32_lt(shift, Self::BITS).not();
43        let shift = shift % Self::BITS;
44        let mut result = *self;
45        let mut i = 0;
46        while i < shift_bits {
47            let bit = ConstChoice::from_u32_lsb((shift >> i) & 1);
48            result = Int::select(
49                &result,
50                &result
51                    .overflowing_shr_vartime(1 << i)
52                    .expect("shift within range"),
53                bit,
54            );
55            i += 1;
56        }
57
58        ConstCtOption::new(result, overflow.not())
59    }
60
61    /// Computes `self >> shift`.
62    ///
63    /// NOTE: this is _signed_ shift right, i.e., the value shifted in on the left is equal to
64    /// the most significant bit.
65    ///
66    /// Returns `None` if `shift >= Self::BITS`.
67    ///
68    /// NOTE: this operation is variable time with respect to `shift` *ONLY*.
69    ///
70    /// When used with a fixed `shift`, this function is constant-time with respect
71    /// to `self`.
72    #[inline(always)]
73    pub const fn overflowing_shr_vartime(&self, shift: u32) -> ConstCtOption<Self> {
74        let is_negative = self.is_negative();
75
76        if shift >= Self::BITS {
77            return ConstCtOption::none(Self::select(&Self::ZERO, &Self::MINUS_ONE, is_negative));
78        }
79
80        // Select the base limb, based on the sign of this value.
81        let base = Limb::select(Limb::ZERO, Limb::MAX, is_negative);
82        let mut limbs = [base; LIMBS];
83
84        let shift_num = (shift / Limb::BITS) as usize;
85        let rem = shift % Limb::BITS;
86
87        let mut i = 0;
88        while i < LIMBS - shift_num {
89            limbs[i] = self.0.limbs[i + shift_num];
90            i += 1;
91        }
92
93        if rem == 0 {
94            return ConstCtOption::some(Self(Uint::new(limbs)));
95        }
96
97        // construct the carry s.t. the `rem`-most significant bits of `carry` are 1 when self
98        // is negative, i.e., shift in 1s when the msb is 1.
99        let mut carry = Limb::select(Limb::ZERO, Limb::MAX, is_negative);
100        carry = carry.bitxor(carry.shr(rem)); // logical shift right; shifts in zeroes.
101
102        while i > 0 {
103            i -= 1;
104            let shifted = limbs[i].shr(rem);
105            let new_carry = limbs[i].shl(Limb::BITS - rem);
106            limbs[i] = shifted.bitor(carry);
107            carry = new_carry;
108        }
109
110        ConstCtOption::some(Self(Uint::new(limbs)))
111    }
112
113    /// Computes `self >> shift` in a panic-free manner.
114    ///
115    /// If the shift exceeds the precision, returns
116    /// - `0` when `self` is non-negative, and
117    /// - `-1` when `self` is negative.
118    pub const fn wrapping_shr(&self, shift: u32) -> Self {
119        let default = Self::select(&Self::ZERO, &Self::MINUS_ONE, self.is_negative());
120        self.overflowing_shr(shift).unwrap_or(default)
121    }
122
123    /// Computes `self >> shift` in variable-time in a panic-free manner.
124    ///
125    /// If the shift exceeds the precision, returns
126    /// - `0` when `self` is non-negative, and
127    /// - `-1` when `self` is negative.
128    pub const fn wrapping_shr_vartime(&self, shift: u32) -> Self {
129        let default = Self::select(&Self::ZERO, &Self::MINUS_ONE, self.is_negative());
130        self.overflowing_shr_vartime(shift).unwrap_or(default)
131    }
132}
133
134macro_rules! impl_shr {
135    ($($shift:ty),+) => {
136        $(
137            impl<const LIMBS: usize> Shr<$shift> for Int<LIMBS> {
138                type Output = Int<LIMBS>;
139
140                #[inline]
141                fn shr(self, shift: $shift) -> Int<LIMBS> {
142                    <&Self>::shr(&self, shift)
143                }
144            }
145
146            impl<const LIMBS: usize> Shr<$shift> for &Int<LIMBS> {
147                type Output = Int<LIMBS>;
148
149                #[inline]
150                fn shr(self, shift: $shift) -> Int<LIMBS> {
151                    Int::<LIMBS>::shr(self, u32::try_from(shift).expect("invalid shift"))
152                }
153            }
154
155            impl<const LIMBS: usize> ShrAssign<$shift> for Int<LIMBS> {
156                fn shr_assign(&mut self, shift: $shift) {
157                    *self = self.shr(shift)
158                }
159            }
160        )+
161    };
162}
163
164impl_shr!(i32, u32, usize);
165
166impl<const LIMBS: usize> WrappingShr for Int<LIMBS> {
167    fn wrapping_shr(&self, shift: u32) -> Int<LIMBS> {
168        self.wrapping_shr(shift)
169    }
170}
171
172impl<const LIMBS: usize> ShrVartime for Int<LIMBS> {
173    fn overflowing_shr_vartime(&self, shift: u32) -> CtOption<Self> {
174        self.overflowing_shr(shift).into()
175    }
176    fn wrapping_shr_vartime(&self, shift: u32) -> Self {
177        self.wrapping_shr(shift)
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use core::ops::Div;
184
185    use crate::I256;
186
187    const N: I256 =
188        I256::from_be_hex("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141");
189
190    const N_2: I256 =
191        I256::from_be_hex("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF5D576E7357A4501DDFE92F46681B20A0");
192
193    #[test]
194    fn shr0() {
195        assert_eq!(I256::MAX >> 0, I256::MAX);
196        assert_eq!(I256::MIN >> 0, I256::MIN);
197    }
198
199    #[test]
200    fn shr1() {
201        assert_eq!(N >> 1, N_2);
202    }
203
204    #[test]
205    fn shr5() {
206        assert_eq!(
207            I256::MAX >> 5,
208            I256::MAX.div(I256::from(32).to_nz().unwrap()).unwrap()
209        );
210        assert_eq!(
211            I256::MIN >> 5,
212            I256::MIN.div(I256::from(32).to_nz().unwrap()).unwrap()
213        );
214    }
215
216    #[test]
217    fn shr7_vartime() {
218        assert_eq!(
219            I256::MAX.shr_vartime(7),
220            I256::MAX.div(I256::from(128).to_nz().unwrap()).unwrap()
221        );
222        assert_eq!(
223            I256::MIN.shr_vartime(7),
224            I256::MIN.div(I256::from(128).to_nz().unwrap()).unwrap()
225        );
226    }
227
228    #[test]
229    fn shr256_const() {
230        assert!(N.overflowing_shr(256).is_none().is_true_vartime());
231        assert!(N.overflowing_shr_vartime(256).is_none().is_true_vartime());
232    }
233
234    #[test]
235    #[should_panic(expected = "`shift` within the bit size of the integer")]
236    fn shr256() {
237        let _ = N >> 256;
238    }
239
240    #[test]
241    fn wrapping_shr() {
242        assert_eq!(I256::MAX.wrapping_shr(257), I256::ZERO);
243        assert_eq!(I256::MIN.wrapping_shr(257), I256::MINUS_ONE);
244    }
245}