use std::collections::BTreeSet;
use fnv::{FnvHashMap, FnvHashSet};
use tracing::error;
use crate::id::{AttrId, Eid, Id128, PolicyId, PropId};
use super::code::{Bytecode, PolicyValue};
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum EvalError {
Program,
Type,
}
#[derive(Default, Debug)]
pub struct AccessControlParams {
pub subject_eids: FnvHashMap<PropId, Eid>,
pub subject_attrs: FnvHashSet<AttrId>,
pub resource_eids: FnvHashMap<PropId, Eid>,
pub resource_attrs: FnvHashSet<AttrId>,
}
#[derive(Default, Debug)]
pub struct PolicyEngine {
policies: FnvHashMap<PolicyId, Policy>,
trigger_groups: FnvHashMap<AttrId, Vec<PolicyTrigger>>,
}
#[derive(Debug)]
struct PolicyTrigger {
pub attr_matcher: BTreeSet<AttrId>,
pub policy_ids: BTreeSet<PolicyId>,
}
#[allow(unused)]
pub trait PolicyTracer {
fn report_applicable(&mut self, class: PolicyValue, policies: impl Iterator<Item = PolicyId>) {}
fn report_policy_eval_start(&mut self, policy_id: PolicyId) {}
fn report_policy_eval_end(&mut self, value: bool) {}
}
pub struct NoOpPolicyTracer;
impl PolicyTracer for NoOpPolicyTracer {}
#[derive(Debug)]
struct Policy {
class: PolicyValue,
bytecode: Vec<u8>,
}
#[derive(PartialEq, Eq, Debug)]
enum StackItem<'a> {
Uint(u64),
AttrIdSet(&'a FnvHashSet<AttrId>),
EntityId(Eid),
AttrId(AttrId),
}
#[derive(Debug)]
struct EvalCtx<'e> {
applicable_allow: FnvHashMap<PolicyId, &'e Policy>,
applicable_deny: FnvHashMap<PolicyId, &'e Policy>,
}
impl PolicyEngine {
pub fn add_policy(&mut self, id: PolicyId, class: PolicyValue, bytecode: Vec<u8>) {
self.policies.insert(id, Policy { class, bytecode });
}
pub fn add_trigger(
&mut self,
attr_matcher: impl Into<BTreeSet<AttrId>>,
policy_ids: impl Into<BTreeSet<PolicyId>>,
) {
let attr_matcher = attr_matcher.into();
let policy_ids = policy_ids.into();
if let Some(first_attr) = attr_matcher.iter().next() {
self.trigger_groups
.entry(*first_attr)
.or_default()
.push(PolicyTrigger {
attr_matcher,
policy_ids,
});
}
}
pub fn get_policy_count(&self) -> usize {
self.policies.len()
}
pub fn get_trigger_count(&self) -> usize {
self.trigger_groups.values().map(Vec::len).sum()
}
pub fn eval(
&self,
params: &AccessControlParams,
tracer: &mut impl PolicyTracer,
) -> Result<PolicyValue, EvalError> {
let mut eval_ctx = EvalCtx {
applicable_allow: Default::default(),
applicable_deny: Default::default(),
};
for attr in ¶ms.subject_attrs {
self.collect_applicable(*attr, params, &mut eval_ctx)?;
}
for attr in ¶ms.resource_attrs {
self.collect_applicable(*attr, params, &mut eval_ctx)?;
}
{
tracer.report_applicable(PolicyValue::Deny, eval_ctx.applicable_deny.keys().copied());
tracer.report_applicable(
PolicyValue::Allow,
eval_ctx.applicable_allow.keys().copied(),
);
}
let has_allow = !eval_ctx.applicable_allow.is_empty();
let has_deny = !eval_ctx.applicable_deny.is_empty();
match (has_allow, has_deny) {
(false, false) => {
for subj_attr in ¶ms.subject_attrs {
if params.resource_attrs.contains(subj_attr) {
return Ok(PolicyValue::Allow);
}
}
Ok(PolicyValue::Deny)
}
(true, false) => {
let is_allow =
eval_policies_disjunctive(eval_ctx.applicable_allow, params, tracer)?;
Ok(PolicyValue::from(is_allow))
}
(false, true) => {
let is_deny = eval_policies_disjunctive(eval_ctx.applicable_deny, params, tracer)?;
Ok(PolicyValue::from(!is_deny))
}
(true, true) => {
let is_allow =
eval_policies_disjunctive(eval_ctx.applicable_allow, params, tracer)?;
if !is_allow {
return Ok(PolicyValue::Deny);
}
let is_deny = eval_policies_disjunctive(eval_ctx.applicable_deny, params, tracer)?;
Ok(PolicyValue::from(!is_deny))
}
}
}
fn collect_applicable<'e>(
&'e self,
attr: AttrId,
params: &AccessControlParams,
eval_ctx: &mut EvalCtx<'e>,
) -> Result<(), EvalError> {
let Some(policy_triggers) = self.trigger_groups.get(&attr) else {
return Ok(());
};
for policy_trigger in policy_triggers {
if policy_trigger.attr_matcher.len() > 1 {
let mut matches: BTreeSet<AttrId> = Default::default();
for attrs in [¶ms.subject_attrs, ¶ms.resource_attrs] {
for attr in attrs {
if policy_trigger.attr_matcher.contains(attr) {
matches.insert(*attr);
}
}
}
if matches != policy_trigger.attr_matcher {
continue;
}
}
for policy_id in policy_trigger.policy_ids.iter().copied() {
let Some(policy) = self.policies.get(&policy_id) else {
error!(?policy_id, "policy is missing");
continue;
};
match policy.class {
PolicyValue::Deny => {
eval_ctx.applicable_deny.insert(policy_id, policy);
}
PolicyValue::Allow => {
eval_ctx.applicable_allow.insert(policy_id, policy);
}
}
}
}
Ok(())
}
}
fn eval_policies_disjunctive(
map: FnvHashMap<PolicyId, &Policy>,
params: &AccessControlParams,
tracer: &mut impl PolicyTracer,
) -> Result<bool, EvalError> {
for (policy_id, policy) in &map {
tracer.report_policy_eval_start(*policy_id);
let value = eval_policy(&policy.bytecode, params)?;
tracer.report_policy_eval_end(value);
if value {
return Ok(true);
}
}
Ok(false)
}
fn eval_policy(mut pc: &[u8], params: &AccessControlParams) -> Result<bool, EvalError> {
let mut stack: Vec<StackItem> = Vec::with_capacity(16);
while let Some(code) = pc.first() {
pc = &pc[1..];
let Ok(code) = Bytecode::try_from(*code) else {
return Err(EvalError::Program);
};
match code {
Bytecode::LoadSubjectId => {
let (key, next) = decode_id(pc)?;
let Some(id) = params.subject_eids.get(&key) else {
return Err(EvalError::Type);
};
stack.push(StackItem::EntityId(*id));
pc = next;
}
Bytecode::LoadSubjectAttrs => {
stack.push(StackItem::AttrIdSet(¶ms.subject_attrs));
}
Bytecode::LoadResourceId => {
let (key, next) = decode_id(pc)?;
let Some(id) = params.resource_eids.get(&key) else {
return Err(EvalError::Type);
};
stack.push(StackItem::EntityId(*id));
pc = next;
}
Bytecode::LoadResourceAttrs => {
stack.push(StackItem::AttrIdSet(¶ms.resource_attrs));
}
Bytecode::LoadConstEntityId => {
let (id, next) = decode_id(pc)?;
stack.push(StackItem::EntityId(id));
pc = next;
}
Bytecode::LoadConstAttrId => {
let (id, next) = decode_id(pc)?;
stack.push(StackItem::AttrId(id));
pc = next;
}
Bytecode::IsEq => {
let Some(a) = stack.pop() else {
return Err(EvalError::Type);
};
let Some(b) = stack.pop() else {
return Err(EvalError::Type);
};
let is_eq = match (a, b) {
(StackItem::AttrId(a), StackItem::AttrId(b)) => a == b,
(StackItem::EntityId(a), StackItem::EntityId(b)) => a == b,
(StackItem::AttrIdSet(set), StackItem::AttrId(id)) => set.contains(&id),
(StackItem::AttrId(id), StackItem::AttrIdSet(set)) => set.contains(&id),
_ => false,
};
stack.push(StackItem::Uint(if is_eq { 1 } else { 0 }));
}
Bytecode::SupersetOf => {
let Some(StackItem::AttrIdSet(a)) = stack.pop() else {
return Err(EvalError::Type);
};
let Some(StackItem::AttrIdSet(b)) = stack.pop() else {
return Err(EvalError::Type);
};
stack.push(StackItem::Uint(if a.is_superset(b) { 1 } else { 0 }));
}
Bytecode::IdSetContains => {
let Some(a) = stack.pop() else {
return Err(EvalError::Type);
};
let Some(b) = stack.pop() else {
return Err(EvalError::Type);
};
match (a, b) {
(StackItem::AttrIdSet(a), StackItem::AttrId(b)) => {
stack.push(StackItem::Uint(if a.contains(&b) { 1 } else { 0 }));
}
_ => {
return Err(EvalError::Type);
}
}
}
Bytecode::And => {
let Some(StackItem::Uint(rhs)) = stack.pop() else {
return Err(EvalError::Type);
};
let Some(StackItem::Uint(lhs)) = stack.pop() else {
return Err(EvalError::Type);
};
stack.push(StackItem::Uint(if rhs > 0 && lhs > 0 { 1 } else { 0 }));
}
Bytecode::Or => {
let Some(StackItem::Uint(rhs)) = stack.pop() else {
return Err(EvalError::Type);
};
let Some(StackItem::Uint(lhs)) = stack.pop() else {
return Err(EvalError::Type);
};
stack.push(StackItem::Uint(if rhs > 0 || lhs > 0 { 1 } else { 0 }));
}
Bytecode::Not => {
let Some(StackItem::Uint(val)) = stack.pop() else {
return Err(EvalError::Type);
};
stack.push(StackItem::Uint(if val > 0 { 0 } else { 1 }));
}
Bytecode::Return => {
let Some(StackItem::Uint(u)) = stack.pop() else {
return Err(EvalError::Type);
};
return Ok(u > 0);
}
}
}
Err(EvalError::Program)
}
#[inline]
fn decode_id<K>(buf: &[u8]) -> Result<(Id128<K>, &[u8]), EvalError> {
let (uint, next) = unsigned_varint::decode::u128(buf).map_err(|_| EvalError::Program)?;
Ok((Id128::from_uint(uint), next))
}