use crate::ast::*;
use crate::entities::{EntitiesError, EntityJson, JsonSerializationError};
use crate::evaluator::{EvaluationError, RestrictedEvaluator};
use crate::extensions::Extensions;
use crate::parser::err::ParseErrors;
use crate::parser::Loc;
use crate::transitive_closure::TCNode;
use crate::FromNormalizedStr;
use itertools::Itertools;
use miette::Diagnostic;
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, TryFromInto};
use smol_str::SmolStr;
use std::collections::{BTreeMap, HashMap, HashSet};
use thiserror::Error;
#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone, Hash, PartialOrd, Ord)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
pub enum EntityType {
Specified(Name),
Unspecified,
}
impl EntityType {
pub fn is_action(&self) -> bool {
match self {
Self::Specified(name) => name.basename() == &Id::new_unchecked("Action"),
Self::Unspecified => false,
}
}
}
impl std::fmt::Display for EntityType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Unspecified => write!(f, "<Unspecified>"),
Self::Specified(name) => write!(f, "{}", name),
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct EntityUID {
ty: EntityType,
eid: Eid,
#[serde(skip)]
loc: Option<Loc>,
}
impl PartialEq for EntityUID {
fn eq(&self, other: &Self) -> bool {
self.ty == other.ty && self.eid == other.eid
}
}
impl Eq for EntityUID {}
impl std::hash::Hash for EntityUID {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.ty.hash(state);
self.eid.hash(state);
}
}
impl PartialOrd for EntityUID {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for EntityUID {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.ty.cmp(&other.ty).then(self.eid.cmp(&other.eid))
}
}
impl StaticallyTyped for EntityUID {
fn type_of(&self) -> Type {
Type::Entity {
ty: self.ty.clone(),
}
}
}
impl EntityUID {
#[cfg(test)]
pub(crate) fn with_eid(eid: &str) -> Self {
Self {
ty: Self::test_entity_type(),
eid: Eid(eid.into()),
loc: None,
}
}
#[cfg(test)]
pub(crate) fn test_entity_type() -> EntityType {
let name = Name::parse_unqualified_name("test_entity_type")
.expect("test_entity_type should be a valid identifier");
EntityType::Specified(name)
}
pub fn with_eid_and_type(typename: &str, eid: &str) -> Result<Self, ParseErrors> {
Ok(Self {
ty: EntityType::Specified(Name::parse_unqualified_name(typename)?),
eid: Eid(eid.into()),
loc: None,
})
}
pub fn components(self) -> (EntityType, Eid) {
(self.ty, self.eid)
}
pub fn loc(&self) -> Option<&Loc> {
self.loc.as_ref()
}
pub fn from_components(name: Name, eid: Eid, loc: Option<Loc>) -> Self {
Self {
ty: EntityType::Specified(name),
eid,
loc,
}
}
pub fn unspecified_from_eid(eid: Eid) -> Self {
Self {
ty: EntityType::Unspecified,
eid,
loc: None,
}
}
pub fn entity_type(&self) -> &EntityType {
&self.ty
}
pub fn eid(&self) -> &Eid {
&self.eid
}
pub fn is_action(&self) -> bool {
self.entity_type().is_action()
}
}
impl std::fmt::Display for EntityUID {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}::\"{}\"", self.entity_type(), self.eid)
}
}
impl std::str::FromStr for EntityUID {
type Err = ParseErrors;
fn from_str(s: &str) -> Result<Self, Self::Err> {
crate::parser::parse_euid(s)
}
}
impl FromNormalizedStr for EntityUID {
fn describe_self() -> &'static str {
"Entity UID"
}
}
#[cfg(feature = "arbitrary")]
impl<'a> arbitrary::Arbitrary<'a> for EntityUID {
fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
Ok(Self {
ty: u.arbitrary()?,
eid: u.arbitrary()?,
loc: None,
})
}
}
#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone, Hash, PartialOrd, Ord)]
pub struct Eid(SmolStr);
impl Eid {
pub fn new(eid: impl Into<SmolStr>) -> Self {
Eid(eid.into())
}
pub fn escaped(&self) -> SmolStr {
self.0.escape_debug().collect()
}
}
impl AsRef<SmolStr> for Eid {
fn as_ref(&self) -> &SmolStr {
&self.0
}
}
impl AsRef<str> for Eid {
fn as_ref(&self) -> &str {
&self.0
}
}
#[cfg(feature = "arbitrary")]
impl<'a> arbitrary::Arbitrary<'a> for Eid {
fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
let x: String = u.arbitrary()?;
Ok(Self(x.into()))
}
}
impl std::fmt::Display for Eid {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0.escape_debug())
}
}
#[derive(Debug, Clone, Serialize)]
pub struct Entity {
uid: EntityUID,
attrs: BTreeMap<SmolStr, PartialValueSerializedAsExpr>,
ancestors: HashSet<EntityUID>,
}
impl std::hash::Hash for Entity {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.uid.hash(state);
}
}
impl Entity {
pub fn new(
uid: EntityUID,
attrs: HashMap<SmolStr, RestrictedExpr>,
ancestors: HashSet<EntityUID>,
extensions: &Extensions<'_>,
) -> Result<Self, EntityAttrEvaluationError> {
let evaluator = RestrictedEvaluator::new(extensions);
let evaluated_attrs = attrs
.into_iter()
.map(|(k, v)| {
let attr_val = evaluator
.partial_interpret(v.as_borrowed())
.map_err(|err| EntityAttrEvaluationError {
uid: uid.clone(),
attr: k.clone(),
err,
})?;
Ok((k, attr_val.into()))
})
.collect::<Result<_, EntityAttrEvaluationError>>()?;
Ok(Entity {
uid,
attrs: evaluated_attrs,
ancestors,
})
}
pub fn new_with_attr_partial_value(
uid: EntityUID,
attrs: HashMap<SmolStr, PartialValue>,
ancestors: HashSet<EntityUID>,
) -> Self {
Entity {
uid,
attrs: attrs.into_iter().map(|(k, v)| (k, v.into())).collect(), ancestors,
}
}
pub fn new_with_attr_partial_value_serialized_as_expr(
uid: EntityUID,
attrs: BTreeMap<SmolStr, PartialValueSerializedAsExpr>,
ancestors: HashSet<EntityUID>,
) -> Self {
Entity {
uid,
attrs,
ancestors,
}
}
pub fn uid(&self) -> &EntityUID {
&self.uid
}
pub fn get(&self, attr: &str) -> Option<&PartialValue> {
self.attrs.get(attr).map(|v| v.as_ref())
}
pub fn is_descendant_of(&self, e: &EntityUID) -> bool {
self.ancestors.contains(e)
}
pub fn ancestors(&self) -> impl Iterator<Item = &EntityUID> {
self.ancestors.iter()
}
pub fn attrs_len(&self) -> usize {
self.attrs.len()
}
pub fn keys(&self) -> impl Iterator<Item = &SmolStr> {
self.attrs.keys()
}
pub fn attrs(&self) -> impl Iterator<Item = (&SmolStr, &PartialValue)> {
self.attrs.iter().map(|(k, v)| (k, v.as_ref()))
}
pub fn with_uid(uid: EntityUID) -> Self {
Self {
uid,
attrs: BTreeMap::new(),
ancestors: HashSet::new(),
}
}
pub(crate) fn deep_eq(&self, other: &Self) -> bool {
self.uid == other.uid && self.attrs == other.attrs && self.ancestors == other.ancestors
}
#[cfg(any(test, fuzzing))]
pub fn set_attr(
&mut self,
attr: SmolStr,
val: RestrictedExpr,
extensions: &Extensions<'_>,
) -> Result<(), EvaluationError> {
let val = RestrictedEvaluator::new(extensions).partial_interpret(val.as_borrowed())?;
self.attrs.insert(attr, val.into());
Ok(())
}
#[cfg(not(fuzzing))]
pub(crate) fn add_ancestor(&mut self, uid: EntityUID) {
self.ancestors.insert(uid);
}
#[cfg(fuzzing)]
pub fn add_ancestor(&mut self, uid: EntityUID) {
self.ancestors.insert(uid);
}
pub fn into_inner(
self,
) -> (
EntityUID,
HashMap<SmolStr, PartialValue>,
HashSet<EntityUID>,
) {
let Self {
uid,
attrs,
ancestors,
} = self;
(
uid,
attrs.into_iter().map(|(k, v)| (k, v.0)).collect(),
ancestors,
)
}
pub fn write_to_json(&self, f: impl std::io::Write) -> Result<(), EntitiesError> {
let ejson = EntityJson::from_entity(self)?;
serde_json::to_writer_pretty(f, &ejson).map_err(JsonSerializationError::from)?;
Ok(())
}
pub fn to_json_value(&self) -> Result<serde_json::Value, EntitiesError> {
let ejson = EntityJson::from_entity(self)?;
let v = serde_json::to_value(ejson).map_err(JsonSerializationError::from)?;
Ok(v)
}
pub fn to_json_string(&self) -> Result<String, EntitiesError> {
let ejson = EntityJson::from_entity(self)?;
let string = serde_json::to_string(&ejson).map_err(JsonSerializationError::from)?;
Ok(string)
}
}
impl PartialEq for Entity {
fn eq(&self, other: &Self) -> bool {
self.uid() == other.uid()
}
}
impl Eq for Entity {}
impl StaticallyTyped for Entity {
fn type_of(&self) -> Type {
self.uid.type_of()
}
}
impl TCNode<EntityUID> for Entity {
fn get_key(&self) -> EntityUID {
self.uid().clone()
}
fn add_edge_to(&mut self, k: EntityUID) {
self.add_ancestor(k)
}
fn out_edges(&self) -> Box<dyn Iterator<Item = &EntityUID> + '_> {
Box::new(self.ancestors())
}
fn has_edge_to(&self, e: &EntityUID) -> bool {
self.is_descendant_of(e)
}
}
impl std::fmt::Display for Entity {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}:\n attrs:{}\n ancestors:{}",
self.uid,
self.attrs
.iter()
.map(|(k, v)| format!("{}: {}", k, v))
.join("; "),
self.ancestors.iter().join(", ")
)
}
}
#[serde_as]
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct PartialValueSerializedAsExpr(
#[serde_as(as = "TryFromInto<RestrictedExpr>")] PartialValue,
);
impl AsRef<PartialValue> for PartialValueSerializedAsExpr {
fn as_ref(&self) -> &PartialValue {
&self.0
}
}
impl std::ops::Deref for PartialValueSerializedAsExpr {
type Target = PartialValue;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl From<PartialValue> for PartialValueSerializedAsExpr {
fn from(value: PartialValue) -> PartialValueSerializedAsExpr {
PartialValueSerializedAsExpr(value)
}
}
impl From<PartialValueSerializedAsExpr> for PartialValue {
fn from(value: PartialValueSerializedAsExpr) -> PartialValue {
value.0
}
}
impl std::fmt::Display for PartialValueSerializedAsExpr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Diagnostic, Error)]
#[error("failed to evaluate attribute `{attr}` of `{uid}`: {err}")]
pub struct EntityAttrEvaluationError {
pub uid: EntityUID,
pub attr: SmolStr,
#[diagnostic(transparent)]
pub err: EvaluationError,
}
#[cfg(test)]
mod test {
use std::str::FromStr;
use super::*;
#[test]
fn display() {
let e = EntityUID::with_eid("eid");
assert_eq!(format!("{e}"), "test_entity_type::\"eid\"");
}
#[test]
fn test_euid_equality() {
let e1 = EntityUID::with_eid("foo");
let e2 = EntityUID::from_components(
Name::parse_unqualified_name("test_entity_type").expect("should be a valid identifier"),
Eid("foo".into()),
None,
);
let e3 = EntityUID::unspecified_from_eid(Eid("foo".into()));
let e4 = EntityUID::unspecified_from_eid(Eid("bar".into()));
let e5 = EntityUID::from_components(
Name::parse_unqualified_name("Unspecified").expect("should be a valid identifier"),
Eid("foo".into()),
None,
);
assert_eq!(e1, e1);
assert_eq!(e2, e2);
assert_eq!(e3, e3);
assert_eq!(e1, e2);
assert!(e1 != e3);
assert!(e1 != e4);
assert!(e1 != e5);
assert!(e3 != e4);
assert!(e3 != e5);
assert!(e4 != e5);
assert!(format!("{e3}") != format!("{e5}"));
}
#[test]
fn action_checker() {
let euid = EntityUID::from_str("Action::\"view\"").unwrap();
assert!(euid.is_action());
let euid = EntityUID::from_str("Foo::Action::\"view\"").unwrap();
assert!(euid.is_action());
let euid = EntityUID::from_str("Foo::\"view\"").unwrap();
assert!(!euid.is_action());
let euid = EntityUID::from_str("Action::Foo::\"view\"").unwrap();
assert!(!euid.is_action());
}
}