rusticata_macros/
combinator.rs

1//! General purpose combinators
2
3use nom::bytes::streaming::take;
4use nom::combinator::map_parser;
5use nom::error::{make_error, ErrorKind, ParseError};
6use nom::{IResult, Needed, Parser};
7use nom::{InputIter, InputTake};
8use nom::{InputLength, ToUsize};
9
10#[deprecated(since = "3.0.1", note = "please use `be_var_u64` instead")]
11/// Read an entire slice as a big-endian value.
12///
13/// Returns the value as `u64`. This function checks for integer overflows, and returns a
14/// `Result::Err` value if the value is too big.
15pub fn bytes_to_u64(s: &[u8]) -> Result<u64, &'static str> {
16    let mut u: u64 = 0;
17
18    if s.is_empty() {
19        return Err("empty");
20    };
21    if s.len() > 8 {
22        return Err("overflow");
23    }
24    for &c in s {
25        let u1 = u << 8;
26        u = u1 | (c as u64);
27    }
28
29    Ok(u)
30}
31
32/// Read the entire slice as a big endian unsigned integer, up to 8 bytes
33#[inline]
34pub fn be_var_u64<'a, E: ParseError<&'a [u8]>>(input: &'a [u8]) -> IResult<&'a [u8], u64, E> {
35    if input.is_empty() {
36        return Err(nom::Err::Incomplete(Needed::new(1)));
37    }
38    if input.len() > 8 {
39        return Err(nom::Err::Error(make_error(input, ErrorKind::TooLarge)));
40    }
41    let mut res = 0u64;
42    for byte in input {
43        res = (res << 8) + *byte as u64;
44    }
45
46    Ok((&b""[..], res))
47}
48
49/// Read the entire slice as a little endian unsigned integer, up to 8 bytes
50#[inline]
51pub fn le_var_u64<'a, E: ParseError<&'a [u8]>>(input: &'a [u8]) -> IResult<&'a [u8], u64, E> {
52    if input.is_empty() {
53        return Err(nom::Err::Incomplete(Needed::new(1)));
54    }
55    if input.len() > 8 {
56        return Err(nom::Err::Error(make_error(input, ErrorKind::TooLarge)));
57    }
58    let mut res = 0u64;
59    for byte in input.iter().rev() {
60        res = (res << 8) + *byte as u64;
61    }
62
63    Ok((&b""[..], res))
64}
65
66/// Read a slice as a big-endian value.
67#[inline]
68pub fn parse_hex_to_u64<S>(i: &[u8], size: S) -> IResult<&[u8], u64>
69where
70    S: ToUsize + Copy,
71{
72    map_parser(take(size.to_usize()), be_var_u64)(i)
73}
74
75/// Apply combinator, automatically converts between errors if the underlying type supports it
76pub fn upgrade_error<I, O, E1: ParseError<I>, E2: ParseError<I>, F>(
77    mut f: F,
78) -> impl FnMut(I) -> IResult<I, O, E2>
79where
80    F: FnMut(I) -> IResult<I, O, E1>,
81    E2: From<E1>,
82{
83    move |i| f(i).map_err(nom::Err::convert)
84}
85
86/// Create a combinator that returns the provided value, and input unchanged
87pub fn pure<I, O, E: ParseError<I>>(val: O) -> impl Fn(I) -> IResult<I, O, E>
88where
89    O: Clone,
90{
91    move |input: I| Ok((input, val.clone()))
92}
93
94/// Return a closure that takes `len` bytes from input, and applies `parser`.
95pub fn flat_take<I, C, O, E: ParseError<I>, F>(
96    len: C,
97    mut parser: F,
98) -> impl FnMut(I) -> IResult<I, O, E>
99where
100    I: InputTake + InputLength + InputIter,
101    C: ToUsize + Copy,
102    F: Parser<I, O, E>,
103{
104    // Note: this is the same as `map_parser(take(len), parser)`
105    move |input: I| {
106        let (input, o1) = take(len.to_usize())(input)?;
107        let (_, o2) = parser.parse(o1)?;
108        Ok((input, o2))
109    }
110}
111
112/// Take `len` bytes from `input`, and apply `parser`.
113pub fn flat_takec<I, O, E: ParseError<I>, C, F>(input: I, len: C, parser: F) -> IResult<I, O, E>
114where
115    C: ToUsize + Copy,
116    F: Parser<I, O, E>,
117    I: InputTake + InputLength + InputIter,
118    O: InputLength,
119{
120    flat_take(len, parser)(input)
121}
122
123/// Helper macro for nom parsers: run first parser if condition is true, else second parser
124pub fn cond_else<I, O, E: ParseError<I>, C, F, G>(
125    cond: C,
126    mut first: F,
127    mut second: G,
128) -> impl FnMut(I) -> IResult<I, O, E>
129where
130    C: Fn() -> bool,
131    F: Parser<I, O, E>,
132    G: Parser<I, O, E>,
133{
134    move |input: I| {
135        if cond() {
136            first.parse(input)
137        } else {
138            second.parse(input)
139        }
140    }
141}
142
143/// Align input value to the next multiple of n bytes
144/// Valid only if n is a power of 2
145pub const fn align_n2(x: usize, n: usize) -> usize {
146    (x + (n - 1)) & !(n - 1)
147}
148
149/// Align input value to the next multiple of 4 bytes
150pub const fn align32(x: usize) -> usize {
151    (x + 3) & !3
152}
153
154#[cfg(test)]
155mod tests {
156    use super::{align32, be_var_u64, cond_else, flat_take, pure};
157    use nom::bytes::streaming::take;
158    use nom::number::streaming::{be_u16, be_u32, be_u8};
159    use nom::{Err, IResult, Needed};
160
161    #[test]
162    fn test_be_var_u64() {
163        let res: IResult<&[u8], u64> = be_var_u64(b"\x12\x34\x56");
164        let (_, v) = res.expect("be_var_u64 failed");
165        assert_eq!(v, 0x123456);
166    }
167
168    #[test]
169    fn test_flat_take() {
170        let input = &[0x00, 0x01, 0xff];
171        // read first 2 bytes and use correct combinator: OK
172        let res: IResult<&[u8], u16> = flat_take(2u8, be_u16)(input);
173        assert_eq!(res, Ok((&input[2..], 0x0001)));
174        // read 3 bytes and use 2: OK (some input is just lost)
175        let res: IResult<&[u8], u16> = flat_take(3u8, be_u16)(input);
176        assert_eq!(res, Ok((&b""[..], 0x0001)));
177        // read 2 bytes and a combinator requiring more bytes
178        let res: IResult<&[u8], u32> = flat_take(2u8, be_u32)(input);
179        assert_eq!(res, Err(Err::Incomplete(Needed::new(2))));
180    }
181
182    #[test]
183    fn test_flat_take_str() {
184        let input = "abcdef";
185        // read first 2 bytes and use correct combinator: OK
186        let res: IResult<&str, &str> = flat_take(2u8, take(2u8))(input);
187        assert_eq!(res, Ok(("cdef", "ab")));
188        // read 3 bytes and use 2: OK (some input is just lost)
189        let res: IResult<&str, &str> = flat_take(3u8, take(2u8))(input);
190        assert_eq!(res, Ok(("def", "ab")));
191        // read 2 bytes and a use combinator requiring more bytes
192        let res: IResult<&str, &str> = flat_take(2u8, take(4u8))(input);
193        assert_eq!(res, Err(Err::Incomplete(Needed::Unknown)));
194    }
195
196    #[test]
197    fn test_cond_else() {
198        let input = &[0x01][..];
199        let empty = &b""[..];
200        let a = 1;
201        fn parse_u8(i: &[u8]) -> IResult<&[u8], u8> {
202            be_u8(i)
203        }
204        assert_eq!(
205            cond_else(|| a == 1, parse_u8, pure(0x02))(input),
206            Ok((empty, 0x01))
207        );
208        assert_eq!(
209            cond_else(|| a == 1, parse_u8, pure(0x02))(input),
210            Ok((empty, 0x01))
211        );
212        assert_eq!(
213            cond_else(|| a == 2, parse_u8, pure(0x02))(input),
214            Ok((input, 0x02))
215        );
216        assert_eq!(
217            cond_else(|| a == 1, pure(0x02), parse_u8)(input),
218            Ok((input, 0x02))
219        );
220        let res: IResult<&[u8], u8> = cond_else(|| a == 1, parse_u8, parse_u8)(input);
221        assert_eq!(res, Ok((empty, 0x01)));
222    }
223
224    #[test]
225    fn test_align32() {
226        assert_eq!(align32(3), 4);
227        assert_eq!(align32(4), 4);
228        assert_eq!(align32(5), 8);
229        assert_eq!(align32(5usize), 8);
230    }
231}