use std::sync::Arc;
use arrow_array::cast::AsArray;
use arrow_array::types::{Float16Type, Float32Type, Float64Type, UInt8Type};
use arrow_array::{Array, ArrowPrimitiveType, FixedSizeListArray, Float32Array, ListArray};
use arrow_schema::{ArrowError, DataType};
pub mod cosine;
pub mod dot;
pub mod hamming;
pub mod l2;
pub mod norm_l2;
pub use cosine::*;
use deepsize::DeepSizeOf;
pub use dot::*;
use hamming::hamming_distance_arrow_batch;
pub use l2::*;
pub use norm_l2::*;
use crate::Result;
#[derive(Debug, Copy, Clone, PartialEq, DeepSizeOf)]
pub enum DistanceType {
L2,
Cosine,
Dot,
Hamming,
}
pub type MetricType = DistanceType;
pub type DistanceFunc<T> = fn(&[T], &[T]) -> f32;
pub type BatchDistanceFunc = fn(&[f32], &[f32], usize) -> Arc<Float32Array>;
pub type ArrowBatchDistanceFunc = fn(&dyn Array, &FixedSizeListArray) -> Result<Arc<Float32Array>>;
impl DistanceType {
pub fn arrow_batch_func(&self) -> ArrowBatchDistanceFunc {
match self {
Self::L2 => l2_distance_arrow_batch,
Self::Cosine => cosine_distance_arrow_batch,
Self::Dot => dot_distance_arrow_batch,
Self::Hamming => hamming_distance_arrow_batch,
}
}
pub fn func<T: L2 + Cosine + Dot>(&self) -> DistanceFunc<T> {
match self {
Self::L2 => l2,
Self::Cosine => cosine_distance,
Self::Dot => dot_distance,
Self::Hamming => todo!(),
}
}
}
impl std::fmt::Display for DistanceType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
Self::L2 => "l2",
Self::Cosine => "cosine",
Self::Dot => "dot",
Self::Hamming => "hamming",
}
)
}
}
impl TryFrom<&str> for DistanceType {
type Error = ArrowError;
fn try_from(s: &str) -> std::result::Result<Self, Self::Error> {
match s.to_lowercase().as_str() {
"l2" | "euclidean" => Ok(Self::L2),
"cosine" => Ok(Self::Cosine),
"dot" => Ok(Self::Dot),
"hamming" => Ok(Self::Hamming),
_ => Err(ArrowError::InvalidArgumentError(format!(
"Metric type '{s}' is not supported"
))),
}
}
}
pub fn multivec_distance(
query: &dyn Array,
vectors: &ListArray,
distance_type: DistanceType,
) -> Result<Vec<f32>> {
let dim = if let DataType::FixedSizeList(_, dim) = vectors.value_type() {
dim as usize
} else {
return Err(ArrowError::InvalidArgumentError(
"vectors must be a list of fixed size list".to_string(),
));
};
match query.data_type() {
DataType::Float16 | DataType::Float32 | DataType::Float64 | DataType::UInt8 => {}
_ => {
return Err(ArrowError::InvalidArgumentError(
"query must be a float array or binary array".to_string(),
));
}
}
let dists = vectors
.iter()
.map(|v| {
v.map(|v| {
let multivector = v.as_fixed_size_list();
match distance_type {
DistanceType::Hamming => {
let query = query.as_primitive::<UInt8Type>().values();
query
.chunks_exact(dim)
.map(|q| {
multivector
.values()
.as_primitive::<UInt8Type>()
.values()
.chunks_exact(dim)
.map(|v| hamming::hamming(q, v))
.min_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap()
})
.sum()
}
_ => match query.data_type() {
DataType::Float16 => multivec_distance_impl::<Float16Type>(
query,
multivector,
dim,
distance_type,
),
DataType::Float32 => multivec_distance_impl::<Float32Type>(
query,
multivector,
dim,
distance_type,
),
DataType::Float64 => multivec_distance_impl::<Float64Type>(
query,
multivector,
dim,
distance_type,
),
_ => unreachable!("missed to check query type"),
},
}
})
.unwrap_or(f32::NAN)
})
.collect();
Ok(dists)
}
fn multivec_distance_impl<T: ArrowPrimitiveType>(
query: &dyn Array,
multivector: &FixedSizeListArray,
dim: usize,
distance_type: DistanceType,
) -> f32
where
T::Native: L2 + Cosine + Dot,
{
let query = query.as_primitive::<T>().values();
query
.chunks_exact(dim)
.map(|q| {
multivector
.values()
.as_primitive::<T>()
.values()
.chunks_exact(dim)
.map(|v| distance_type.func()(q, v))
.min_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap()
})
.sum()
}