#[cfg(feature = "ffi")]
mod ffi;
use opendp_derive::bootstrap;
use crate::core::{Domain, Function, Metric, MetricSpace, StabilityMap, Transformation};
use crate::domains::{AtomDomain, VectorDomain};
use crate::error::Fallible;
use crate::metrics::{InsertDeleteDistance, IntDistance, SymmetricDistance};
use crate::traits::samplers::Shuffle;
use crate::traits::CheckAtom;
use std::cmp::Ordering;
#[doc(hidden)]
pub trait IsMetricOrdered: Metric {
const ORDERED: bool;
}
impl IsMetricOrdered for SymmetricDistance {
const ORDERED: bool = false;
}
impl IsMetricOrdered for InsertDeleteDistance {
const ORDERED: bool = true;
}
#[bootstrap(
features("contrib"),
arguments(constant(rust_type = "$get_atom(get_type(input_domain))")),
generics(TA(suppress), MI(suppress), MO(default = "SymmetricDistance"))
)]
pub fn make_resize<TA, MI, MO>(
input_domain: VectorDomain<AtomDomain<TA>>,
input_metric: MI,
size: usize,
constant: TA,
) -> Fallible<Transformation<VectorDomain<AtomDomain<TA>>, VectorDomain<AtomDomain<TA>>, MI, MO>>
where
TA: 'static + Clone + CheckAtom,
MI: IsMetricOrdered<Distance = IntDistance>,
MO: IsMetricOrdered<Distance = IntDistance>,
(VectorDomain<AtomDomain<TA>>, MI): MetricSpace,
(VectorDomain<AtomDomain<TA>>, MO): MetricSpace,
{
if !input_domain.element_domain.member(&constant)? {
return fallible!(MakeTransformation, "constant must be a member of DA");
}
if size == 0 {
return fallible!(MakeTransformation, "row size must be greater than zero");
}
Transformation::new(
input_domain.clone(),
input_domain.with_size(size),
Function::new_fallible(move |arg: &Vec<TA>| {
Ok(match arg.len().cmp(&size) {
Ordering::Less | Ordering::Equal => {
let mut data = arg
.iter()
.chain(vec![&constant; size - arg.len()])
.cloned()
.collect::<Vec<TA>>();
if MO::ORDERED {
data.shuffle()?;
}
data
}
Ordering::Greater => {
let mut data = arg.clone();
if !MI::ORDERED {
data.shuffle()?;
}
data[..size].to_vec()
}
})
}),
input_metric,
MO::default(),
StabilityMap::new_from_constant(2),
)
}
#[cfg(test)]
mod test {
use super::*;
use crate::domains::AtomDomain;
#[test]
fn test() -> Fallible<()> {
let (input_domain, input_metric) = (
VectorDomain::new(AtomDomain::default()),
SymmetricDistance::default(),
);
let trans = make_resize::<_, SymmetricDistance, SymmetricDistance>(
input_domain,
input_metric,
3,
"x",
)?;
assert_eq!(trans.invoke(&vec!["A"; 2])?, vec!["A", "A", "x"]);
assert_eq!(trans.invoke(&vec!["A"; 3])?, vec!["A"; 3]);
assert_eq!(trans.invoke(&vec!["A"; 4])?, vec!["A", "A", "A"]);
assert!(trans.check(&1, &2)?);
assert!(!trans.check(&1, &1)?);
Ok(())
}
}