1use crate::entities::json::{
18 ContextJsonDeserializationError, ContextJsonParser, NullContextSchema,
19};
20use crate::evaluator::{EvaluationError, RestrictedEvaluator};
21use crate::extensions::Extensions;
22use crate::parser::Loc;
23use miette::Diagnostic;
24use serde::{Deserialize, Serialize};
25use smol_str::SmolStr;
26use std::collections::{BTreeMap, HashMap};
27use std::sync::Arc;
28use thiserror::Error;
29
30use super::{
31 BorrowedRestrictedExpr, BoundedDisplay, EntityType, EntityUID, Expr, ExprKind,
32 ExpressionConstructionError, PartialValue, RestrictedExpr, Unknown, Value, ValueKind, Var,
33};
34
35#[derive(Debug, Clone, Serialize)]
37pub struct Request {
38 pub(crate) principal: EntityUIDEntry,
40
41 pub(crate) action: EntityUIDEntry,
43
44 pub(crate) resource: EntityUIDEntry,
46
47 pub(crate) context: Option<Context>,
50}
51
52#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize, Serialize)]
54#[serde(rename_all = "camelCase")]
55pub struct RequestType {
56 pub principal: EntityType,
58 pub action: EntityUID,
60 pub resource: EntityType,
62}
63
64#[derive(Debug, Clone, Serialize)]
68pub enum EntityUIDEntry {
69 Known {
71 euid: Arc<EntityUID>,
73 loc: Option<Loc>,
75 },
76 Unknown {
78 ty: Option<EntityType>,
80
81 loc: Option<Loc>,
83 },
84}
85
86impl EntityUIDEntry {
87 pub fn evaluate(&self, var: Var) -> PartialValue {
91 match self {
92 EntityUIDEntry::Known { euid, loc } => {
93 Value::new(Arc::unwrap_or_clone(Arc::clone(euid)), loc.clone()).into()
94 }
95 EntityUIDEntry::Unknown { ty: None, loc } => {
96 Expr::unknown(Unknown::new_untyped(var.to_string()))
97 .with_maybe_source_loc(loc.clone())
98 .into()
99 }
100 EntityUIDEntry::Unknown {
101 ty: Some(known_type),
102 loc,
103 } => Expr::unknown(Unknown::new_with_type(
104 var.to_string(),
105 super::Type::Entity {
106 ty: known_type.clone(),
107 },
108 ))
109 .with_maybe_source_loc(loc.clone())
110 .into(),
111 }
112 }
113
114 pub fn known(euid: EntityUID, loc: Option<Loc>) -> Self {
116 Self::Known {
117 euid: Arc::new(euid),
118 loc,
119 }
120 }
121
122 pub fn unknown() -> Self {
124 Self::Unknown {
125 ty: None,
126 loc: None,
127 }
128 }
129
130 pub fn unknown_with_type(ty: EntityType, loc: Option<Loc>) -> Self {
132 Self::Unknown { ty: Some(ty), loc }
133 }
134
135 pub fn uid(&self) -> Option<&EntityUID> {
137 match self {
138 Self::Known { euid, .. } => Some(euid),
139 Self::Unknown { .. } => None,
140 }
141 }
142
143 pub fn get_type(&self) -> Option<&EntityType> {
145 match self {
146 Self::Known { euid, .. } => Some(euid.entity_type()),
147 Self::Unknown { ty, .. } => ty.as_ref(),
148 }
149 }
150}
151
152impl Request {
153 pub fn new<S: RequestSchema>(
158 principal: (EntityUID, Option<Loc>),
159 action: (EntityUID, Option<Loc>),
160 resource: (EntityUID, Option<Loc>),
161 context: Context,
162 schema: Option<&S>,
163 extensions: &Extensions<'_>,
164 ) -> Result<Self, S::Error> {
165 let req = Self {
166 principal: EntityUIDEntry::known(principal.0, principal.1),
167 action: EntityUIDEntry::known(action.0, action.1),
168 resource: EntityUIDEntry::known(resource.0, resource.1),
169 context: Some(context),
170 };
171 if let Some(schema) = schema {
172 schema.validate_request(&req, extensions)?;
173 }
174 Ok(req)
175 }
176
177 pub fn new_with_unknowns<S: RequestSchema>(
183 principal: EntityUIDEntry,
184 action: EntityUIDEntry,
185 resource: EntityUIDEntry,
186 context: Option<Context>,
187 schema: Option<&S>,
188 extensions: &Extensions<'_>,
189 ) -> Result<Self, S::Error> {
190 let req = Self {
191 principal,
192 action,
193 resource,
194 context,
195 };
196 if let Some(schema) = schema {
197 schema.validate_request(&req, extensions)?;
198 }
199 Ok(req)
200 }
201
202 pub fn new_unchecked(
205 principal: EntityUIDEntry,
206 action: EntityUIDEntry,
207 resource: EntityUIDEntry,
208 context: Option<Context>,
209 ) -> Self {
210 Self {
211 principal,
212 action,
213 resource,
214 context,
215 }
216 }
217
218 pub fn principal(&self) -> &EntityUIDEntry {
220 &self.principal
221 }
222
223 pub fn action(&self) -> &EntityUIDEntry {
225 &self.action
226 }
227
228 pub fn resource(&self) -> &EntityUIDEntry {
230 &self.resource
231 }
232
233 pub fn context(&self) -> Option<&Context> {
236 self.context.as_ref()
237 }
238
239 pub fn to_request_type(&self) -> Option<RequestType> {
245 Some(RequestType {
246 principal: self.principal().uid()?.entity_type().clone(),
247 action: self.action().uid()?.clone(),
248 resource: self.resource().uid()?.entity_type().clone(),
249 })
250 }
251}
252
253impl std::fmt::Display for Request {
254 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255 let display_euid = |maybe_euid: &EntityUIDEntry| match maybe_euid {
256 EntityUIDEntry::Known { euid, .. } => format!("{euid}"),
257 EntityUIDEntry::Unknown { ty: None, .. } => "unknown".to_string(),
258 EntityUIDEntry::Unknown {
259 ty: Some(known_type),
260 ..
261 } => format!("unknown of type {}", known_type),
262 };
263 write!(
264 f,
265 "request with principal {}, action {}, resource {}, and context {}",
266 display_euid(&self.principal),
267 display_euid(&self.action),
268 display_euid(&self.resource),
269 match &self.context {
270 Some(x) => format!("{x}"),
271 None => "unknown".to_string(),
272 }
273 )
274 }
275}
276
277#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
279#[serde(into = "RestrictedExpr")]
282pub enum Context {
283 Value(Arc<BTreeMap<SmolStr, Value>>),
285 RestrictedResidual(Arc<BTreeMap<SmolStr, Expr>>),
290}
291
292impl Context {
293 pub fn empty() -> Self {
295 Self::Value(Arc::new(BTreeMap::new()))
296 }
297
298 fn from_restricted_partial_val_unchecked(
304 value: PartialValue,
305 ) -> Result<Self, ContextCreationError> {
306 match value {
307 PartialValue::Value(v) => {
308 if let ValueKind::Record(attrs) = v.value {
309 Ok(Context::Value(attrs))
310 } else {
311 Err(ContextCreationError::not_a_record(v.into()))
312 }
313 }
314 PartialValue::Residual(e) => {
315 if let ExprKind::Record(attrs) = e.expr_kind() {
316 Ok(Context::RestrictedResidual(attrs.clone()))
323 } else {
324 Err(ContextCreationError::not_a_record(e))
325 }
326 }
327 }
328 }
329
330 pub fn from_expr(
335 expr: BorrowedRestrictedExpr<'_>,
336 extensions: &Extensions<'_>,
337 ) -> Result<Self, ContextCreationError> {
338 match expr.expr_kind() {
339 ExprKind::Record { .. } => {
340 let evaluator = RestrictedEvaluator::new(extensions);
341 let pval = evaluator.partial_interpret(expr)?;
342 #[allow(clippy::expect_used)]
350 Ok(Self::from_restricted_partial_val_unchecked(pval).expect(
351 "`from_restricted_partial_val_unchecked` should succeed when called on a record.",
352 ))
353 }
354 _ => Err(ContextCreationError::not_a_record(expr.to_owned().into())),
355 }
356 }
357
358 pub fn from_pairs(
364 pairs: impl IntoIterator<Item = (SmolStr, RestrictedExpr)>,
365 extensions: &Extensions<'_>,
366 ) -> Result<Self, ContextCreationError> {
367 match RestrictedExpr::record(pairs) {
368 Ok(record) => Self::from_expr(record.as_borrowed(), extensions),
369 Err(ExpressionConstructionError::DuplicateKey(err)) => Err(
370 ExpressionConstructionError::DuplicateKey(err.with_context("in context")).into(),
371 ),
372 }
373 }
374
375 pub fn from_json_str(json: &str) -> Result<Self, ContextJsonDeserializationError> {
382 ContextJsonParser::new(None::<&NullContextSchema>, Extensions::all_available())
383 .from_json_str(json)
384 }
385
386 pub fn from_json_value(
393 json: serde_json::Value,
394 ) -> Result<Self, ContextJsonDeserializationError> {
395 ContextJsonParser::new(None::<&NullContextSchema>, Extensions::all_available())
396 .from_json_value(json)
397 }
398
399 pub fn from_json_file(
406 json: impl std::io::Read,
407 ) -> Result<Self, ContextJsonDeserializationError> {
408 ContextJsonParser::new(None::<&NullContextSchema>, Extensions::all_available())
409 .from_json_file(json)
410 }
411
412 pub fn num_keys(&self) -> usize {
414 match self {
415 Context::Value(record) => record.len(),
416 Context::RestrictedResidual(record) => record.len(),
417 }
418 }
419
420 fn into_pairs(self) -> Box<dyn Iterator<Item = (SmolStr, RestrictedExpr)>> {
427 match self {
428 Context::Value(record) => Box::new(
429 Arc::unwrap_or_clone(record)
430 .into_iter()
431 .map(|(k, v)| (k, RestrictedExpr::from(v))),
432 ),
433 Context::RestrictedResidual(record) => Box::new(
434 Arc::unwrap_or_clone(record)
435 .into_iter()
436 .map(|(k, v)| (k, RestrictedExpr::new_unchecked(v))),
439 ),
440 }
441 }
442
443 pub fn substitute(self, mapping: &HashMap<SmolStr, Value>) -> Result<Self, EvaluationError> {
447 match self {
448 Context::RestrictedResidual(residual_context) => {
449 let expr = Expr::record_arc(residual_context).substitute(mapping);
454 let expr = BorrowedRestrictedExpr::new_unchecked(&expr);
455
456 let extns = Extensions::all_available();
457 let eval = RestrictedEvaluator::new(extns);
458 let partial_value = eval.partial_interpret(expr)?;
459
460 #[allow(clippy::expect_used)]
468 Ok(
469 Self::from_restricted_partial_val_unchecked(partial_value).expect(
470 "`from_restricted_partial_val_unchecked` should succeed when called on a record.",
471 ),
472 )
473 }
474 Context::Value(_) => Ok(self),
475 }
476 }
477}
478
479mod iter {
481 use super::*;
482
483 pub struct IntoIter(pub(super) Box<dyn Iterator<Item = (SmolStr, RestrictedExpr)>>);
485
486 impl std::fmt::Debug for IntoIter {
487 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
488 write!(f, "IntoIter(<context>)")
489 }
490 }
491
492 impl Iterator for IntoIter {
493 type Item = (SmolStr, RestrictedExpr);
494
495 fn next(&mut self) -> Option<Self::Item> {
496 self.0.next()
497 }
498 }
499}
500
501impl IntoIterator for Context {
502 type Item = (SmolStr, RestrictedExpr);
503 type IntoIter = iter::IntoIter;
504
505 fn into_iter(self) -> Self::IntoIter {
506 iter::IntoIter(self.into_pairs())
507 }
508}
509
510impl From<Context> for RestrictedExpr {
511 fn from(value: Context) -> Self {
512 match value {
513 Context::Value(attrs) => Value::record_arc(attrs, None).into(),
514 Context::RestrictedResidual(attrs) => {
515 RestrictedExpr::new_unchecked(Expr::record_arc(attrs))
519 }
520 }
521 }
522}
523
524impl From<Context> for PartialValue {
525 fn from(ctx: Context) -> PartialValue {
526 match ctx {
527 Context::Value(attrs) => Value::record_arc(attrs, None).into(),
528 Context::RestrictedResidual(attrs) => {
529 PartialValue::Residual(Expr::record_arc(attrs))
534 }
535 }
536 }
537}
538
539impl std::default::Default for Context {
540 fn default() -> Context {
541 Context::empty()
542 }
543}
544
545impl std::fmt::Display for Context {
546 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
547 write!(f, "{}", PartialValue::from(self.clone()))
548 }
549}
550
551impl BoundedDisplay for Context {
552 fn fmt(&self, f: &mut impl std::fmt::Write, n: Option<usize>) -> std::fmt::Result {
553 BoundedDisplay::fmt(&PartialValue::from(self.clone()), f, n)
554 }
555}
556
557#[derive(Debug, Diagnostic, Error)]
559pub enum ContextCreationError {
560 #[error(transparent)]
562 #[diagnostic(transparent)]
563 NotARecord(#[from] context_creation_errors::NotARecord),
564 #[error(transparent)]
566 #[diagnostic(transparent)]
567 Evaluation(#[from] EvaluationError),
568 #[error(transparent)]
571 #[diagnostic(transparent)]
572 ExpressionConstruction(#[from] ExpressionConstructionError),
573}
574
575impl ContextCreationError {
576 pub(crate) fn not_a_record(expr: Expr) -> Self {
577 Self::NotARecord(context_creation_errors::NotARecord {
578 expr: Box::new(expr),
579 })
580 }
581}
582
583pub mod context_creation_errors {
585 use super::Expr;
586 use crate::impl_diagnostic_from_method_on_field;
587 use miette::Diagnostic;
588 use thiserror::Error;
589
590 #[derive(Debug, Error)]
596 #[error("expression is not a record: {expr}")]
597 pub struct NotARecord {
598 pub(super) expr: Box<Expr>,
600 }
601
602 impl Diagnostic for NotARecord {
604 impl_diagnostic_from_method_on_field!(expr, source_loc);
605 }
606}
607
608pub trait RequestSchema {
610 type Error: miette::Diagnostic;
612 fn validate_request(
614 &self,
615 request: &Request,
616 extensions: &Extensions<'_>,
617 ) -> Result<(), Self::Error>;
618}
619
620#[derive(Debug, Clone)]
622pub struct RequestSchemaAllPass;
623impl RequestSchema for RequestSchemaAllPass {
624 type Error = Infallible;
625 fn validate_request(
626 &self,
627 _request: &Request,
628 _extensions: &Extensions<'_>,
629 ) -> Result<(), Self::Error> {
630 Ok(())
631 }
632}
633
634#[derive(Debug, Diagnostic, Error)]
637#[error(transparent)]
638pub struct Infallible(pub std::convert::Infallible);
639
640#[cfg(test)]
641mod test {
642 use super::*;
643 use cool_asserts::assert_matches;
644
645 #[test]
646 fn test_json_from_str_non_record() {
647 assert_matches!(
648 Context::from_expr(RestrictedExpr::val("1").as_borrowed(), Extensions::none()),
649 Err(ContextCreationError::NotARecord { .. })
650 );
651 assert_matches!(
652 Context::from_json_str("1"),
653 Err(ContextJsonDeserializationError::ContextCreation(
654 ContextCreationError::NotARecord { .. }
655 ))
656 );
657 }
658}