1use core::{
2 alloc::{Layout, LayoutError},
3 num::NonZeroUsize,
4};
5
6#[derive(Debug, Clone, Copy, Eq, PartialEq)]
8pub struct StackReq {
9 align: Option<NonZeroUsize>,
10 size: usize,
11}
12
13impl Default for StackReq {
14 #[inline]
15 fn default() -> Self {
16 Self::empty()
17 }
18}
19
20#[inline(always)]
21const fn try_round_up_pow2(a: usize, b: usize) -> Option<usize> {
22 match a.checked_add(!b.wrapping_neg()) {
23 None => None,
24 Some(x) => Some(x & b.wrapping_neg()),
25 }
26}
27
28#[inline(always)]
29const fn max(a: usize, b: usize) -> usize {
30 if a > b {
31 a
32 } else {
33 b
34 }
35}
36
37impl StackReq {
38 pub const EMPTY: Self = Self {
40 align: unsafe { Some(NonZeroUsize::new_unchecked(1)) },
41 size: 0,
42 };
43
44 pub const OVERFLOW: Self = Self {
46 align: None,
47 size: 0,
48 };
49
50 #[inline]
52 pub const fn empty() -> StackReq {
53 Self::EMPTY
54 }
55
56 #[inline]
65 pub const fn new_aligned<T>(n: usize, align: usize) -> StackReq {
66 if align >= core::mem::align_of::<T>() && align.is_power_of_two() {
67 StackReq {
68 align: unsafe { Some(NonZeroUsize::new_unchecked(align)) },
69 size: core::mem::size_of::<T>(),
70 }
71 .array(n)
72 } else {
73 StackReq {
74 align: None,
75 size: 0,
76 }
77 }
78 }
79
80 #[inline]
86 pub const fn new<T>(n: usize) -> StackReq {
87 StackReq::new_aligned::<T>(n, core::mem::align_of::<T>())
88 }
89
90 #[inline]
92 pub const fn size_bytes(&self) -> usize {
93 self.size
94 }
95
96 #[inline]
98 pub const fn align_bytes(&self) -> usize {
99 match self.align {
100 Some(align) => align.get(),
101 None => 0,
102 }
103 }
104
105 #[inline]
111 pub const fn unaligned_bytes_required(&self) -> usize {
112 match self.layout() {
113 Ok(layout) => layout.size() + (layout.align() - 1),
114 Err(_) => usize::MAX,
115 }
116 }
117
118 #[inline]
120 pub const fn layout(self) -> Result<Layout, LayoutError> {
121 Layout::from_size_align(self.size_bytes(), self.align_bytes())
122 }
123
124 #[inline]
131 pub const fn and(self, other: StackReq) -> StackReq {
132 match (self.align, other.align) {
133 (Some(left), Some(right)) => {
134 let align = max(left.get(), right.get());
135 let left = try_round_up_pow2(self.size, align);
136 let right = try_round_up_pow2(other.size, align);
137
138 match (left, right) {
139 (Some(left), Some(right)) => {
140 match left.checked_add(right) {
141 Some(size) => StackReq {
142 align: unsafe { Some(NonZeroUsize::new_unchecked(align)) },
144 size,
145 },
146 _ => StackReq::OVERFLOW,
147 }
148 }
149 _ => StackReq::OVERFLOW,
150 }
151 }
152 _ => StackReq::OVERFLOW,
153 }
154 }
155
156 #[inline]
163 pub const fn all_of(reqs: &[Self]) -> Self {
164 let mut total = StackReq::EMPTY;
165 let mut reqs = reqs;
166 while let Some((req, next)) = reqs.split_first() {
167 total = total.and(*req);
168 reqs = next;
169 }
170 total
171 }
172
173 #[inline]
180 pub const fn or(self, other: StackReq) -> StackReq {
181 match (self.align, other.align) {
182 (Some(left), Some(right)) => {
183 let align = max(left.get(), right.get());
184 let left = try_round_up_pow2(self.size, align);
185 let right = try_round_up_pow2(other.size, align);
186
187 match (left, right) {
188 (Some(left), Some(right)) => {
189 let size = max(left, right);
190 StackReq {
191 align: unsafe { Some(NonZeroUsize::new_unchecked(align)) },
193 size,
194 }
195 }
196 _ => StackReq::OVERFLOW,
197 }
198 }
199 _ => StackReq::OVERFLOW,
200 }
201 }
202
203 #[inline]
210 pub fn any_of(reqs: &[StackReq]) -> StackReq {
211 let mut total = StackReq::EMPTY;
212 let mut reqs = reqs;
213 while let Some((req, next)) = reqs.split_first() {
214 total = total.or(*req);
215 reqs = next;
216 }
217 total
218 }
219
220 #[inline]
222 pub const fn array(self, n: usize) -> StackReq {
223 match self.align {
224 Some(align) => {
225 let size = self.size.checked_mul(n);
226 match size {
227 Some(size) => StackReq {
228 size,
229 align: Some(align),
230 },
231 None => StackReq::OVERFLOW,
232 }
233 }
234 None => StackReq::OVERFLOW,
235 }
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242
243 #[test]
244 fn round_up() {
245 assert_eq!(try_round_up_pow2(0, 4), Some(0));
246 assert_eq!(try_round_up_pow2(1, 4), Some(4));
247 assert_eq!(try_round_up_pow2(2, 4), Some(4));
248 assert_eq!(try_round_up_pow2(3, 4), Some(4));
249 assert_eq!(try_round_up_pow2(4, 4), Some(4));
250 }
251
252 #[test]
253 fn overflow() {
254 assert_eq!(StackReq::new::<u32>(usize::MAX).align_bytes(), 0);
255 }
256
257 #[test]
258 fn and_overflow() {
259 assert_eq!(
260 StackReq::new::<u8>(usize::MAX)
261 .and(StackReq::new::<u8>(1))
262 .align_bytes(),
263 0,
264 );
265 }
266}