cedar_policy_core/ast/
id.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use serde::{Deserialize, Deserializer, Serialize};
18use smol_str::SmolStr;
19
20use crate::{parser::err::ParseErrors, FromNormalizedStr};
21
22use super::{InternalName, ReservedNameError};
23
24const RESERVED_ID: &str = "__cedar";
25
26/// Identifiers. Anything in `Id` should be a valid identifier, this means it
27/// does not contain, for instance, spaces or characters like '+'; and also is
28/// not one of the Cedar reserved identifiers (at time of writing,
29/// `true | false | if | then | else | in | is | like | has`).
30//
31// For now, internally, `Id`s are just owned `SmolString`s.
32#[derive(Serialize, Debug, PartialEq, Eq, Clone, Hash, PartialOrd, Ord)]
33pub struct Id(SmolStr);
34
35impl Id {
36    /// Create a new `Id` from a `String`, where it is the caller's
37    /// responsibility to ensure that the string is indeed a valid identifier.
38    ///
39    /// When possible, callers should not use this, and instead use `s.parse()`,
40    /// which checks that `s` is a valid identifier, and returns a parse error
41    /// if not.
42    ///
43    /// This method was created for the `From<cst::Ident> for Id` impl to use.
44    /// Since `parser::parse_ident()` implicitly uses that `From` impl itself,
45    /// if we tried to make that `From` impl go through `.parse()` like everyone
46    /// else, we'd get infinite recursion.  And, we assert that `cst::Ident` is
47    /// always already checked to contain a valid identifier, otherwise it would
48    /// never have been created.
49    pub(crate) fn new_unchecked(s: impl Into<SmolStr>) -> Id {
50        Id(s.into())
51    }
52
53    /// Get the underlying string
54    pub fn into_smolstr(self) -> SmolStr {
55        self.0
56    }
57
58    /// Return if the `Id` is reserved (i.e., `__cedar`)
59    /// Note that it does not test if the `Id` string is a reserved keyword
60    /// as the parser already ensures that it is not
61    pub fn is_reserved(&self) -> bool {
62        self.as_ref() == RESERVED_ID
63    }
64}
65
66impl AsRef<str> for Id {
67    fn as_ref(&self) -> &str {
68        &self.0
69    }
70}
71
72impl std::fmt::Display for Id {
73    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74        write!(f, "{}", &self.0)
75    }
76}
77
78// allow `.parse()` on a string to make an `Id`
79impl std::str::FromStr for Id {
80    type Err = ParseErrors;
81
82    fn from_str(s: &str) -> Result<Self, Self::Err> {
83        crate::parser::parse_ident(s)
84    }
85}
86
87impl FromNormalizedStr for Id {
88    fn describe_self() -> &'static str {
89        "Id"
90    }
91}
92
93/// An `Id` that is not equal to `__cedar`, as specified by RFC 52
94#[derive(Serialize, Debug, PartialEq, Eq, Clone, Hash, PartialOrd, Ord)]
95#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
96#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
97pub struct UnreservedId(#[cfg_attr(feature = "wasm", tsify(type = "string"))] pub(crate) Id);
98
99impl From<UnreservedId> for Id {
100    fn from(value: UnreservedId) -> Self {
101        value.0
102    }
103}
104
105impl TryFrom<Id> for UnreservedId {
106    type Error = ReservedNameError;
107    fn try_from(value: Id) -> Result<Self, Self::Error> {
108        if value.is_reserved() {
109            Err(ReservedNameError(InternalName::unqualified_name(value)))
110        } else {
111            Ok(Self(value))
112        }
113    }
114}
115
116impl AsRef<Id> for UnreservedId {
117    fn as_ref(&self) -> &Id {
118        &self.0
119    }
120}
121
122impl AsRef<str> for UnreservedId {
123    fn as_ref(&self) -> &str {
124        self.0.as_ref()
125    }
126}
127
128impl std::fmt::Display for UnreservedId {
129    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130        self.0.fmt(f)
131    }
132}
133
134impl std::str::FromStr for UnreservedId {
135    type Err = ParseErrors;
136    fn from_str(s: &str) -> Result<Self, Self::Err> {
137        Id::from_str(s).and_then(|id| id.try_into().map_err(ParseErrors::singleton))
138    }
139}
140
141impl FromNormalizedStr for UnreservedId {
142    fn describe_self() -> &'static str {
143        "Unreserved Id"
144    }
145}
146
147impl UnreservedId {
148    /// Create an [`UnreservedId`] from an empty string
149    pub(crate) fn empty() -> Self {
150        // PANIC SAFETY: "" does not contain `__cedar`
151        #[allow(clippy::unwrap_used)]
152        Id("".into()).try_into().unwrap()
153    }
154}
155
156struct IdVisitor;
157
158impl serde::de::Visitor<'_> for IdVisitor {
159    type Value = Id;
160
161    fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162        formatter.write_str("a valid id")
163    }
164
165    fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
166    where
167        E: serde::de::Error,
168    {
169        Id::from_normalized_str(value)
170            .map_err(|err| serde::de::Error::custom(format!("invalid id `{value}`: {err}")))
171    }
172}
173
174/// Deserialize an `Id` using `from_normalized_str`.
175/// This deserialization implementation is used in the JSON schema format.
176impl<'de> Deserialize<'de> for Id {
177    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
178    where
179        D: Deserializer<'de>,
180    {
181        deserializer.deserialize_str(IdVisitor)
182    }
183}
184
185/// Deserialize a [`UnreservedId`] using `from_normalized_str`
186/// This deserialization implementation is used in the JSON schema format.
187impl<'de> Deserialize<'de> for UnreservedId {
188    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
189    where
190        D: Deserializer<'de>,
191    {
192        deserializer
193            .deserialize_str(IdVisitor)
194            .and_then(|n| n.try_into().map_err(serde::de::Error::custom))
195    }
196}
197
198#[cfg(feature = "arbitrary")]
199impl<'a> arbitrary::Arbitrary<'a> for Id {
200    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
201        // identifier syntax:
202        // IDENT     := ['_''a'-'z''A'-'Z']['_''a'-'z''A'-'Z''0'-'9']* - RESERVED
203        // BOOL      := 'true' | 'false'
204        // RESERVED  := BOOL | 'if' | 'then' | 'else' | 'in' | 'is' | 'like' | 'has'
205
206        let construct_list = |s: &str| s.chars().collect::<Vec<char>>();
207        let list_concat = |s1: &[char], s2: &[char]| [s1, s2].concat();
208        // the set of the first character of an identifier
209        let head_letters = construct_list("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_");
210        // the set of the remaining characters of an identifier
211        let tail_letters = list_concat(&construct_list("0123456789"), &head_letters);
212        // identifier character count minus 1
213        let remaining_length = u.int_in_range(0..=16)?;
214        let mut cs = vec![*u.choose(&head_letters)?];
215        cs.extend(
216            (0..remaining_length)
217                .map(|_| u.choose(&tail_letters))
218                .collect::<Result<Vec<&char>, _>>()?,
219        );
220        let mut s: String = cs.into_iter().collect();
221        // Should the parsing fails, the string should be reserved word.
222        // Append a `_` to create a valid Id.
223        if crate::parser::parse_ident(&s).is_err() {
224            s.push('_');
225        }
226        Ok(Self::new_unchecked(s))
227    }
228
229    fn size_hint(depth: usize) -> (usize, Option<usize>) {
230        arbitrary::size_hint::and_all(&[
231            // for arbitrary length
232            <usize as arbitrary::Arbitrary>::size_hint(depth),
233            // for arbitrary choices
234            // we use the size hint of a vector of `u8` to get an underestimate of bytes required by the sequence of choices.
235            <Vec<u8> as arbitrary::Arbitrary>::size_hint(depth),
236        ])
237    }
238}
239
240#[cfg(feature = "arbitrary")]
241impl<'a> arbitrary::Arbitrary<'a> for UnreservedId {
242    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
243        let id: Id = u.arbitrary()?;
244        match UnreservedId::try_from(id.clone()) {
245            Ok(id) => Ok(id),
246            Err(_) => {
247                // PANIC SAFETY: `___cedar` is a valid unreserved id
248                #[allow(clippy::unwrap_used)]
249                let new_id = format!("_{id}").parse().unwrap();
250                Ok(new_id)
251            }
252        }
253    }
254
255    fn size_hint(depth: usize) -> (usize, Option<usize>) {
256        <Id as arbitrary::Arbitrary>::size_hint(depth)
257    }
258}
259
260/// Like `Id`, except this specifically _can_ contain Cedar reserved identifiers.
261/// (It still can't contain, for instance, spaces or characters like '+'.)
262//
263// For now, internally, `AnyId`s are just owned `SmolString`s.
264#[derive(Serialize, Debug, PartialEq, Eq, Clone, Hash, PartialOrd, Ord)]
265#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
266#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
267pub struct AnyId(SmolStr);
268
269impl AnyId {
270    /// Create a new `AnyId` from a `String`, where it is the caller's
271    /// responsibility to ensure that the string is indeed a valid `AnyId`.
272    ///
273    /// When possible, callers should not use this, and instead use `s.parse()`,
274    /// which checks that `s` is a valid `AnyId`, and returns a parse error
275    /// if not.
276    ///
277    /// This method was created for the `From<cst::Ident> for AnyId` impl to use.
278    /// See notes on `Id::new_unchecked()`.
279    pub(crate) fn new_unchecked(s: impl Into<SmolStr>) -> AnyId {
280        AnyId(s.into())
281    }
282
283    /// Get the underlying string
284    pub fn into_smolstr(self) -> SmolStr {
285        self.0
286    }
287}
288
289struct AnyIdVisitor;
290
291impl serde::de::Visitor<'_> for AnyIdVisitor {
292    type Value = AnyId;
293
294    fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295        formatter.write_str("any id")
296    }
297
298    fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
299    where
300        E: serde::de::Error,
301    {
302        AnyId::from_normalized_str(value)
303            .map_err(|err| serde::de::Error::custom(format!("invalid id `{value}`: {err}")))
304    }
305}
306
307/// Deserialize an `AnyId` using `from_normalized_str`.
308/// This deserialization implementation is used in the JSON policy format.
309impl<'de> Deserialize<'de> for AnyId {
310    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
311    where
312        D: Deserializer<'de>,
313    {
314        deserializer.deserialize_str(AnyIdVisitor)
315    }
316}
317
318impl AsRef<str> for AnyId {
319    fn as_ref(&self) -> &str {
320        &self.0
321    }
322}
323
324impl std::fmt::Display for AnyId {
325    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326        write!(f, "{}", &self.0)
327    }
328}
329
330// allow `.parse()` on a string to make an `AnyId`
331impl std::str::FromStr for AnyId {
332    type Err = ParseErrors;
333
334    fn from_str(s: &str) -> Result<Self, Self::Err> {
335        crate::parser::parse_anyid(s)
336    }
337}
338
339impl FromNormalizedStr for AnyId {
340    fn describe_self() -> &'static str {
341        "AnyId"
342    }
343}
344
345#[cfg(feature = "arbitrary")]
346impl<'a> arbitrary::Arbitrary<'a> for AnyId {
347    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
348        // AnyId syntax:
349        // ['_''a'-'z''A'-'Z']['_''a'-'z''A'-'Z''0'-'9']*
350
351        let construct_list = |s: &str| s.chars().collect::<Vec<char>>();
352        let list_concat = |s1: &[char], s2: &[char]| [s1, s2].concat();
353        // the set of the first character of an AnyId
354        let head_letters = construct_list("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_");
355        // the set of the remaining characters of an AnyId
356        let tail_letters = list_concat(&construct_list("0123456789"), &head_letters);
357        // identifier character count minus 1
358        let remaining_length = u.int_in_range(0..=16)?;
359        let mut cs = vec![*u.choose(&head_letters)?];
360        cs.extend(
361            (0..remaining_length)
362                .map(|_| u.choose(&tail_letters))
363                .collect::<Result<Vec<&char>, _>>()?,
364        );
365        let s: String = cs.into_iter().collect();
366        debug_assert!(
367            crate::parser::parse_anyid(&s).is_ok(),
368            "all strings constructed this way should be valid AnyIds, but this one is not: {s:?}"
369        );
370        Ok(Self::new_unchecked(s))
371    }
372
373    fn size_hint(depth: usize) -> (usize, Option<usize>) {
374        arbitrary::size_hint::and_all(&[
375            // for arbitrary length
376            <usize as arbitrary::Arbitrary>::size_hint(depth),
377            // for arbitrary choices
378            // we use the size hint of a vector of `u8` to get an underestimate of bytes required by the sequence of choices.
379            <Vec<u8> as arbitrary::Arbitrary>::size_hint(depth),
380        ])
381    }
382}
383
384// PANIC SAFETY: unit-test code
385#[allow(clippy::panic)]
386#[cfg(test)]
387mod test {
388    use super::*;
389
390    #[test]
391    fn normalized_id() {
392        Id::from_normalized_str("foo").expect("should be OK");
393        Id::from_normalized_str("foo::bar").expect_err("shouldn't be OK");
394        Id::from_normalized_str(r#"foo::"bar""#).expect_err("shouldn't be OK");
395        Id::from_normalized_str(" foo").expect_err("shouldn't be OK");
396        Id::from_normalized_str("foo ").expect_err("shouldn't be OK");
397        Id::from_normalized_str("foo\n").expect_err("shouldn't be OK");
398        Id::from_normalized_str("foo//comment").expect_err("shouldn't be OK");
399    }
400}