cedar_policy_core/parser/
unescape.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use crate::ast::PatternElem;
18use itertools::Itertools;
19use miette::Diagnostic;
20use nonempty::NonEmpty;
21use rustc_lexer::unescape::{unescape_str, EscapeError};
22use smol_str::SmolStr;
23use std::ops::Range;
24use thiserror::Error;
25
26/// Unescape a string following Cedar's string escape rules
27pub fn to_unescaped_string(s: &str) -> Result<SmolStr, NonEmpty<UnescapeError>> {
28    let mut unescaped_str = String::new();
29    let mut errs = Vec::new();
30    let mut callback = |range, r| match r {
31        Ok(c) => unescaped_str.push(c),
32        Err(err) => errs.push(UnescapeError {
33            err,
34            input: s.to_owned(),
35            range,
36        }),
37    };
38    unescape_str(s, &mut callback);
39    if let Some((head, tails)) = errs.split_first() {
40        Err(NonEmpty {
41            head: head.clone(),
42            tail: tails.iter().cloned().collect_vec(),
43        })
44    } else {
45        Ok(unescaped_str.into())
46    }
47}
48
49pub(crate) fn to_pattern(s: &str) -> Result<Vec<PatternElem>, NonEmpty<UnescapeError>> {
50    let mut unescaped_str = Vec::new();
51    let mut errs = Vec::new();
52    let bytes = s.as_bytes(); // to inspect string element in O(1) time
53    let mut callback = |range: Range<usize>, r| match r {
54        Ok(c) => unescaped_str.push(if c == '*' { PatternElem::Wildcard }else { PatternElem::Char(c) }),
55        // PANIC SAFETY By invariant, all passed in ranges must be in range
56        #[allow(clippy::indexing_slicing)]
57        Err(EscapeError::InvalidEscape)
58        // note that the range argument refers to the *byte* offset into the string.
59        // so we can compare the byte slice against the bytes of the ``star'' escape sequence.
60        if &bytes[range.clone()] == br"\*"
61            =>
62        {
63            unescaped_str.push(PatternElem::Char('*'))
64        }
65        Err(err) => errs.push(UnescapeError { err, input: s.to_owned(), range }),
66    };
67    unescape_str(s, &mut callback);
68    if let Some((head, tails)) = errs.split_first() {
69        Err(NonEmpty {
70            head: head.clone(),
71            tail: tails.iter().cloned().collect_vec(),
72        })
73    } else {
74        Ok(unescaped_str)
75    }
76}
77
78/// Errors generated when processing escapes
79#[derive(Debug, Diagnostic, Error, PartialEq, Eq)]
80pub struct UnescapeError {
81    /// underlying EscapeError
82    err: EscapeError,
83    /// copy of the input string which had the error
84    #[source_code]
85    input: String,
86    /// Range of the input string where the error occurred
87    /// This range must be within the length of `input`
88    #[label]
89    range: Range<usize>,
90}
91
92impl Clone for UnescapeError {
93    fn clone(&self) -> Self {
94        Self {
95            err: clone_escape_error(&self.err),
96            input: self.input.clone(),
97            range: self.range.clone(),
98        }
99    }
100}
101
102/// [`EscapeError`] doesn't implement clone or copy
103fn clone_escape_error(e: &EscapeError) -> EscapeError {
104    match e {
105        EscapeError::ZeroChars => EscapeError::ZeroChars,
106        EscapeError::MoreThanOneChar => EscapeError::MoreThanOneChar,
107        EscapeError::LoneSlash => EscapeError::LoneSlash,
108        EscapeError::InvalidEscape => EscapeError::InvalidEscape,
109        EscapeError::BareCarriageReturn => EscapeError::BareCarriageReturn,
110        EscapeError::BareCarriageReturnInRawString => EscapeError::BareCarriageReturnInRawString,
111        EscapeError::EscapeOnlyChar => EscapeError::EscapeOnlyChar,
112        EscapeError::TooShortHexEscape => EscapeError::TooShortHexEscape,
113        EscapeError::InvalidCharInHexEscape => EscapeError::InvalidCharInHexEscape,
114        EscapeError::OutOfRangeHexEscape => EscapeError::OutOfRangeHexEscape,
115        EscapeError::NoBraceInUnicodeEscape => EscapeError::NoBraceInUnicodeEscape,
116        EscapeError::InvalidCharInUnicodeEscape => EscapeError::InvalidCharInUnicodeEscape,
117        EscapeError::EmptyUnicodeEscape => EscapeError::EmptyUnicodeEscape,
118        EscapeError::UnclosedUnicodeEscape => EscapeError::UnclosedUnicodeEscape,
119        EscapeError::LeadingUnderscoreUnicodeEscape => EscapeError::LeadingUnderscoreUnicodeEscape,
120        EscapeError::OverlongUnicodeEscape => EscapeError::OverlongUnicodeEscape,
121        EscapeError::LoneSurrogateUnicodeEscape => EscapeError::LoneSurrogateUnicodeEscape,
122        EscapeError::OutOfRangeUnicodeEscape => EscapeError::OutOfRangeUnicodeEscape,
123        EscapeError::UnicodeEscapeInByte => EscapeError::UnicodeEscapeInByte,
124        EscapeError::NonAsciiCharInByte => EscapeError::NonAsciiCharInByte,
125        EscapeError::NonAsciiCharInByteString => EscapeError::NonAsciiCharInByteString,
126    }
127}
128
129impl std::fmt::Display for UnescapeError {
130    // PANIC SAFETY By invariant, the range will always be within the bounds of `input`
131    #[allow(clippy::indexing_slicing)]
132    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133        write!(
134            f,
135            "the input `{}` is not a valid escape",
136            &self.input[self.range.clone()],
137        )
138    }
139}
140
141#[cfg(test)]
142mod test {
143    use cool_asserts::assert_matches;
144
145    use super::to_unescaped_string;
146    use crate::ast;
147    use crate::parser::err::{ParseError, ToASTErrorKind};
148    use crate::parser::text_to_cst;
149
150    #[test]
151    fn test_string_escape() {
152        // refer to this doc for Rust string escapes: http://web.mit.edu/rust-lang_v1.25/arch/amd64_ubuntu1404/share/doc/rust/html/reference/tokens.html
153
154        // valid ASCII escapes
155        assert_eq!(
156            to_unescaped_string(r"\t\r\n\\\0\x42").expect("valid string"),
157            "\t\r\n\\\0\x42"
158        );
159
160        // invalid ASCII escapes
161        let errs = to_unescaped_string(r"abc\xFFdef").expect_err("should be an invalid escape");
162        assert_eq!(errs.len(), 1);
163
164        // valid unicode escapes
165        assert_eq!(
166            to_unescaped_string(r"\u{0}\u{1}\u{1234}\u{12345}\u{054321}\u{123}\u{42}",)
167                .expect("valid string"),
168            "\u{000000}\u{001}\u{001234}\u{012345}\u{054321}\u{123}\u{00042}"
169        );
170
171        // invalid unicode escapes
172        let errs = to_unescaped_string(r"abc\u{1111111}\u{222222222}FFdef")
173            .expect_err("should be invalid escapes");
174        assert_eq!(errs.len(), 2);
175
176        // invalid escapes
177        let errs = to_unescaped_string(r"abc\*\bdef").expect_err("should be invalid escapes");
178        assert_eq!(errs.len(), 2);
179    }
180
181    // PANIC SAFETY: testing
182    #[allow(clippy::indexing_slicing)]
183    #[test]
184    fn test_pattern_escape() {
185        // valid ASCII escapes
186        assert!(
187            matches!(text_to_cst::parse_expr(r#""aa" like "\t\r\n\\\0\x42\*""#)
188            .expect("failed parsing")
189            .to_expr::<ast::ExprBuilder<()>>()
190            .expect("failed conversion").expr_kind(),
191            ast::ExprKind::Like {
192                expr: _,
193                pattern,
194            } if
195                pattern.to_string() ==
196                format!("{}{}", "\t\r\n\\\0\x42".escape_debug(), r"\*")
197            )
198        );
199
200        // invalid ASCII escapes
201        let errs = text_to_cst::parse_expr(r#""abc" like "abc\xFF\xFEdef""#)
202            .expect("failed parsing")
203            .to_expr::<ast::ExprBuilder<()>>()
204            .unwrap_err();
205        assert_eq!(errs.len(), 2);
206        assert_matches!(&errs[0], ParseError::ToAST(e) => assert_matches!(e.kind(), ToASTErrorKind::Unescape(_)));
207        assert_matches!(&errs[1], ParseError::ToAST(e) => assert_matches!(e.kind(), ToASTErrorKind::Unescape(_)));
208
209        // valid `\*` surrounded by chars
210        assert!(
211            matches!(text_to_cst::parse_expr(r#""aaa" like "👀👀\*🤞🤞\*🤝""#)
212            .expect("failed parsing")
213            .to_expr::<ast::ExprBuilder<()>>()
214            .expect("failed conversion").expr_kind(),
215            ast::ExprKind::Like { expr: _, pattern} if pattern.to_string() == *r"👀👀\*🤞🤞\*🤝")
216        );
217
218        // invalid escapes
219        let errs = text_to_cst::parse_expr(r#""aaa" like "abc\d\bdef""#)
220            .expect("failed parsing")
221            .to_expr::<ast::ExprBuilder<()>>()
222            .unwrap_err();
223        assert_eq!(errs.len(), 2);
224        assert_matches!(&errs[0], ParseError::ToAST(e) => assert_matches!(e.kind(), ToASTErrorKind::Unescape(_)));
225        assert_matches!(&errs[1], ParseError::ToAST(e) => assert_matches!(e.kind(), ToASTErrorKind::Unescape(_)));
226    }
227}