cedar_policy_core/ast/
partial_value.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use super::{BoundedDisplay, Expr, Unknown, Value};
18use crate::parser::Loc;
19use itertools::Either;
20use miette::Diagnostic;
21use thiserror::Error;
22
23/// Intermediate results of partial evaluation
24#[derive(Debug, Clone, PartialEq, Eq)]
25pub enum PartialValue {
26    /// Fully evaluated values
27    Value(Value),
28    /// Residual expressions containing unknowns
29    /// INVARIANT: A residual _must_ have an unknown contained within
30    Residual(Expr),
31}
32
33impl PartialValue {
34    /// Create a new `PartialValue` consisting of just this single `Unknown`
35    pub fn unknown(u: Unknown) -> Self {
36        Self::Residual(Expr::unknown(u))
37    }
38
39    /// Return the `PartialValue`, but with the given `Loc` (or `None`)
40    pub fn with_maybe_source_loc(self, loc: Option<Loc>) -> Self {
41        match self {
42            Self::Value(v) => Self::Value(v.with_maybe_source_loc(loc)),
43            Self::Residual(e) => Self::Residual(e.with_maybe_source_loc(loc)),
44        }
45    }
46}
47
48impl<V: Into<Value>> From<V> for PartialValue {
49    fn from(into_v: V) -> Self {
50        PartialValue::Value(into_v.into())
51    }
52}
53
54impl From<Expr> for PartialValue {
55    fn from(e: Expr) -> Self {
56        debug_assert!(e.contains_unknown());
57        PartialValue::Residual(e)
58    }
59}
60
61/// Errors encountered when converting `PartialValue` to `Value`
62// CAUTION: this type is publicly exported in `cedar-policy`.
63#[derive(Debug, PartialEq, Eq, Diagnostic, Error)]
64pub enum PartialValueToValueError {
65    /// The `PartialValue` is a residual, i.e., contains an unknown
66    #[diagnostic(transparent)]
67    #[error(transparent)]
68    ContainsUnknown(#[from] ContainsUnknown),
69}
70
71/// The `PartialValue` is a residual, i.e., contains an unknown
72// CAUTION: this type is publicly exported in `cedar-policy`.
73// Don't make fields `pub`, don't make breaking changes, and use caution
74// when adding public methods.
75#[derive(Debug, PartialEq, Eq, Diagnostic, Error)]
76#[error("value contains a residual expression: `{residual}`")]
77pub struct ContainsUnknown {
78    /// Residual expression which contains an unknown
79    residual: Expr,
80}
81
82impl TryFrom<PartialValue> for Value {
83    type Error = PartialValueToValueError;
84
85    fn try_from(value: PartialValue) -> Result<Self, Self::Error> {
86        match value {
87            PartialValue::Value(v) => Ok(v),
88            PartialValue::Residual(e) => Err(ContainsUnknown { residual: e }.into()),
89        }
90    }
91}
92
93impl std::fmt::Display for PartialValue {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        match self {
96            PartialValue::Value(v) => write!(f, "{v}"),
97            PartialValue::Residual(r) => write!(f, "{r}"),
98        }
99    }
100}
101
102impl BoundedDisplay for PartialValue {
103    fn fmt(&self, f: &mut impl std::fmt::Write, n: Option<usize>) -> std::fmt::Result {
104        match self {
105            PartialValue::Value(v) => BoundedDisplay::fmt(v, f, n),
106            PartialValue::Residual(r) => BoundedDisplay::fmt(r, f, n),
107        }
108    }
109}
110
111/// Collect an iterator of either residuals or values into one of the following
112///  a) An iterator over values, if everything evaluated to values
113///  b) An iterator over residuals expressions, if anything only evaluated to a residual
114/// Order is preserved.
115pub fn split<I>(i: I) -> Either<impl Iterator<Item = Value>, impl Iterator<Item = Expr>>
116where
117    I: IntoIterator<Item = PartialValue>,
118{
119    let mut values = vec![];
120    let mut residuals = vec![];
121
122    for item in i.into_iter() {
123        match item {
124            PartialValue::Value(a) => {
125                if residuals.is_empty() {
126                    values.push(a)
127                } else {
128                    residuals.push(a.into())
129                }
130            }
131            PartialValue::Residual(r) => {
132                residuals.push(r);
133            }
134        }
135    }
136
137    if residuals.is_empty() {
138        Either::Left(values.into_iter())
139    } else {
140        let mut exprs: Vec<Expr> = values.into_iter().map(|x| x.into()).collect();
141        exprs.append(&mut residuals);
142        Either::Right(exprs.into_iter())
143    }
144}
145
146// PANIC SAFETY: Unit Test Code
147#[allow(clippy::panic)]
148#[cfg(test)]
149mod test {
150    use super::*;
151
152    #[test]
153    fn split_values() {
154        let vs = [
155            PartialValue::Value(Value::from(1)),
156            PartialValue::Value(Value::from(2)),
157        ];
158        match split(vs) {
159            Either::Right(_) => panic!("expected values, got residuals"),
160            Either::Left(vs) => {
161                assert_eq!(vs.collect::<Vec<_>>(), vec![Value::from(1), Value::from(2)])
162            }
163        };
164    }
165
166    #[test]
167    fn split_residuals() {
168        let rs = [
169            PartialValue::Value(Value::from(1)),
170            PartialValue::Residual(Expr::val(2)),
171            PartialValue::Value(Value::from(3)),
172            PartialValue::Residual(Expr::val(4)),
173        ];
174        let expected = vec![Expr::val(1), Expr::val(2), Expr::val(3), Expr::val(4)];
175        match split(rs) {
176            Either::Left(_) => panic!("expected residuals, got values"),
177            Either::Right(rs) => {
178                assert_eq!(rs.collect::<Vec<_>>(), expected);
179            }
180        };
181    }
182
183    #[test]
184    fn split_residuals2() {
185        let rs = [
186            PartialValue::Value(Value::from(1)),
187            PartialValue::Value(Value::from(2)),
188            PartialValue::Residual(Expr::val(3)),
189            PartialValue::Residual(Expr::val(4)),
190        ];
191        let expected = vec![Expr::val(1), Expr::val(2), Expr::val(3), Expr::val(4)];
192        match split(rs) {
193            Either::Left(_) => panic!("expected residuals, got values"),
194            Either::Right(rs) => {
195                assert_eq!(rs.collect::<Vec<_>>(), expected);
196            }
197        };
198    }
199
200    #[test]
201    fn split_residuals3() {
202        let rs = [
203            PartialValue::Residual(Expr::val(1)),
204            PartialValue::Residual(Expr::val(2)),
205            PartialValue::Value(Value::from(3)),
206            PartialValue::Value(Value::from(4)),
207        ];
208        let expected = vec![Expr::val(1), Expr::val(2), Expr::val(3), Expr::val(4)];
209        match split(rs) {
210            Either::Left(_) => panic!("expected residuals, got values"),
211            Either::Right(rs) => {
212                assert_eq!(rs.collect::<Vec<_>>(), expected);
213            }
214        };
215    }
216}