const_str/__ctfe/
net.rs

1use core::net::{IpAddr, Ipv4Addr, Ipv6Addr};
2
3struct Parser<'a> {
4    s: &'a [u8],
5    i: usize,
6}
7
8impl<'a> Parser<'a> {
9    const fn new(s: &'a str) -> Self {
10        Self {
11            s: s.as_bytes(),
12            i: 0,
13        }
14    }
15
16    const fn const_clone(&self) -> Self {
17        Self {
18            s: self.s,
19            i: self.i,
20        }
21    }
22
23    const fn is_end(&self) -> bool {
24        self.i == self.s.len()
25    }
26
27    const fn peek(&self) -> Option<u8> {
28        let &Self { s, i, .. } = self;
29        if i < s.len() {
30            Some(s[i])
31        } else {
32            None
33        }
34    }
35
36    const fn peek2(&self) -> Option<(u8, Option<u8>)> {
37        let &Self { s, i, .. } = self;
38        if i >= s.len() {
39            return None;
40        }
41        if i + 1 >= s.len() {
42            return Some((s[i], None));
43        }
44        Some((s[i], Some(s[i + 1])))
45    }
46
47    const fn read_byte(mut self) -> (Self, Option<u8>) {
48        let Self { s, i, .. } = self;
49        if i < s.len() {
50            self.i += 1;
51            (self, Some(s[i]))
52        } else {
53            (self, None)
54        }
55    }
56
57    const fn read_given_byte(self, byte: u8) -> (Self, Option<()>) {
58        let p = self.const_clone();
59        let (p, val) = p.read_byte();
60        match val {
61            Some(v) if v == byte => (p, Some(())),
62            _ => (self, None),
63        }
64    }
65
66    const fn advance(mut self, step: usize) -> Self {
67        self.i += step;
68        self
69    }
70}
71
72macro_rules! impl_read_uint {
73    ($ty:ty, $id: ident) => {
74        impl<'a> Parser<'a> {
75            const fn $id(
76                mut self,
77                radix: u8,
78                allow_leading_zeros: bool,
79                max_digits: usize,
80            ) -> (Self, Option<$ty>) {
81                assert!(radix == 10 || radix == 16);
82                let Self { s, mut i, .. } = self;
83                let mut digit_count = 0;
84                let mut ans: $ty = 0;
85
86                loop {
87                    let b = if i < s.len() { s[i] } else { break };
88                    let x = match b {
89                        b'0'..=b'9' => b - b'0',
90                        b'a'..=b'f' if radix == 16 => b - b'a' + 10,
91                        b'A'..=b'F' if radix == 16 => b - b'A' + 10,
92                        _ => break,
93                    };
94                    if !allow_leading_zeros && (ans == 0 && digit_count == 1) {
95                        return (self, None);
96                    }
97                    ans = match ans.checked_mul(radix as $ty) {
98                        Some(x) => x,
99                        None => return (self, None),
100                    };
101                    ans = match ans.checked_add(x as $ty) {
102                        Some(x) => x,
103                        None => return (self, None),
104                    };
105                    i += 1;
106                    digit_count += 1;
107
108                    if digit_count > max_digits {
109                        return (self, None);
110                    }
111                }
112
113                if digit_count == 0 {
114                    return (self, None);
115                }
116                self.i = i;
117                (self, Some(ans))
118            }
119        }
120    };
121}
122
123impl_read_uint!(u8, read_u8);
124impl_read_uint!(u16, read_u16);
125
126macro_rules! try_parse {
127    ($orig:ident, $id:ident, $ret: expr) => {{
128        match $ret {
129            (next, Some(val)) => {
130                $id = next;
131                val
132            }
133            (_, None) => return ($orig, None),
134        }
135    }};
136}
137
138macro_rules! parse {
139    ($id:ident,$ret: expr) => {{
140        let (next, val) = $ret;
141        $id = next;
142        val
143    }};
144}
145
146impl Parser<'_> {
147    const fn read_ipv4(self) -> (Self, Option<Ipv4Addr>) {
148        let mut p = self.const_clone();
149        let mut nums = [0; 4];
150        let mut i = 0;
151        while i < 4 {
152            if i > 0 {
153                try_parse!(self, p, p.read_given_byte(b'.'));
154            }
155            nums[i] = try_parse!(self, p, p.read_u8(10, false, 3));
156            i += 1;
157        }
158        let val = Ipv4Addr::new(nums[0], nums[1], nums[2], nums[3]);
159        (p, Some(val))
160    }
161
162    const fn read_ipv6(self) -> (Self, Option<Ipv6Addr>) {
163        let mut p = self.const_clone();
164
165        let mut nums: [u16; 8] = [0; 8];
166        let mut left_cnt = 0;
167        let mut right_cnt = 0;
168
169        let mut state: u8 = 0;
170        'dfa: loop {
171            match state {
172                0 => match p.peek2() {
173                    Some((b':', Some(b':'))) => {
174                        p = p.advance(2);
175                        state = 2;
176                        continue 'dfa;
177                    }
178                    _ => {
179                        state = 1;
180                        continue 'dfa;
181                    }
182                },
183                1 => loop {
184                    nums[left_cnt] = try_parse!(self, p, p.read_u16(16, true, 4));
185                    left_cnt += 1;
186                    if left_cnt == 8 {
187                        break 'dfa;
188                    }
189                    try_parse!(self, p, p.read_given_byte(b':'));
190                    if matches!(p.peek(), Some(b':')) {
191                        p = p.advance(1);
192                        if left_cnt == 7 {
193                            break 'dfa;
194                        }
195                        state = 2;
196                        continue 'dfa;
197                    }
198                    if left_cnt == 6 {
199                        if let Some(val) = parse!(p, p.read_ipv4()) {
200                            let [n1, n2, n3, n4] = val.octets();
201                            nums[6] = u16::from_be_bytes([n1, n2]);
202                            nums[7] = u16::from_be_bytes([n3, n4]);
203                            // left_cnt = 8;
204                            break 'dfa;
205                        }
206                    }
207                },
208                2 => loop {
209                    if left_cnt + right_cnt <= 6 {
210                        if let Some(val) = parse!(p, p.read_ipv4()) {
211                            let [n1, n2, n3, n4] = val.octets();
212                            nums[7 - right_cnt] = u16::from_be_bytes([n1, n2]);
213                            nums[6 - right_cnt] = u16::from_be_bytes([n3, n4]);
214                            right_cnt += 2;
215                            break 'dfa;
216                        }
217                    }
218                    match parse!(p, p.read_u16(16, true, 4)) {
219                        Some(val) => {
220                            nums[7 - right_cnt] = val;
221                            right_cnt += 1;
222                            if left_cnt + right_cnt == 7 {
223                                break 'dfa;
224                            }
225                            match p.peek() {
226                                Some(b':') => p = p.advance(1),
227                                _ => break 'dfa,
228                            }
229                        }
230                        None => break 'dfa,
231                    }
232                },
233                _ => unreachable!(),
234            }
235        }
236        {
237            let mut i = 8 - right_cnt;
238            let mut j = 7;
239            #[allow(clippy::manual_swap)]
240            while i < j {
241                let (lhs, rhs) = (nums[i], nums[j]);
242                nums[i] = rhs;
243                nums[j] = lhs;
244                i += 1;
245                j -= 1;
246            }
247        }
248
249        let val = Ipv6Addr::new(
250            nums[0], nums[1], nums[2], nums[3], //
251            nums[4], nums[5], nums[6], nums[7], //
252        );
253        (p, Some(val))
254    }
255}
256
257macro_rules! parse_with {
258    ($s: expr, $m:ident) => {{
259        let p = Parser::new($s);
260        let (p, val) = p.$m();
261        match val {
262            Some(v) if p.is_end() => Some(v),
263            _ => None,
264        }
265    }};
266}
267
268pub const fn expect_ipv4(s: &str) -> Ipv4Addr {
269    match parse_with!(s, read_ipv4) {
270        Some(val) => val,
271        None => panic!("invalid ipv4 address"),
272    }
273}
274
275pub const fn expect_ipv6(s: &str) -> Ipv6Addr {
276    match parse_with!(s, read_ipv6) {
277        Some(val) => val,
278        None => panic!("invalid ipv6 address"),
279    }
280}
281
282pub const fn expect_ip(s: &str) -> IpAddr {
283    match parse_with!(s, read_ipv4) {
284        Some(val) => IpAddr::V4(val),
285        None => match parse_with!(s, read_ipv6) {
286            Some(val) => IpAddr::V6(val),
287            None => panic!("invalid ip address"),
288        },
289    }
290}
291
292/// Converts a string slice to an IP address.
293///
294/// This macro is [const-fn compatible](./index.html#const-fn-compatible).
295///
296/// # Examples
297/// ```
298/// use core::net::{IpAddr, Ipv4Addr, Ipv6Addr};
299/// use const_str::ip_addr;
300///
301/// const LOCALHOST_V4: Ipv4Addr = ip_addr!(v4, "127.0.0.1");
302/// const LOCALHOST_V6: Ipv6Addr = ip_addr!(v6, "::1");
303///
304/// const LOCALHOSTS: [IpAddr;2] = [ip_addr!("127.0.0.1"), ip_addr!("::1")];
305/// ```
306#[macro_export]
307macro_rules! ip_addr {
308    (v4, $s:expr) => {
309        $crate::__ctfe::expect_ipv4($s)
310    };
311    (v6, $s:expr) => {
312        $crate::__ctfe::expect_ipv6($s)
313    };
314    ($s:expr) => {
315        $crate::__ctfe::expect_ip($s)
316    };
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322
323    #[test]
324    fn test_ip_addr() {
325        fn parse<T>(s: &str, _: &T) -> T
326        where
327            T: std::str::FromStr,
328            <T as std::str::FromStr>::Err: std::fmt::Debug,
329        {
330            s.parse().unwrap()
331        }
332
333        macro_rules! test_ip_addr {
334            (v4, invalid, $s:expr) => {{
335                let output = parse_with!($s, read_ipv4);
336                assert!(output.is_none());
337            }};
338            (v6, invalid, $s:expr) => {{
339                let output = parse_with!($s, read_ipv6);
340                assert!(output.is_none());
341            }};
342            ($t:tt, $s:expr) => {{
343                let output = ip_addr!($t, $s);
344                let ans = parse($s, &output);
345                assert_eq!(output, ans);
346            }
347            {
348                let output = ip_addr!($s);
349                let ans = parse($s, &output);
350                assert_eq!(output, ans);
351            }};
352        }
353
354        test_ip_addr!(v4, "0.0.0.0");
355        test_ip_addr!(v4, "127.0.0.1");
356        test_ip_addr!(v4, "255.255.255.255");
357        test_ip_addr!(v4, invalid, "0");
358        test_ip_addr!(v4, invalid, "0x1");
359        test_ip_addr!(v4, invalid, "127.00.0.1");
360        test_ip_addr!(v4, invalid, "027.0.0.1");
361        test_ip_addr!(v4, invalid, "256.0.0.1");
362        test_ip_addr!(v4, invalid, "255.0.0");
363        test_ip_addr!(v4, invalid, "255.0.0.1.2");
364        test_ip_addr!(v4, invalid, "255.0.0..1");
365
366        test_ip_addr!(v6, "::");
367        test_ip_addr!(v6, "::1");
368        test_ip_addr!(v6, "2001:db8::2:3:4:1");
369        test_ip_addr!(v6, "::1:2:3");
370        test_ip_addr!(v6, "FF01::101");
371        test_ip_addr!(v6, "0:0:0:0:0:0:13.1.68.3");
372        test_ip_addr!(v6, "0:0:0:0:0:FFFF:129.144.52.38");
373        test_ip_addr!(v6, invalid, "::::");
374        test_ip_addr!(v6, invalid, "::00001");
375    }
376}