cedar_policy_core/ast/
pattern.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use std::sync::Arc;
18
19use serde::{Deserialize, Serialize};
20
21/// Represent an element in a pattern literal (the RHS of the like operation)
22#[derive(Deserialize, Serialize, Hash, Debug, Clone, Copy, PartialEq, Eq)]
23#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
24pub enum PatternElem {
25    /// A character literal
26    Char(char),
27    /// The wildcard `*`
28    Wildcard,
29}
30
31/// Represent a pattern literal (the RHS of the like operator)
32/// Also provides an implementation of the Display trait as well as a wildcard matching method.
33#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, Deserialize)]
34#[serde(transparent)]
35pub struct Pattern {
36    /// A vector of pattern elements
37    elems: Arc<Vec<PatternElem>>,
38}
39
40impl Pattern {
41    /// Explicitly create a pattern literal out of a shared vector of pattern elements
42    fn new(elems: Arc<Vec<PatternElem>>) -> Self {
43        Self { elems }
44    }
45
46    /// Getter to the wrapped vector
47    pub fn get_elems(&self) -> &[PatternElem] {
48        &self.elems
49    }
50
51    /// Iterate over pattern elements
52    pub fn iter(&self) -> impl Iterator<Item = &PatternElem> {
53        self.elems.iter()
54    }
55
56    /// Length of elems vector
57    pub fn len(&self) -> usize {
58        self.elems.len()
59    }
60}
61
62impl From<Arc<Vec<PatternElem>>> for Pattern {
63    fn from(value: Arc<Vec<PatternElem>>) -> Self {
64        Self::new(value)
65    }
66}
67
68impl From<Vec<PatternElem>> for Pattern {
69    fn from(value: Vec<PatternElem>) -> Self {
70        Self::new(Arc::new(value))
71    }
72}
73
74impl FromIterator<PatternElem> for Pattern {
75    fn from_iter<T: IntoIterator<Item = PatternElem>>(iter: T) -> Self {
76        Self::new(Arc::new(iter.into_iter().collect()))
77    }
78}
79
80impl std::fmt::Display for Pattern {
81    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82        for pc in self.elems.as_ref() {
83            match pc {
84                PatternElem::Char('*') => write!(f, r#"\*"#)?,
85                PatternElem::Char(c) => write!(f, "{}", c.escape_debug())?,
86                PatternElem::Wildcard => write!(f, r#"*"#)?,
87            }
88        }
89        Ok(())
90    }
91}
92
93impl PatternElem {
94    fn match_char(&self, text_char: &char) -> bool {
95        match self {
96            PatternElem::Char(c) => text_char == c,
97            PatternElem::Wildcard => true,
98        }
99    }
100    fn is_wildcard(&self) -> bool {
101        matches!(self, PatternElem::Wildcard)
102    }
103}
104
105impl Pattern {
106    /// Find if the argument text matches the pattern
107    pub fn wildcard_match(&self, text: &str) -> bool {
108        let pattern = self.get_elems();
109        if pattern.is_empty() {
110            return text.is_empty();
111        }
112
113        // Copying the strings into vectors requires extra space, but has two benefits:
114        // 1. It makes accessing elements more efficient. The alternative (i.e.,
115        //    chars().nth()) needs to re-scan the string for each invocation. Note
116        //    that a simple iterator will not work here since we move both forward
117        //    and backward through the string.
118        // 2. It provides an unambiguous length. In general for a string s,
119        //    s.len() is not the same as s.chars().count(). The length of these
120        //    created vectors will match .chars().count()
121        let text: Vec<char> = text.chars().collect();
122
123        let mut i: usize = 0; // index into text
124        let mut j: usize = 0; // index into pattern
125        let mut star_idx: usize = 0; // index in pattern (j) of the most recent *
126        let mut tmp_idx: usize = 0; // index in text (i) of the most recent *
127        let mut contains_star: bool = false; // does the pattern contain *?
128
129        let text_len = text.len();
130        let pattern_len = pattern.len();
131
132        while i < text_len && (!contains_star || star_idx != pattern_len - 1) {
133            // PANIC SAFETY `j` is checked to be less than length
134            #[allow(clippy::indexing_slicing)]
135            if j < pattern_len && pattern[j].is_wildcard() {
136                contains_star = true;
137                star_idx = j;
138                tmp_idx = i;
139                j += 1;
140            } else if j < pattern_len && pattern[j].match_char(&text[i]) {
141                i += 1;
142                j += 1;
143            } else if contains_star {
144                j = star_idx + 1;
145                i = tmp_idx + 1;
146                tmp_idx = i;
147            } else {
148                return false;
149            }
150        }
151
152        // PANIC SAFETY `j` is checked to be less than length
153        #[allow(clippy::indexing_slicing)]
154        while j < pattern_len && pattern[j].is_wildcard() {
155            j += 1;
156        }
157
158        j == pattern_len
159    }
160}
161
162#[cfg(test)]
163mod test {
164    use super::*;
165
166    impl std::ops::Add for Pattern {
167        type Output = Pattern;
168        fn add(self, rhs: Self) -> Self::Output {
169            let elems = [self.get_elems(), rhs.get_elems()].concat();
170            Pattern::from(elems)
171        }
172    }
173
174    // Map a string into a pattern literal with `PatternElem::Char`
175    fn string_map(text: &str) -> Pattern {
176        text.chars().map(PatternElem::Char).collect()
177    }
178
179    // Create a star pattern literal
180    fn star() -> Pattern {
181        Pattern::from(vec![PatternElem::Wildcard])
182    }
183
184    // Create an empty pattern literal
185    fn empty() -> Pattern {
186        Pattern::from(vec![])
187    }
188
189    #[test]
190    fn test_wildcard_match_basic() {
191        // Patterns that match "foo bar"
192        assert!((string_map("foo") + star()).wildcard_match("foo bar"));
193        assert!((star() + string_map("bar")).wildcard_match("foo bar"));
194        assert!((star() + string_map("o b") + star()).wildcard_match("foo bar"));
195        assert!((string_map("f") + star() + string_map(" bar")).wildcard_match("foo bar"));
196        assert!((string_map("f") + star() + star() + string_map("r")).wildcard_match("foo bar"));
197        assert!((star() + string_map("f") + star() + star() + star()).wildcard_match("foo bar"));
198
199        // Patterns that do not match "foo bar"
200        assert!(!(star() + string_map("foo")).wildcard_match("foo bar"));
201        assert!(!(string_map("bar") + star()).wildcard_match("foo bar"));
202        assert!(!(star() + string_map("bo") + star()).wildcard_match("foo bar"));
203        assert!(!(string_map("f") + star() + string_map("br")).wildcard_match("foo bar"));
204        assert!(!(star() + string_map("x") + star() + star() + star()).wildcard_match("foo bar"));
205        assert!(!empty().wildcard_match("foo bar"));
206
207        // Patterns that match ""
208        assert!(empty().wildcard_match(""));
209        assert!(star().wildcard_match(""));
210
211        // Patterns that do not match ""
212        assert!(!string_map("foo bar").wildcard_match(""));
213
214        // Patterns that match "*"
215        assert!(string_map("*").wildcard_match("*"));
216        assert!(star().wildcard_match("*"));
217
218        // Patterns that do not match "*"
219        assert!(!string_map("\u{0000}").wildcard_match("*"));
220        assert!(!string_map(r"\u{0000}").wildcard_match("*"));
221    }
222
223    #[test]
224    fn test_wildcard_match_unicode() {
225        // Patterns that match "y̆"
226        assert!((string_map("y") + star()).wildcard_match("y̆"));
227        assert!(string_map("y̆").wildcard_match("y̆"));
228
229        // Patterns that do not match "y̆"
230        assert!(!(star() + string_map("p") + star()).wildcard_match("y̆"));
231
232        // Patterns that match "ḛ̶͑͝x̶͔͛a̵̰̯͛m̴͉̋́p̷̠͂l̵͇̍̔ȩ̶̣͝"
233        assert!((star() + string_map("p") + star()).wildcard_match("ḛ̶͑͝x̶͔͛a̵̰̯͛m̴͉̋́p̷̠͂l̵͇̍̔ȩ̶̣͝"));
234        assert!((star() + string_map("a̵̰̯͛m̴͉̋́") + star()).wildcard_match("ḛ̶͑͝x̶͔͛a̵̰̯͛m̴͉̋́p̷̠͂l̵͇̍̔ȩ̶̣͝"));
235
236        // Patterns that do not match "ḛ̶͑͝x̶͔͛a̵̰̯͛m̴͉̋́p̷̠͂l̵͇̍̔ȩ̶̣͝"
237        assert!(!(string_map("y") + star()).wildcard_match("ḛ̶͑͝x̶͔͛a̵̰̯͛m̴͉̋́p̷̠͂l̵͇̍̔ȩ̶̣͝"));
238    }
239}