1use core::ops::{Shr, ShrAssign};
4
5use subtle::CtOption;
6
7use crate::{ConstChoice, ConstCtOption, Int, Limb, ShrVartime, Uint, WrappingShr};
8
9impl<const LIMBS: usize> Int<LIMBS> {
10 pub const fn shr(&self, shift: u32) -> Self {
17 self.overflowing_shr(shift)
18 .expect("`shift` within the bit size of the integer")
19 }
20
21 pub const fn shr_vartime(&self, shift: u32) -> Self {
28 self.overflowing_shr_vartime(shift)
29 .expect("`shift` within the bit size of the integer")
30 }
31
32 pub const fn overflowing_shr(&self, shift: u32) -> ConstCtOption<Self> {
39 let shift_bits = u32::BITS - (Self::BITS - 1).leading_zeros();
42 let overflow = ConstChoice::from_u32_lt(shift, Self::BITS).not();
43 let shift = shift % Self::BITS;
44 let mut result = *self;
45 let mut i = 0;
46 while i < shift_bits {
47 let bit = ConstChoice::from_u32_lsb((shift >> i) & 1);
48 result = Int::select(
49 &result,
50 &result
51 .overflowing_shr_vartime(1 << i)
52 .expect("shift within range"),
53 bit,
54 );
55 i += 1;
56 }
57
58 ConstCtOption::new(result, overflow.not())
59 }
60
61 #[inline(always)]
73 pub const fn overflowing_shr_vartime(&self, shift: u32) -> ConstCtOption<Self> {
74 let is_negative = self.is_negative();
75
76 if shift >= Self::BITS {
77 return ConstCtOption::none(Self::select(&Self::ZERO, &Self::MINUS_ONE, is_negative));
78 }
79
80 let base = Limb::select(Limb::ZERO, Limb::MAX, is_negative);
82 let mut limbs = [base; LIMBS];
83
84 let shift_num = (shift / Limb::BITS) as usize;
85 let rem = shift % Limb::BITS;
86
87 let mut i = 0;
88 while i < LIMBS - shift_num {
89 limbs[i] = self.0.limbs[i + shift_num];
90 i += 1;
91 }
92
93 if rem == 0 {
94 return ConstCtOption::some(Self(Uint::new(limbs)));
95 }
96
97 let mut carry = Limb::select(Limb::ZERO, Limb::MAX, is_negative);
100 carry = carry.bitxor(carry.shr(rem)); while i > 0 {
103 i -= 1;
104 let shifted = limbs[i].shr(rem);
105 let new_carry = limbs[i].shl(Limb::BITS - rem);
106 limbs[i] = shifted.bitor(carry);
107 carry = new_carry;
108 }
109
110 ConstCtOption::some(Self(Uint::new(limbs)))
111 }
112
113 pub const fn wrapping_shr(&self, shift: u32) -> Self {
119 let default = Self::select(&Self::ZERO, &Self::MINUS_ONE, self.is_negative());
120 self.overflowing_shr(shift).unwrap_or(default)
121 }
122
123 pub const fn wrapping_shr_vartime(&self, shift: u32) -> Self {
129 let default = Self::select(&Self::ZERO, &Self::MINUS_ONE, self.is_negative());
130 self.overflowing_shr_vartime(shift).unwrap_or(default)
131 }
132}
133
134macro_rules! impl_shr {
135 ($($shift:ty),+) => {
136 $(
137 impl<const LIMBS: usize> Shr<$shift> for Int<LIMBS> {
138 type Output = Int<LIMBS>;
139
140 #[inline]
141 fn shr(self, shift: $shift) -> Int<LIMBS> {
142 <&Self>::shr(&self, shift)
143 }
144 }
145
146 impl<const LIMBS: usize> Shr<$shift> for &Int<LIMBS> {
147 type Output = Int<LIMBS>;
148
149 #[inline]
150 fn shr(self, shift: $shift) -> Int<LIMBS> {
151 Int::<LIMBS>::shr(self, u32::try_from(shift).expect("invalid shift"))
152 }
153 }
154
155 impl<const LIMBS: usize> ShrAssign<$shift> for Int<LIMBS> {
156 fn shr_assign(&mut self, shift: $shift) {
157 *self = self.shr(shift)
158 }
159 }
160 )+
161 };
162}
163
164impl_shr!(i32, u32, usize);
165
166impl<const LIMBS: usize> WrappingShr for Int<LIMBS> {
167 fn wrapping_shr(&self, shift: u32) -> Int<LIMBS> {
168 self.wrapping_shr(shift)
169 }
170}
171
172impl<const LIMBS: usize> ShrVartime for Int<LIMBS> {
173 fn overflowing_shr_vartime(&self, shift: u32) -> CtOption<Self> {
174 self.overflowing_shr(shift).into()
175 }
176 fn wrapping_shr_vartime(&self, shift: u32) -> Self {
177 self.wrapping_shr(shift)
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use core::ops::Div;
184
185 use crate::I256;
186
187 const N: I256 =
188 I256::from_be_hex("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141");
189
190 const N_2: I256 =
191 I256::from_be_hex("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF5D576E7357A4501DDFE92F46681B20A0");
192
193 #[test]
194 fn shr0() {
195 assert_eq!(I256::MAX >> 0, I256::MAX);
196 assert_eq!(I256::MIN >> 0, I256::MIN);
197 }
198
199 #[test]
200 fn shr1() {
201 assert_eq!(N >> 1, N_2);
202 }
203
204 #[test]
205 fn shr5() {
206 assert_eq!(
207 I256::MAX >> 5,
208 I256::MAX.div(I256::from(32).to_nz().unwrap()).unwrap()
209 );
210 assert_eq!(
211 I256::MIN >> 5,
212 I256::MIN.div(I256::from(32).to_nz().unwrap()).unwrap()
213 );
214 }
215
216 #[test]
217 fn shr7_vartime() {
218 assert_eq!(
219 I256::MAX.shr_vartime(7),
220 I256::MAX.div(I256::from(128).to_nz().unwrap()).unwrap()
221 );
222 assert_eq!(
223 I256::MIN.shr_vartime(7),
224 I256::MIN.div(I256::from(128).to_nz().unwrap()).unwrap()
225 );
226 }
227
228 #[test]
229 fn shr256_const() {
230 assert!(N.overflowing_shr(256).is_none().is_true_vartime());
231 assert!(N.overflowing_shr_vartime(256).is_none().is_true_vartime());
232 }
233
234 #[test]
235 #[should_panic(expected = "`shift` within the bit size of the integer")]
236 fn shr256() {
237 let _ = N >> 256;
238 }
239
240 #[test]
241 fn wrapping_shr() {
242 assert_eq!(I256::MAX.wrapping_shr(257), I256::ZERO);
243 assert_eq!(I256::MIN.wrapping_shr(257), I256::MINUS_ONE);
244 }
245}