crypto_bigint/uint/
shl.rs

1//! [`Uint`] bitwise left shift operations.
2
3use crate::{ConstChoice, ConstCtOption, Limb, ShlVartime, Uint, Word, WrappingShl};
4use core::ops::{Shl, ShlAssign};
5use subtle::CtOption;
6
7impl<const LIMBS: usize> Uint<LIMBS> {
8    /// Computes `self << shift`.
9    ///
10    /// Panics if `shift >= Self::BITS`.
11    pub const fn shl(&self, shift: u32) -> Self {
12        self.overflowing_shl(shift)
13            .expect("`shift` within the bit size of the integer")
14    }
15
16    /// Computes `self << shift` in variable time.
17    ///
18    /// Panics if `shift >= Self::BITS`.
19    pub const fn shl_vartime(&self, shift: u32) -> Self {
20        self.overflowing_shl_vartime(shift)
21            .expect("`shift` within the bit size of the integer")
22    }
23
24    /// Computes `self << shift`.
25    ///
26    /// Returns `None` if `shift >= Self::BITS`.
27    pub const fn overflowing_shl(&self, shift: u32) -> ConstCtOption<Self> {
28        // `floor(log2(BITS - 1))` is the number of bits in the representation of `shift`
29        // (which lies in range `0 <= shift < BITS`).
30        let shift_bits = u32::BITS - (Self::BITS - 1).leading_zeros();
31        let overflow = ConstChoice::from_u32_lt(shift, Self::BITS).not();
32        let shift = shift % Self::BITS;
33        let mut result = *self;
34        let mut i = 0;
35        while i < shift_bits {
36            let bit = ConstChoice::from_u32_lsb((shift >> i) & 1);
37            result = Uint::select(
38                &result,
39                &result
40                    .overflowing_shl_vartime(1 << i)
41                    .expect("shift within range"),
42                bit,
43            );
44            i += 1;
45        }
46
47        ConstCtOption::new(Uint::select(&result, &Self::ZERO, overflow), overflow.not())
48    }
49
50    /// Computes `self << shift`.
51    ///
52    /// Returns `None` if `shift >= Self::BITS`.
53    ///
54    /// NOTE: this operation is variable time with respect to `shift` *ONLY*.
55    ///
56    /// When used with a fixed `shift`, this function is constant-time with respect
57    /// to `self`.
58    #[inline(always)]
59    pub const fn overflowing_shl_vartime(&self, shift: u32) -> ConstCtOption<Self> {
60        let mut limbs = [Limb::ZERO; LIMBS];
61
62        if shift >= Self::BITS {
63            return ConstCtOption::none(Self::ZERO);
64        }
65
66        let shift_num = (shift / Limb::BITS) as usize;
67        let rem = shift % Limb::BITS;
68
69        let mut i = shift_num;
70        while i < LIMBS {
71            limbs[i] = self.limbs[i - shift_num];
72            i += 1;
73        }
74
75        if rem == 0 {
76            return ConstCtOption::some(Self { limbs });
77        }
78
79        let mut carry = Limb::ZERO;
80
81        let mut i = shift_num;
82        while i < LIMBS {
83            let shifted = limbs[i].shl(rem);
84            let new_carry = limbs[i].shr(Limb::BITS - rem);
85            limbs[i] = shifted.bitor(carry);
86            carry = new_carry;
87            i += 1;
88        }
89
90        ConstCtOption::some(Self { limbs })
91    }
92
93    /// Computes a left shift on a wide input as `(lo, hi)`.
94    ///
95    /// Returns `None` if `shift >= Self::BITS`.
96    ///
97    /// NOTE: this operation is variable time with respect to `shift` *ONLY*.
98    ///
99    /// When used with a fixed `shift`, this function is constant-time with respect
100    /// to `self`.
101    #[inline(always)]
102    pub const fn overflowing_shl_vartime_wide(
103        lower_upper: (Self, Self),
104        shift: u32,
105    ) -> ConstCtOption<(Self, Self)> {
106        let (lower, upper) = lower_upper;
107        if shift >= 2 * Self::BITS {
108            ConstCtOption::none((Self::ZERO, Self::ZERO))
109        } else if shift >= Self::BITS {
110            let upper = lower
111                .overflowing_shl_vartime(shift - Self::BITS)
112                .expect("shift within range");
113            ConstCtOption::some((Self::ZERO, upper))
114        } else {
115            let new_lower = lower
116                .overflowing_shl_vartime(shift)
117                .expect("shift within range");
118            let upper_lo = lower
119                .overflowing_shr_vartime(Self::BITS - shift)
120                .expect("shift within range");
121            let upper_hi = upper
122                .overflowing_shl_vartime(shift)
123                .expect("shift within range");
124            ConstCtOption::some((new_lower, upper_lo.bitor(&upper_hi)))
125        }
126    }
127
128    /// Computes `self << shift` in a panic-free manner, returning zero if the shift exceeds the
129    /// precision.
130    pub const fn wrapping_shl(&self, shift: u32) -> Self {
131        self.overflowing_shl(shift).unwrap_or(Self::ZERO)
132    }
133
134    /// Computes `self << shift` in variable-time in a panic-free manner, returning zero if the
135    /// shift exceeds the precision.
136    pub const fn wrapping_shl_vartime(&self, shift: u32) -> Self {
137        self.overflowing_shl_vartime(shift).unwrap_or(Self::ZERO)
138    }
139
140    /// Computes `self << shift` where `0 <= shift < Limb::BITS`,
141    /// returning the result and the carry.
142    #[inline(always)]
143    pub(crate) const fn shl_limb(&self, shift: u32) -> (Self, Limb) {
144        let mut limbs = [Limb::ZERO; LIMBS];
145
146        let nz = ConstChoice::from_u32_nonzero(shift);
147        let lshift = shift;
148        let rshift = nz.if_true_u32(Limb::BITS - shift);
149        let carry = nz.if_true_word(self.limbs[LIMBS - 1].0.wrapping_shr(Word::BITS - shift));
150
151        limbs[0] = Limb(self.limbs[0].0 << lshift);
152        let mut i = 1;
153        while i < LIMBS {
154            let mut limb = self.limbs[i].0 << lshift;
155            let hi = self.limbs[i - 1].0 >> rshift;
156            limb |= nz.if_true_word(hi);
157            limbs[i] = Limb(limb);
158            i += 1
159        }
160
161        (Uint::<LIMBS>::new(limbs), Limb(carry))
162    }
163
164    /// Computes `self << 1` in constant-time, returning [`ConstChoice::TRUE`]
165    /// if the most significant bit was set, and [`ConstChoice::FALSE`] otherwise.
166    #[inline(always)]
167    pub(crate) const fn overflowing_shl1(&self) -> (Self, Limb) {
168        let mut ret = Self::ZERO;
169        let mut i = 0;
170        let mut carry = Limb::ZERO;
171        while i < LIMBS {
172            let (shifted, new_carry) = self.limbs[i].shl1();
173            ret.limbs[i] = shifted.bitor(carry);
174            carry = new_carry;
175            i += 1;
176        }
177
178        (ret, carry)
179    }
180}
181
182macro_rules! impl_shl {
183    ($($shift:ty),+) => {
184        $(
185            impl<const LIMBS: usize> Shl<$shift> for Uint<LIMBS> {
186                type Output = Uint<LIMBS>;
187
188                #[inline]
189                fn shl(self, shift: $shift) -> Uint<LIMBS> {
190                    <&Self>::shl(&self, shift)
191                }
192            }
193
194            impl<const LIMBS: usize> Shl<$shift> for &Uint<LIMBS> {
195                type Output = Uint<LIMBS>;
196
197                #[inline]
198                fn shl(self, shift: $shift) -> Uint<LIMBS> {
199                    Uint::<LIMBS>::shl(self, u32::try_from(shift).expect("invalid shift"))
200                }
201            }
202
203            impl<const LIMBS: usize> ShlAssign<$shift> for Uint<LIMBS> {
204                fn shl_assign(&mut self, shift: $shift) {
205                    *self = self.shl(shift)
206                }
207            }
208        )+
209    };
210}
211
212impl_shl!(i32, u32, usize);
213
214impl<const LIMBS: usize> WrappingShl for Uint<LIMBS> {
215    fn wrapping_shl(&self, shift: u32) -> Uint<LIMBS> {
216        self.wrapping_shl(shift)
217    }
218}
219
220impl<const LIMBS: usize> ShlVartime for Uint<LIMBS> {
221    fn overflowing_shl_vartime(&self, shift: u32) -> CtOption<Self> {
222        self.overflowing_shl(shift).into()
223    }
224    fn wrapping_shl_vartime(&self, shift: u32) -> Self {
225        self.wrapping_shl(shift)
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use crate::{Limb, Uint, U128, U256};
232
233    const N: U256 =
234        U256::from_be_hex("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141");
235
236    const TWO_N: U256 =
237        U256::from_be_hex("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFD755DB9CD5E9140777FA4BD19A06C8282");
238
239    const FOUR_N: U256 =
240        U256::from_be_hex("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFAEABB739ABD2280EEFF497A3340D90504");
241
242    const SIXTY_FIVE: U256 =
243        U256::from_be_hex("FFFFFFFFFFFFFFFD755DB9CD5E9140777FA4BD19A06C82820000000000000000");
244
245    const EIGHTY_EIGHT: U256 =
246        U256::from_be_hex("FFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD03641410000000000000000000000");
247
248    const SIXTY_FOUR: U256 =
249        U256::from_be_hex("FFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD03641410000000000000000");
250
251    #[test]
252    fn shl_simple() {
253        let mut t = U256::from(1u8);
254        assert_eq!(t << 1, U256::from(2u8));
255        t = U256::from(3u8);
256        assert_eq!(t << 8, U256::from(0x300u16));
257    }
258
259    #[test]
260    fn shl1() {
261        assert_eq!(N << 1, TWO_N);
262        assert_eq!(N.overflowing_shl1(), (TWO_N, Limb::ONE));
263    }
264
265    #[test]
266    fn shl2() {
267        assert_eq!(N << 2, FOUR_N);
268    }
269
270    #[test]
271    fn shl65() {
272        assert_eq!(N << 65, SIXTY_FIVE);
273    }
274
275    #[test]
276    fn shl88() {
277        assert_eq!(N << 88, EIGHTY_EIGHT);
278    }
279
280    #[test]
281    fn shl256_const() {
282        assert!(N.overflowing_shl(256).is_none().is_true_vartime());
283        assert!(N.overflowing_shl_vartime(256).is_none().is_true_vartime());
284    }
285
286    #[test]
287    #[should_panic(expected = "`shift` within the bit size of the integer")]
288    fn shl256() {
289        let _ = N << 256;
290    }
291
292    #[test]
293    fn shl64() {
294        assert_eq!(N << 64, SIXTY_FOUR);
295    }
296
297    #[test]
298    fn shl_wide_1_1_128() {
299        assert_eq!(
300            Uint::overflowing_shl_vartime_wide((U128::ONE, U128::ONE), 128).unwrap(),
301            (U128::ZERO, U128::ONE)
302        );
303        assert_eq!(
304            Uint::overflowing_shl_vartime_wide((U128::ONE, U128::ONE), 128).unwrap(),
305            (U128::ZERO, U128::ONE)
306        );
307    }
308
309    #[test]
310    fn shl_wide_max_0_1() {
311        assert_eq!(
312            Uint::overflowing_shl_vartime_wide((U128::MAX, U128::ZERO), 1).unwrap(),
313            (U128::MAX.sbb(&U128::ONE, Limb::ZERO).0, U128::ONE)
314        );
315    }
316
317    #[test]
318    fn shl_wide_max_max_256() {
319        assert!(
320            Uint::overflowing_shl_vartime_wide((U128::MAX, U128::MAX), 256)
321                .is_none()
322                .is_true_vartime(),
323        );
324    }
325}