cedar_policy_core/ast/
request.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use crate::entities::json::{
18    ContextJsonDeserializationError, ContextJsonParser, NullContextSchema,
19};
20use crate::evaluator::{EvaluationError, RestrictedEvaluator};
21use crate::extensions::Extensions;
22use crate::parser::Loc;
23use miette::Diagnostic;
24use serde::{Deserialize, Serialize};
25use smol_str::SmolStr;
26use std::collections::{BTreeMap, HashMap};
27use std::sync::Arc;
28use thiserror::Error;
29
30use super::{
31    BorrowedRestrictedExpr, BoundedDisplay, EntityType, EntityUID, Expr, ExprKind,
32    ExpressionConstructionError, PartialValue, RestrictedExpr, Unknown, Value, ValueKind, Var,
33};
34
35/// Represents the request tuple <P, A, R, C> (see the Cedar design doc).
36#[derive(Debug, Clone, Serialize)]
37pub struct Request {
38    /// Principal associated with the request
39    pub(crate) principal: EntityUIDEntry,
40
41    /// Action associated with the request
42    pub(crate) action: EntityUIDEntry,
43
44    /// Resource associated with the request
45    pub(crate) resource: EntityUIDEntry,
46
47    /// Context associated with the request.
48    /// `None` means that variable will result in a residual for partial evaluation.
49    pub(crate) context: Option<Context>,
50}
51
52/// Represents the principal type, resource type, and action UID.
53#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize, Serialize)]
54#[serde(rename_all = "camelCase")]
55pub struct RequestType {
56    /// Principal type
57    pub principal: EntityType,
58    /// Action type
59    pub action: EntityUID,
60    /// Resource type
61    pub resource: EntityType,
62}
63
64/// An entry in a request for a Entity UID.
65/// It may either be a concrete EUID
66/// or an unknown in the case of partial evaluation
67#[derive(Debug, Clone, Serialize)]
68pub enum EntityUIDEntry {
69    /// A concrete EntityUID
70    Known {
71        /// The concrete `EntityUID`
72        euid: Arc<EntityUID>,
73        /// Source location associated with the `EntityUIDEntry`, if any
74        loc: Option<Loc>,
75    },
76    /// An EntityUID left as unknown for partial evaluation
77    Unknown {
78        /// The type of the unknown EntityUID, if known.
79        ty: Option<EntityType>,
80
81        /// Source location associated with the `EntityUIDEntry`, if any
82        loc: Option<Loc>,
83    },
84}
85
86impl EntityUIDEntry {
87    /// Evaluate the entry to either:
88    /// A value, if the entry is concrete
89    /// An unknown corresponding to the passed `var`
90    pub fn evaluate(&self, var: Var) -> PartialValue {
91        match self {
92            EntityUIDEntry::Known { euid, loc } => {
93                Value::new(Arc::unwrap_or_clone(Arc::clone(euid)), loc.clone()).into()
94            }
95            EntityUIDEntry::Unknown { ty: None, loc } => {
96                Expr::unknown(Unknown::new_untyped(var.to_string()))
97                    .with_maybe_source_loc(loc.clone())
98                    .into()
99            }
100            EntityUIDEntry::Unknown {
101                ty: Some(known_type),
102                loc,
103            } => Expr::unknown(Unknown::new_with_type(
104                var.to_string(),
105                super::Type::Entity {
106                    ty: known_type.clone(),
107                },
108            ))
109            .with_maybe_source_loc(loc.clone())
110            .into(),
111        }
112    }
113
114    /// Create an entry with a concrete EntityUID and the given source location
115    pub fn known(euid: EntityUID, loc: Option<Loc>) -> Self {
116        Self::Known {
117            euid: Arc::new(euid),
118            loc,
119        }
120    }
121
122    /// Create an entry with an entirely unknown EntityUID
123    pub fn unknown() -> Self {
124        Self::Unknown {
125            ty: None,
126            loc: None,
127        }
128    }
129
130    /// Create an entry with an unknown EntityUID but known EntityType
131    pub fn unknown_with_type(ty: EntityType, loc: Option<Loc>) -> Self {
132        Self::Unknown { ty: Some(ty), loc }
133    }
134
135    /// Get the UID of the entry, or `None` if it is unknown (partial evaluation)
136    pub fn uid(&self) -> Option<&EntityUID> {
137        match self {
138            Self::Known { euid, .. } => Some(euid),
139            Self::Unknown { .. } => None,
140        }
141    }
142
143    /// Get the type of the entry, or `None` if it is unknown (partial evaluation with no type annotation)
144    pub fn get_type(&self) -> Option<&EntityType> {
145        match self {
146            Self::Known { euid, .. } => Some(euid.entity_type()),
147            Self::Unknown { ty, .. } => ty.as_ref(),
148        }
149    }
150}
151
152impl Request {
153    /// Default constructor.
154    ///
155    /// If `schema` is provided, this constructor validates that this `Request`
156    /// complies with the given `schema`.
157    pub fn new<S: RequestSchema>(
158        principal: (EntityUID, Option<Loc>),
159        action: (EntityUID, Option<Loc>),
160        resource: (EntityUID, Option<Loc>),
161        context: Context,
162        schema: Option<&S>,
163        extensions: &Extensions<'_>,
164    ) -> Result<Self, S::Error> {
165        let req = Self {
166            principal: EntityUIDEntry::known(principal.0, principal.1),
167            action: EntityUIDEntry::known(action.0, action.1),
168            resource: EntityUIDEntry::known(resource.0, resource.1),
169            context: Some(context),
170        };
171        if let Some(schema) = schema {
172            schema.validate_request(&req, extensions)?;
173        }
174        Ok(req)
175    }
176
177    /// Create a new `Request` with potentially unknown (for partial eval) variables.
178    ///
179    /// If `schema` is provided, this constructor validates that this `Request`
180    /// complies with the given `schema` (at least to the extent that we can
181    /// validate with the given information)
182    pub fn new_with_unknowns<S: RequestSchema>(
183        principal: EntityUIDEntry,
184        action: EntityUIDEntry,
185        resource: EntityUIDEntry,
186        context: Option<Context>,
187        schema: Option<&S>,
188        extensions: &Extensions<'_>,
189    ) -> Result<Self, S::Error> {
190        let req = Self {
191            principal,
192            action,
193            resource,
194            context,
195        };
196        if let Some(schema) = schema {
197            schema.validate_request(&req, extensions)?;
198        }
199        Ok(req)
200    }
201
202    /// Create a new `Request` with potentially unknown (for partial eval) variables/context
203    /// and without schema validation.
204    pub fn new_unchecked(
205        principal: EntityUIDEntry,
206        action: EntityUIDEntry,
207        resource: EntityUIDEntry,
208        context: Option<Context>,
209    ) -> Self {
210        Self {
211            principal,
212            action,
213            resource,
214            context,
215        }
216    }
217
218    /// Get the principal associated with the request
219    pub fn principal(&self) -> &EntityUIDEntry {
220        &self.principal
221    }
222
223    /// Get the action associated with the request
224    pub fn action(&self) -> &EntityUIDEntry {
225        &self.action
226    }
227
228    /// Get the resource associated with the request
229    pub fn resource(&self) -> &EntityUIDEntry {
230        &self.resource
231    }
232
233    /// Get the context associated with the request
234    /// Returning `None` means the variable is unknown, and will result in a residual expression
235    pub fn context(&self) -> Option<&Context> {
236        self.context.as_ref()
237    }
238
239    /// Get the request types that correspond to this request.
240    /// This includes the types of the principal, action, and resource.
241    /// [`RequestType`] is used by the entity manifest.
242    /// The context type is implied by the action's type.
243    /// Returns `None` if the request is not fully concrete.
244    pub fn to_request_type(&self) -> Option<RequestType> {
245        Some(RequestType {
246            principal: self.principal().uid()?.entity_type().clone(),
247            action: self.action().uid()?.clone(),
248            resource: self.resource().uid()?.entity_type().clone(),
249        })
250    }
251}
252
253impl std::fmt::Display for Request {
254    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255        let display_euid = |maybe_euid: &EntityUIDEntry| match maybe_euid {
256            EntityUIDEntry::Known { euid, .. } => format!("{euid}"),
257            EntityUIDEntry::Unknown { ty: None, .. } => "unknown".to_string(),
258            EntityUIDEntry::Unknown {
259                ty: Some(known_type),
260                ..
261            } => format!("unknown of type {}", known_type),
262        };
263        write!(
264            f,
265            "request with principal {}, action {}, resource {}, and context {}",
266            display_euid(&self.principal),
267            display_euid(&self.action),
268            display_euid(&self.resource),
269            match &self.context {
270                Some(x) => format!("{x}"),
271                None => "unknown".to_string(),
272            }
273        )
274    }
275}
276
277/// `Context` field of a `Request`
278#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
279// Serialization is used for differential testing, which requires that `Context`
280// is serialized as a `RestrictedExpr`.
281#[serde(into = "RestrictedExpr")]
282pub enum Context {
283    /// The context is a concrete value.
284    Value(Arc<BTreeMap<SmolStr, Value>>),
285    /// The context is a residual expression, containing some unknown value in
286    /// the record attributes.
287    /// INVARIANT(restricted): Each `Expr` in this map must be a `RestrictedExpr`.
288    /// INVARIANT(unknown): At least one `Expr` must contain an `unknown`.
289    RestrictedResidual(Arc<BTreeMap<SmolStr, Expr>>),
290}
291
292impl Context {
293    /// Create an empty `Context`
294    pub fn empty() -> Self {
295        Self::Value(Arc::new(BTreeMap::new()))
296    }
297
298    /// Create a `Context` from a `PartialValue` without checking that the
299    /// residual is a restricted expression.  This function does check that the
300    /// value or residual is a record and returns `Err` when it is not.
301    ///
302    /// INVARIANT: if `value` is a residual, then it must be a valid restricted expression.
303    fn from_restricted_partial_val_unchecked(
304        value: PartialValue,
305    ) -> Result<Self, ContextCreationError> {
306        match value {
307            PartialValue::Value(v) => {
308                if let ValueKind::Record(attrs) = v.value {
309                    Ok(Context::Value(attrs))
310                } else {
311                    Err(ContextCreationError::not_a_record(v.into()))
312                }
313            }
314            PartialValue::Residual(e) => {
315                if let ExprKind::Record(attrs) = e.expr_kind() {
316                    // From the invariant on `PartialValue::Residual`, there is
317                    // an unknown in `e`. It is a record, so there must be an
318                    // unknown in one of the attributes expressions, satisfying
319                    // INVARIANT(unknown). From the invariant on this function,
320                    // `e` is a valid restricted expression, satisfying
321                    // INVARIANT(restricted).
322                    Ok(Context::RestrictedResidual(attrs.clone()))
323                } else {
324                    Err(ContextCreationError::not_a_record(e))
325                }
326            }
327        }
328    }
329
330    /// Create a `Context` from a `RestrictedExpr`, which must be a `Record`.
331    ///
332    /// `extensions` provides the `Extensions` which should be active for
333    /// evaluating the `RestrictedExpr`.
334    pub fn from_expr(
335        expr: BorrowedRestrictedExpr<'_>,
336        extensions: &Extensions<'_>,
337    ) -> Result<Self, ContextCreationError> {
338        match expr.expr_kind() {
339            ExprKind::Record { .. } => {
340                let evaluator = RestrictedEvaluator::new(extensions);
341                let pval = evaluator.partial_interpret(expr)?;
342                // The invariant on `from_restricted_partial_val_unchecked`
343                // is satisfied because `expr` is a restricted expression,
344                // and must still be restricted after `partial_interpret`.
345                // The function call cannot return `Err` because `expr` is a
346                // record, and partially evaluating a record expression will
347                // yield a record expression or a record value.
348                // PANIC SAFETY: See above
349                #[allow(clippy::expect_used)]
350                Ok(Self::from_restricted_partial_val_unchecked(pval).expect(
351                    "`from_restricted_partial_val_unchecked` should succeed when called on a record.",
352                ))
353            }
354            _ => Err(ContextCreationError::not_a_record(expr.to_owned().into())),
355        }
356    }
357
358    /// Create a `Context` from a map of key to `RestrictedExpr`, or a Vec of
359    /// `(key, RestrictedExpr)` pairs, or any other iterator of `(key, RestrictedExpr)` pairs
360    ///
361    /// `extensions` provides the `Extensions` which should be active for
362    /// evaluating the `RestrictedExpr`.
363    pub fn from_pairs(
364        pairs: impl IntoIterator<Item = (SmolStr, RestrictedExpr)>,
365        extensions: &Extensions<'_>,
366    ) -> Result<Self, ContextCreationError> {
367        match RestrictedExpr::record(pairs) {
368            Ok(record) => Self::from_expr(record.as_borrowed(), extensions),
369            Err(ExpressionConstructionError::DuplicateKey(err)) => Err(
370                ExpressionConstructionError::DuplicateKey(err.with_context("in context")).into(),
371            ),
372        }
373    }
374
375    /// Create a `Context` from a string containing JSON (which must be a JSON
376    /// object, not any other JSON type, or you will get an error here).
377    /// JSON here must use the `__entity` and `__extn` escapes for entity
378    /// references, extension values, etc.
379    ///
380    /// For schema-based parsing, use `ContextJsonParser`.
381    pub fn from_json_str(json: &str) -> Result<Self, ContextJsonDeserializationError> {
382        ContextJsonParser::new(None::<&NullContextSchema>, Extensions::all_available())
383            .from_json_str(json)
384    }
385
386    /// Create a `Context` from a `serde_json::Value` (which must be a JSON
387    /// object, not any other JSON type, or you will get an error here).
388    /// JSON here must use the `__entity` and `__extn` escapes for entity
389    /// references, extension values, etc.
390    ///
391    /// For schema-based parsing, use `ContextJsonParser`.
392    pub fn from_json_value(
393        json: serde_json::Value,
394    ) -> Result<Self, ContextJsonDeserializationError> {
395        ContextJsonParser::new(None::<&NullContextSchema>, Extensions::all_available())
396            .from_json_value(json)
397    }
398
399    /// Create a `Context` from a JSON file.  The JSON file must contain a JSON
400    /// object, not any other JSON type, or you will get an error here.
401    /// JSON here must use the `__entity` and `__extn` escapes for entity
402    /// references, extension values, etc.
403    ///
404    /// For schema-based parsing, use `ContextJsonParser`.
405    pub fn from_json_file(
406        json: impl std::io::Read,
407    ) -> Result<Self, ContextJsonDeserializationError> {
408        ContextJsonParser::new(None::<&NullContextSchema>, Extensions::all_available())
409            .from_json_file(json)
410    }
411
412    /// Get the number of keys in this `Context`.
413    pub fn num_keys(&self) -> usize {
414        match self {
415            Context::Value(record) => record.len(),
416            Context::RestrictedResidual(record) => record.len(),
417        }
418    }
419
420    /// Private helper function to implement `into_iter()` for `Context`.
421    /// Gets an iterator over the (key, value) pairs in the `Context`, cloning
422    /// only if necessary.
423    ///
424    /// Note that some error messages rely on this function returning keys in
425    /// sorted order, or else the error message will not be fully deterministic.
426    fn into_pairs(self) -> Box<dyn Iterator<Item = (SmolStr, RestrictedExpr)>> {
427        match self {
428            Context::Value(record) => Box::new(
429                Arc::unwrap_or_clone(record)
430                    .into_iter()
431                    .map(|(k, v)| (k, RestrictedExpr::from(v))),
432            ),
433            Context::RestrictedResidual(record) => Box::new(
434                Arc::unwrap_or_clone(record)
435                    .into_iter()
436                    // By INVARIANT(restricted), all attributes expressions are
437                    // restricted expressions.
438                    .map(|(k, v)| (k, RestrictedExpr::new_unchecked(v))),
439            ),
440        }
441    }
442
443    /// Substitute unknowns with concrete values in this context. If this is
444    /// already a `Context::Value`, then this returns `self` unchanged and will
445    /// not error. Otherwise delegate to [`Expr::substitute`].
446    pub fn substitute(self, mapping: &HashMap<SmolStr, Value>) -> Result<Self, EvaluationError> {
447        match self {
448            Context::RestrictedResidual(residual_context) => {
449                // From Invariant(Restricted), `residual_context` contains only
450                // restricted expressions, so `Expr::record_arc` of the attributes
451                // will also be a restricted expression. This doesn't change after
452                // substitution, so we know `expr` must be a restricted expression.
453                let expr = Expr::record_arc(residual_context).substitute(mapping);
454                let expr = BorrowedRestrictedExpr::new_unchecked(&expr);
455
456                let extns = Extensions::all_available();
457                let eval = RestrictedEvaluator::new(extns);
458                let partial_value = eval.partial_interpret(expr)?;
459
460                // The invariant on `from_restricted_partial_val_unchecked`
461                // is satisfied because `expr` is restricted and must still be
462                // restricted after `partial_interpret`.
463                // The function call cannot fail because because `expr` was
464                // constructed as a record, and substitution and partial
465                // evaluation does not change this.
466                // PANIC SAFETY: See above
467                #[allow(clippy::expect_used)]
468                Ok(
469                    Self::from_restricted_partial_val_unchecked(partial_value).expect(
470                        "`from_restricted_partial_val_unchecked` should succeed when called on a record.",
471                    ),
472                )
473            }
474            Context::Value(_) => Ok(self),
475        }
476    }
477}
478
479/// Utilities for implementing `IntoIterator` for `Context`
480mod iter {
481    use super::*;
482
483    /// `IntoIter` iterator for `Context`
484    pub struct IntoIter(pub(super) Box<dyn Iterator<Item = (SmolStr, RestrictedExpr)>>);
485
486    impl std::fmt::Debug for IntoIter {
487        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
488            write!(f, "IntoIter(<context>)")
489        }
490    }
491
492    impl Iterator for IntoIter {
493        type Item = (SmolStr, RestrictedExpr);
494
495        fn next(&mut self) -> Option<Self::Item> {
496            self.0.next()
497        }
498    }
499}
500
501impl IntoIterator for Context {
502    type Item = (SmolStr, RestrictedExpr);
503    type IntoIter = iter::IntoIter;
504
505    fn into_iter(self) -> Self::IntoIter {
506        iter::IntoIter(self.into_pairs())
507    }
508}
509
510impl From<Context> for RestrictedExpr {
511    fn from(value: Context) -> Self {
512        match value {
513            Context::Value(attrs) => Value::record_arc(attrs, None).into(),
514            Context::RestrictedResidual(attrs) => {
515                // By INVARIANT(restricted), all attributes expressions are
516                // restricted expressions, so the result of `record_arc` will be
517                // a restricted expression.
518                RestrictedExpr::new_unchecked(Expr::record_arc(attrs))
519            }
520        }
521    }
522}
523
524impl From<Context> for PartialValue {
525    fn from(ctx: Context) -> PartialValue {
526        match ctx {
527            Context::Value(attrs) => Value::record_arc(attrs, None).into(),
528            Context::RestrictedResidual(attrs) => {
529                // A `PartialValue::Residual` must contain an unknown in the
530                // expression. By INVARIANT(unknown), at least one expr in
531                // `attrs` contains an unknown, so the `record_arc` expression
532                // contains at least one unknown.
533                PartialValue::Residual(Expr::record_arc(attrs))
534            }
535        }
536    }
537}
538
539impl std::default::Default for Context {
540    fn default() -> Context {
541        Context::empty()
542    }
543}
544
545impl std::fmt::Display for Context {
546    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
547        write!(f, "{}", PartialValue::from(self.clone()))
548    }
549}
550
551impl BoundedDisplay for Context {
552    fn fmt(&self, f: &mut impl std::fmt::Write, n: Option<usize>) -> std::fmt::Result {
553        BoundedDisplay::fmt(&PartialValue::from(self.clone()), f, n)
554    }
555}
556
557/// Errors while trying to create a `Context`
558#[derive(Debug, Diagnostic, Error)]
559pub enum ContextCreationError {
560    /// Tried to create a `Context` out of something other than a record
561    #[error(transparent)]
562    #[diagnostic(transparent)]
563    NotARecord(#[from] context_creation_errors::NotARecord),
564    /// Error evaluating the expression given for the `Context`
565    #[error(transparent)]
566    #[diagnostic(transparent)]
567    Evaluation(#[from] EvaluationError),
568    /// Error constructing a record for the `Context`.
569    /// Only returned by `Context::from_pairs()` and `Context::merge()`
570    #[error(transparent)]
571    #[diagnostic(transparent)]
572    ExpressionConstruction(#[from] ExpressionConstructionError),
573}
574
575impl ContextCreationError {
576    pub(crate) fn not_a_record(expr: Expr) -> Self {
577        Self::NotARecord(context_creation_errors::NotARecord {
578            expr: Box::new(expr),
579        })
580    }
581}
582
583/// Error subtypes for [`ContextCreationError`]
584pub mod context_creation_errors {
585    use super::Expr;
586    use crate::impl_diagnostic_from_method_on_field;
587    use miette::Diagnostic;
588    use thiserror::Error;
589
590    /// Error type for an expression that needed to be a record, but is not
591    //
592    // CAUTION: this type is publicly exported in `cedar-policy`.
593    // Don't make fields `pub`, don't make breaking changes, and use caution
594    // when adding public methods.
595    #[derive(Debug, Error)]
596    #[error("expression is not a record: {expr}")]
597    pub struct NotARecord {
598        /// Expression which is not a record
599        pub(super) expr: Box<Expr>,
600    }
601
602    // custom impl of `Diagnostic`: take source location from the `expr` field's `.source_loc()` method
603    impl Diagnostic for NotARecord {
604        impl_diagnostic_from_method_on_field!(expr, source_loc);
605    }
606}
607
608/// Trait for schemas capable of validating `Request`s
609pub trait RequestSchema {
610    /// Error type returned when a request fails validation
611    type Error: miette::Diagnostic;
612    /// Validate the given `request`, returning `Err` if it fails validation
613    fn validate_request(
614        &self,
615        request: &Request,
616        extensions: &Extensions<'_>,
617    ) -> Result<(), Self::Error>;
618}
619
620/// A `RequestSchema` that does no validation and always reports a passing result
621#[derive(Debug, Clone)]
622pub struct RequestSchemaAllPass;
623impl RequestSchema for RequestSchemaAllPass {
624    type Error = Infallible;
625    fn validate_request(
626        &self,
627        _request: &Request,
628        _extensions: &Extensions<'_>,
629    ) -> Result<(), Self::Error> {
630        Ok(())
631    }
632}
633
634/// Wrapper around `std::convert::Infallible` which also implements
635/// `miette::Diagnostic`
636#[derive(Debug, Diagnostic, Error)]
637#[error(transparent)]
638pub struct Infallible(pub std::convert::Infallible);
639
640#[cfg(test)]
641mod test {
642    use super::*;
643    use cool_asserts::assert_matches;
644
645    #[test]
646    fn test_json_from_str_non_record() {
647        assert_matches!(
648            Context::from_expr(RestrictedExpr::val("1").as_borrowed(), Extensions::none()),
649            Err(ContextCreationError::NotARecord { .. })
650        );
651        assert_matches!(
652            Context::from_json_str("1"),
653            Err(ContextJsonDeserializationError::ContextCreation(
654                ContextCreationError::NotARecord { .. }
655            ))
656        );
657    }
658}