use std::ops::Range;
use std::sync::Arc;
use arrow_array::{Array, FixedSizeListArray, RecordBatch, UInt32Array};
pub use builder::IvfBuildParams;
use lance_core::Result;
use lance_linalg::{
distance::{DistanceType, MetricType},
kmeans::{compute_partitions_arrow_array, kmeans_find_partitions_arrow_array},
};
use tracing::instrument;
use crate::vector::ivf::transform::PartitionTransformer;
use crate::vector::{pq::ProductQuantizer, residual::ResidualTransform, transform::Transformer};
use super::pq::transform::PQTransformer;
use super::quantizer::Quantization;
use super::{quantizer::Quantizer, residual::compute_residual};
use super::{PART_ID_COLUMN, PQ_CODE_COLUMN};
pub mod builder;
pub mod shuffler;
pub mod storage;
mod transform;
pub fn new_ivf_transformer(
centroids: FixedSizeListArray,
metric_type: DistanceType,
transforms: Vec<Arc<dyn Transformer>>,
) -> IvfTransformer {
IvfTransformer::new(centroids, metric_type, transforms)
}
pub fn new_ivf_transformer_with_quantizer(
centroids: FixedSizeListArray,
metric_type: MetricType,
vector_column: &str,
quantizer: Quantizer,
range: Option<Range<u32>>,
) -> Result<IvfTransformer> {
match quantizer {
Quantizer::Flat(_) | Quantizer::FlatBin(_) => Ok(IvfTransformer::new_flat(
centroids,
metric_type,
vector_column,
range,
)),
Quantizer::Product(pq) => Ok(IvfTransformer::with_pq(
centroids,
metric_type,
vector_column,
pq,
range,
false,
)),
Quantizer::Scalar(_) => Ok(IvfTransformer::with_sq(
centroids,
metric_type,
vector_column,
range,
)),
}
}
#[derive(Debug)]
pub struct IvfTransformer {
centroids: FixedSizeListArray,
transforms: Vec<Arc<dyn Transformer>>,
distance_type: DistanceType,
}
impl IvfTransformer {
pub fn new(
centroids: FixedSizeListArray,
metric_type: MetricType,
transforms: Vec<Arc<dyn Transformer>>,
) -> Self {
Self {
centroids,
distance_type: metric_type,
transforms,
}
}
pub fn new_flat(
centroids: FixedSizeListArray,
distance_type: DistanceType,
vector_column: &str,
range: Option<Range<u32>>,
) -> Self {
let mut transforms: Vec<Arc<dyn Transformer>> =
vec![Arc::new(super::transform::Flatten::new(vector_column))];
let dt = if distance_type == DistanceType::Cosine {
transforms.push(Arc::new(super::transform::NormalizeTransformer::new(
vector_column,
)));
MetricType::L2
} else {
distance_type
};
let ivf_transform = Arc::new(PartitionTransformer::new(
centroids.clone(),
dt,
vector_column,
));
transforms.push(ivf_transform);
if let Some(range) = range {
transforms.push(Arc::new(transform::PartitionFilter::new(
PART_ID_COLUMN,
range,
)));
}
Self {
centroids,
distance_type,
transforms,
}
}
pub fn with_pq(
centroids: FixedSizeListArray,
distance_type: DistanceType,
vector_column: &str,
pq: ProductQuantizer,
range: Option<Range<u32>>,
with_pq_code: bool, ) -> Self {
let mut transforms: Vec<Arc<dyn Transformer>> =
vec![Arc::new(super::transform::Flatten::new(vector_column))];
let mt = if distance_type == MetricType::Cosine {
transforms.push(Arc::new(super::transform::NormalizeTransformer::new(
vector_column,
)));
MetricType::L2
} else {
distance_type
};
let partition_transform = Arc::new(PartitionTransformer::new(
centroids.clone(),
mt,
vector_column,
));
transforms.push(partition_transform);
if let Some(range) = range {
transforms.push(Arc::new(transform::PartitionFilter::new(
PART_ID_COLUMN,
range,
)));
}
if ProductQuantizer::use_residual(distance_type) {
transforms.push(Arc::new(ResidualTransform::new(
centroids.clone(),
PART_ID_COLUMN,
vector_column,
)));
}
if with_pq_code {
transforms.push(Arc::new(PQTransformer::new(
pq,
vector_column,
PQ_CODE_COLUMN,
)));
}
Self {
centroids,
distance_type,
transforms,
}
}
fn with_sq(
centroids: FixedSizeListArray,
metric_type: MetricType,
vector_column: &str,
range: Option<Range<u32>>,
) -> Self {
let mut transforms: Vec<Arc<dyn Transformer>> =
vec![Arc::new(super::transform::Flatten::new(vector_column))];
let mt = if metric_type == MetricType::Cosine {
transforms.push(Arc::new(super::transform::NormalizeTransformer::new(
vector_column,
)));
MetricType::L2
} else {
metric_type
};
let partition_transformer = Arc::new(PartitionTransformer::new(
centroids.clone(),
mt,
vector_column,
));
transforms.push(partition_transformer);
if let Some(range) = range {
transforms.push(Arc::new(transform::PartitionFilter::new(
PART_ID_COLUMN,
range,
)));
}
Self {
centroids,
distance_type: metric_type,
transforms,
}
}
#[inline]
pub fn compute_residual(&self, data: &FixedSizeListArray) -> Result<FixedSizeListArray> {
compute_residual(&self.centroids, data, Some(self.distance_type), None)
}
#[inline]
pub fn compute_partitions(&self, data: &FixedSizeListArray) -> Result<UInt32Array> {
Ok(compute_partitions_arrow_array(&self.centroids, data, self.distance_type)?.into())
}
pub fn find_partitions(&self, query: &dyn Array, nprobes: usize) -> Result<UInt32Array> {
Ok(kmeans_find_partitions_arrow_array(
&self.centroids,
query,
nprobes,
self.distance_type,
)?)
}
}
impl Transformer for IvfTransformer {
#[instrument(name = "IvfTransformer::transform", level = "debug", skip_all)]
fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
let mut batch = batch.clone();
for transform in self.transforms.as_slice() {
batch = transform.transform(&batch)?;
}
Ok(batch)
}
}