use std::collections::HashMap;
use std::fmt::Debug;
use std::{any::Any, ops::Bound, sync::Arc};
use arrow::buffer::{OffsetBuffer, ScalarBuffer};
use arrow_array::{ListArray, RecordBatch};
use arrow_schema::{Field, Schema};
use async_trait::async_trait;
use datafusion::functions_array::array_has;
use datafusion::physical_plan::SendableRecordBatchStream;
use datafusion_common::{scalar::ScalarValue, Column};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::Expr;
use deepsize::DeepSizeOf;
use inverted::TokenizerConfig;
use lance_core::utils::mask::RowIdTreeMap;
use lance_core::{Error, Result};
use snafu::{location, Location};
use crate::{Index, IndexParams, IndexType};
pub mod bitmap;
pub mod btree;
pub mod expression;
pub mod flat;
pub mod inverted;
pub mod label_list;
pub mod lance_format;
pub const LANCE_SCALAR_INDEX: &str = "__lance_scalar_index";
#[derive(Debug)]
pub enum ScalarIndexType {
BTree,
Bitmap,
LabelList,
Inverted,
}
impl TryFrom<IndexType> for ScalarIndexType {
type Error = Error;
fn try_from(value: IndexType) -> Result<Self> {
match value {
IndexType::BTree | IndexType::Scalar => Ok(Self::BTree),
IndexType::Bitmap => Ok(Self::Bitmap),
IndexType::LabelList => Ok(Self::LabelList),
IndexType::Inverted => Ok(Self::Inverted),
_ => Err(Error::InvalidInput {
source: format!("Index type {:?} is not a scalar index", value).into(),
location: location!(),
}),
}
}
}
#[derive(Default)]
pub struct ScalarIndexParams {
pub force_index_type: Option<ScalarIndexType>,
}
impl ScalarIndexParams {
pub fn new(index_type: ScalarIndexType) -> Self {
Self {
force_index_type: Some(index_type),
}
}
}
impl IndexParams for ScalarIndexParams {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn index_type(&self) -> IndexType {
match self.force_index_type {
Some(ScalarIndexType::BTree) | None => IndexType::BTree,
Some(ScalarIndexType::Bitmap) => IndexType::Bitmap,
Some(ScalarIndexType::LabelList) => IndexType::LabelList,
Some(ScalarIndexType::Inverted) => IndexType::Inverted,
}
}
fn index_name(&self) -> &str {
LANCE_SCALAR_INDEX
}
}
#[derive(Clone)]
pub struct InvertedIndexParams {
pub with_position: bool,
pub tokenizer_config: TokenizerConfig,
}
impl Debug for InvertedIndexParams {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("InvertedIndexParams")
.field("with_position", &self.with_position)
.finish()
}
}
impl DeepSizeOf for InvertedIndexParams {
fn deep_size_of_children(&self, _: &mut deepsize::Context) -> usize {
0
}
}
impl Default for InvertedIndexParams {
fn default() -> Self {
Self {
with_position: true,
tokenizer_config: TokenizerConfig::default(),
}
}
}
impl InvertedIndexParams {
pub fn with_position(mut self, with_position: bool) -> Self {
self.with_position = with_position;
self
}
}
impl IndexParams for InvertedIndexParams {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn index_type(&self) -> IndexType {
IndexType::Inverted
}
fn index_name(&self) -> &str {
"INVERTED"
}
}
#[async_trait]
pub trait IndexWriter: Send {
async fn write_record_batch(&mut self, batch: RecordBatch) -> Result<u64>;
async fn finish(&mut self) -> Result<()>;
async fn finish_with_metadata(&mut self, metadata: HashMap<String, String>) -> Result<()>;
}
#[async_trait]
pub trait IndexReader: Send + Sync {
async fn read_record_batch(&self, n: u32) -> Result<RecordBatch>;
async fn read_range(
&self,
range: std::ops::Range<usize>,
projection: Option<&[&str]>,
) -> Result<RecordBatch>;
async fn num_batches(&self) -> u32;
fn num_rows(&self) -> usize;
fn schema(&self) -> &lance_core::datatypes::Schema;
}
#[async_trait]
pub trait IndexStore: std::fmt::Debug + Send + Sync + DeepSizeOf {
fn as_any(&self) -> &dyn Any;
fn io_parallelism(&self) -> usize;
async fn new_index_file(&self, name: &str, schema: Arc<Schema>)
-> Result<Box<dyn IndexWriter>>;
async fn open_index_file(&self, name: &str) -> Result<Arc<dyn IndexReader>>;
async fn copy_index_file(&self, name: &str, dest_store: &dyn IndexStore) -> Result<()>;
}
pub trait AnyQuery: std::fmt::Debug + Any + Send + Sync {
fn as_any(&self) -> &dyn Any;
fn format(&self, col: &str) -> String;
fn to_expr(&self, col: String) -> Expr;
fn dyn_eq(&self, other: &dyn AnyQuery) -> bool;
}
impl PartialEq for dyn AnyQuery {
fn eq(&self, other: &Self) -> bool {
self.dyn_eq(other)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct FullTextSearchQuery {
pub columns: Vec<String>,
pub query: String,
pub limit: Option<i64>,
pub wand_factor: Option<f32>,
}
impl FullTextSearchQuery {
pub fn new(query: String) -> Self {
Self {
query,
limit: None,
columns: vec![],
wand_factor: None,
}
}
pub fn columns(mut self, columns: Option<Vec<String>>) -> Self {
if let Some(columns) = columns {
self.columns = columns;
}
self
}
pub fn limit(mut self, limit: Option<i64>) -> Self {
self.limit = limit;
self
}
pub fn wand_factor(mut self, wand_factor: Option<f32>) -> Self {
self.wand_factor = wand_factor;
self
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum SargableQuery {
Range(Bound<ScalarValue>, Bound<ScalarValue>),
IsIn(Vec<ScalarValue>),
Equals(ScalarValue),
FullTextSearch(FullTextSearchQuery),
IsNull(),
}
impl AnyQuery for SargableQuery {
fn as_any(&self) -> &dyn Any {
self
}
fn format(&self, col: &str) -> String {
match self {
Self::Range(lower, upper) => match (lower, upper) {
(Bound::Unbounded, Bound::Unbounded) => "true".to_string(),
(Bound::Unbounded, Bound::Included(rhs)) => format!("{} <= {}", col, rhs),
(Bound::Unbounded, Bound::Excluded(rhs)) => format!("{} < {}", col, rhs),
(Bound::Included(lhs), Bound::Unbounded) => format!("{} >= {}", col, lhs),
(Bound::Included(lhs), Bound::Included(rhs)) => {
format!("{} >= {} && {} <= {}", col, lhs, col, rhs)
}
(Bound::Included(lhs), Bound::Excluded(rhs)) => {
format!("{} >= {} && {} < {}", col, lhs, col, rhs)
}
(Bound::Excluded(lhs), Bound::Unbounded) => format!("{} > {}", col, lhs),
(Bound::Excluded(lhs), Bound::Included(rhs)) => {
format!("{} > {} && {} <= {}", col, lhs, col, rhs)
}
(Bound::Excluded(lhs), Bound::Excluded(rhs)) => {
format!("{} > {} && {} < {}", col, lhs, col, rhs)
}
},
Self::IsIn(values) => {
format!(
"{} IN [{}]",
col,
values
.iter()
.map(|val| val.to_string())
.collect::<Vec<_>>()
.join(",")
)
}
Self::FullTextSearch(query) => {
format!("fts({})", query.query)
}
Self::IsNull() => {
format!("{} IS NULL", col)
}
Self::Equals(val) => {
format!("{} = {}", col, val)
}
}
}
fn to_expr(&self, col: String) -> Expr {
let col_expr = Expr::Column(Column::new_unqualified(col));
match self {
Self::Range(lower, upper) => match (lower, upper) {
(Bound::Unbounded, Bound::Unbounded) => {
Expr::Literal(ScalarValue::Boolean(Some(true)))
}
(Bound::Unbounded, Bound::Included(rhs)) => {
col_expr.lt_eq(Expr::Literal(rhs.clone()))
}
(Bound::Unbounded, Bound::Excluded(rhs)) => col_expr.lt(Expr::Literal(rhs.clone())),
(Bound::Included(lhs), Bound::Unbounded) => {
col_expr.gt_eq(Expr::Literal(lhs.clone()))
}
(Bound::Included(lhs), Bound::Included(rhs)) => {
col_expr.between(Expr::Literal(lhs.clone()), Expr::Literal(rhs.clone()))
}
(Bound::Included(lhs), Bound::Excluded(rhs)) => col_expr
.clone()
.gt_eq(Expr::Literal(lhs.clone()))
.and(col_expr.lt(Expr::Literal(rhs.clone()))),
(Bound::Excluded(lhs), Bound::Unbounded) => col_expr.gt(Expr::Literal(lhs.clone())),
(Bound::Excluded(lhs), Bound::Included(rhs)) => col_expr
.clone()
.gt(Expr::Literal(lhs.clone()))
.and(col_expr.lt_eq(Expr::Literal(rhs.clone()))),
(Bound::Excluded(lhs), Bound::Excluded(rhs)) => col_expr
.clone()
.gt(Expr::Literal(lhs.clone()))
.and(col_expr.lt(Expr::Literal(rhs.clone()))),
},
Self::IsIn(values) => col_expr.in_list(
values
.iter()
.map(|val| Expr::Literal(val.clone()))
.collect::<Vec<_>>(),
false,
),
Self::FullTextSearch(query) => {
col_expr.like(Expr::Literal(ScalarValue::Utf8(Some(query.query.clone()))))
}
Self::IsNull() => col_expr.is_null(),
Self::Equals(value) => col_expr.eq(Expr::Literal(value.clone())),
}
}
fn dyn_eq(&self, other: &dyn AnyQuery) -> bool {
match other.as_any().downcast_ref::<Self>() {
Some(o) => self == o,
None => false,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum LabelListQuery {
HasAllLabels(Vec<ScalarValue>),
HasAnyLabel(Vec<ScalarValue>),
}
impl AnyQuery for LabelListQuery {
fn as_any(&self) -> &dyn Any {
self
}
fn format(&self, col: &str) -> String {
format!("{}", self.to_expr(col.to_string()))
}
fn to_expr(&self, col: String) -> Expr {
match self {
Self::HasAllLabels(labels) => {
let labels_arr = ScalarValue::iter_to_array(labels.iter().cloned()).unwrap();
let offsets_buffer =
OffsetBuffer::new(ScalarBuffer::<i32>::from(vec![0, labels_arr.len() as i32]));
let labels_list = ListArray::try_new(
Arc::new(Field::new("item", labels_arr.data_type().clone(), false)),
offsets_buffer,
labels_arr,
None,
)
.unwrap();
let labels_arr = Arc::new(labels_list);
Expr::ScalarFunction(ScalarFunction {
func: Arc::new(array_has::ArrayHasAll::new().into()),
args: vec![
Expr::Column(Column::new_unqualified(col)),
Expr::Literal(ScalarValue::List(labels_arr)),
],
})
}
Self::HasAnyLabel(labels) => {
let labels_arr = ScalarValue::iter_to_array(labels.iter().cloned()).unwrap();
let offsets_buffer =
OffsetBuffer::new(ScalarBuffer::<i32>::from(vec![0, labels_arr.len() as i32]));
let labels_list = ListArray::try_new(
Arc::new(Field::new("item", labels_arr.data_type().clone(), false)),
offsets_buffer,
labels_arr,
None,
)
.unwrap();
let labels_arr = Arc::new(labels_list);
Expr::ScalarFunction(ScalarFunction {
func: Arc::new(array_has::ArrayHasAny::new().into()),
args: vec![
Expr::Column(Column::new_unqualified(col)),
Expr::Literal(ScalarValue::List(labels_arr)),
],
})
}
}
}
fn dyn_eq(&self, other: &dyn AnyQuery) -> bool {
match other.as_any().downcast_ref::<Self>() {
Some(o) => self == o,
None => false,
}
}
}
#[async_trait]
pub trait ScalarIndex: Send + Sync + std::fmt::Debug + Index + DeepSizeOf {
async fn search(&self, query: &dyn AnyQuery) -> Result<RowIdTreeMap>;
async fn load(store: Arc<dyn IndexStore>) -> Result<Arc<Self>>
where
Self: Sized;
async fn remap(
&self,
mapping: &HashMap<u64, Option<u64>>,
dest_store: &dyn IndexStore,
) -> Result<()>;
async fn update(
&self,
new_data: SendableRecordBatchStream,
dest_store: &dyn IndexStore,
) -> Result<()>;
}