cedar_policy_core/ast/
pattern.rs1use std::sync::Arc;
18
19use serde::{Deserialize, Serialize};
20
21#[derive(Deserialize, Serialize, Hash, Debug, Clone, Copy, PartialEq, Eq)]
23#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
24pub enum PatternElem {
25 Char(char),
27 Wildcard,
29}
30
31#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, Deserialize)]
34#[serde(transparent)]
35pub struct Pattern {
36 elems: Arc<Vec<PatternElem>>,
38}
39
40impl Pattern {
41 fn new(elems: Arc<Vec<PatternElem>>) -> Self {
43 Self { elems }
44 }
45
46 pub fn get_elems(&self) -> &[PatternElem] {
48 &self.elems
49 }
50
51 pub fn iter(&self) -> impl Iterator<Item = &PatternElem> {
53 self.elems.iter()
54 }
55
56 pub fn len(&self) -> usize {
58 self.elems.len()
59 }
60}
61
62impl From<Arc<Vec<PatternElem>>> for Pattern {
63 fn from(value: Arc<Vec<PatternElem>>) -> Self {
64 Self::new(value)
65 }
66}
67
68impl From<Vec<PatternElem>> for Pattern {
69 fn from(value: Vec<PatternElem>) -> Self {
70 Self::new(Arc::new(value))
71 }
72}
73
74impl FromIterator<PatternElem> for Pattern {
75 fn from_iter<T: IntoIterator<Item = PatternElem>>(iter: T) -> Self {
76 Self::new(Arc::new(iter.into_iter().collect()))
77 }
78}
79
80impl std::fmt::Display for Pattern {
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 for pc in self.elems.as_ref() {
83 match pc {
84 PatternElem::Char('*') => write!(f, r#"\*"#)?,
85 PatternElem::Char(c) => write!(f, "{}", c.escape_debug())?,
86 PatternElem::Wildcard => write!(f, r#"*"#)?,
87 }
88 }
89 Ok(())
90 }
91}
92
93impl PatternElem {
94 fn match_char(&self, text_char: &char) -> bool {
95 match self {
96 PatternElem::Char(c) => text_char == c,
97 PatternElem::Wildcard => true,
98 }
99 }
100 fn is_wildcard(&self) -> bool {
101 matches!(self, PatternElem::Wildcard)
102 }
103}
104
105impl Pattern {
106 pub fn wildcard_match(&self, text: &str) -> bool {
108 let pattern = self.get_elems();
109 if pattern.is_empty() {
110 return text.is_empty();
111 }
112
113 let text: Vec<char> = text.chars().collect();
122
123 let mut i: usize = 0; let mut j: usize = 0; let mut star_idx: usize = 0; let mut tmp_idx: usize = 0; let mut contains_star: bool = false; let text_len = text.len();
130 let pattern_len = pattern.len();
131
132 while i < text_len && (!contains_star || star_idx != pattern_len - 1) {
133 #[allow(clippy::indexing_slicing)]
135 if j < pattern_len && pattern[j].is_wildcard() {
136 contains_star = true;
137 star_idx = j;
138 tmp_idx = i;
139 j += 1;
140 } else if j < pattern_len && pattern[j].match_char(&text[i]) {
141 i += 1;
142 j += 1;
143 } else if contains_star {
144 j = star_idx + 1;
145 i = tmp_idx + 1;
146 tmp_idx = i;
147 } else {
148 return false;
149 }
150 }
151
152 #[allow(clippy::indexing_slicing)]
154 while j < pattern_len && pattern[j].is_wildcard() {
155 j += 1;
156 }
157
158 j == pattern_len
159 }
160}
161
162#[cfg(test)]
163mod test {
164 use super::*;
165
166 impl std::ops::Add for Pattern {
167 type Output = Pattern;
168 fn add(self, rhs: Self) -> Self::Output {
169 let elems = [self.get_elems(), rhs.get_elems()].concat();
170 Pattern::from(elems)
171 }
172 }
173
174 fn string_map(text: &str) -> Pattern {
176 text.chars().map(PatternElem::Char).collect()
177 }
178
179 fn star() -> Pattern {
181 Pattern::from(vec![PatternElem::Wildcard])
182 }
183
184 fn empty() -> Pattern {
186 Pattern::from(vec![])
187 }
188
189 #[test]
190 fn test_wildcard_match_basic() {
191 assert!((string_map("foo") + star()).wildcard_match("foo bar"));
193 assert!((star() + string_map("bar")).wildcard_match("foo bar"));
194 assert!((star() + string_map("o b") + star()).wildcard_match("foo bar"));
195 assert!((string_map("f") + star() + string_map(" bar")).wildcard_match("foo bar"));
196 assert!((string_map("f") + star() + star() + string_map("r")).wildcard_match("foo bar"));
197 assert!((star() + string_map("f") + star() + star() + star()).wildcard_match("foo bar"));
198
199 assert!(!(star() + string_map("foo")).wildcard_match("foo bar"));
201 assert!(!(string_map("bar") + star()).wildcard_match("foo bar"));
202 assert!(!(star() + string_map("bo") + star()).wildcard_match("foo bar"));
203 assert!(!(string_map("f") + star() + string_map("br")).wildcard_match("foo bar"));
204 assert!(!(star() + string_map("x") + star() + star() + star()).wildcard_match("foo bar"));
205 assert!(!empty().wildcard_match("foo bar"));
206
207 assert!(empty().wildcard_match(""));
209 assert!(star().wildcard_match(""));
210
211 assert!(!string_map("foo bar").wildcard_match(""));
213
214 assert!(string_map("*").wildcard_match("*"));
216 assert!(star().wildcard_match("*"));
217
218 assert!(!string_map("\u{0000}").wildcard_match("*"));
220 assert!(!string_map(r"\u{0000}").wildcard_match("*"));
221 }
222
223 #[test]
224 fn test_wildcard_match_unicode() {
225 assert!((string_map("y") + star()).wildcard_match("y̆"));
227 assert!(string_map("y̆").wildcard_match("y̆"));
228
229 assert!(!(star() + string_map("p") + star()).wildcard_match("y̆"));
231
232 assert!((star() + string_map("p") + star()).wildcard_match("ḛ̶͑͝x̶͔͛a̵̰̯͛m̴͉̋́p̷̠͂l̵͇̍̔ȩ̶̣͝"));
234 assert!((star() + string_map("a̵̰̯͛m̴͉̋́") + star()).wildcard_match("ḛ̶͑͝x̶͔͛a̵̰̯͛m̴͉̋́p̷̠͂l̵͇̍̔ȩ̶̣͝"));
235
236 assert!(!(string_map("y") + star()).wildcard_match("ḛ̶͑͝x̶͔͛a̵̰̯͛m̴͉̋́p̷̠͂l̵͇̍̔ȩ̶̣͝"));
238 }
239}