use std::sync::Arc;
use arrow_array::{Array, FixedSizeListArray, Float32Array};
use arrow_schema::ArrowError;
pub mod cosine;
pub mod dot;
pub mod hamming;
pub mod l2;
pub mod norm_l2;
pub use cosine::*;
use deepsize::DeepSizeOf;
pub use dot::*;
pub use l2::*;
pub use norm_l2::*;
use crate::Result;
#[derive(Debug, Copy, Clone, PartialEq, DeepSizeOf)]
pub enum DistanceType {
L2,
Cosine,
Dot,
Hamming,
}
pub type MetricType = DistanceType;
pub type DistanceFunc<T> = fn(&[T], &[T]) -> f32;
pub type BatchDistanceFunc = fn(&[f32], &[f32], usize) -> Arc<Float32Array>;
pub type ArrowBatchDistanceFunc = fn(&dyn Array, &FixedSizeListArray) -> Result<Arc<Float32Array>>;
impl DistanceType {
pub fn arrow_batch_func(&self) -> ArrowBatchDistanceFunc {
match self {
Self::L2 => l2_distance_arrow_batch,
Self::Cosine => cosine_distance_arrow_batch,
Self::Dot => dot_distance_arrow_batch,
Self::Hamming => todo!(),
}
}
pub fn func<T: L2 + Cosine + Dot>(&self) -> DistanceFunc<T> {
match self {
Self::L2 => l2,
Self::Cosine => cosine_distance,
Self::Dot => dot_distance,
Self::Hamming => todo!(),
}
}
}
impl std::fmt::Display for DistanceType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
Self::L2 => "l2",
Self::Cosine => "cosine",
Self::Dot => "dot",
Self::Hamming => "hamming",
}
)
}
}
impl TryFrom<&str> for DistanceType {
type Error = ArrowError;
fn try_from(s: &str) -> std::result::Result<Self, Self::Error> {
match s.to_lowercase().as_str() {
"l2" | "euclidean" => Ok(Self::L2),
"cosine" => Ok(Self::Cosine),
"dot" => Ok(Self::Dot),
"hamming" => Ok(Self::Hamming),
_ => Err(ArrowError::InvalidArgumentError(format!(
"Metric type '{s}' is not supported"
))),
}
}
}