use arrow_array::builder::Int64Builder;
use arrow_array::{Array, Int64Array};
use arrow_schema::DataType;
use deepsize::DeepSizeOf;
use lance_io::encodings::plain::PlainDecoder;
use lance_io::encodings::Decoder;
use snafu::{location, Location};
use std::collections::BTreeMap;
use tokio::io::AsyncWriteExt;
use lance_core::{Error, Result};
use lance_io::traits::{Reader, Writer};
#[derive(Clone, Debug, PartialEq, DeepSizeOf)]
pub struct PageInfo {
pub position: usize,
pub length: usize,
}
impl PageInfo {
pub fn new(position: usize, length: usize) -> Self {
Self { position, length }
}
}
#[derive(Debug, Default, Clone, PartialEq, DeepSizeOf)]
pub struct PageTable {
pages: BTreeMap<i32, BTreeMap<i32, PageInfo>>,
}
impl PageTable {
pub async fn load<'a>(
reader: &dyn Reader,
position: usize,
min_field_id: i32,
max_field_id: i32,
num_batches: i32,
) -> Result<Self> {
if max_field_id < min_field_id {
return Err(Error::Internal {
message: format!(
"max_field_id {} is less than min_field_id {}",
max_field_id, min_field_id
),
location: location!(),
});
}
let field_ids = min_field_id..=max_field_id;
let num_columns = field_ids.clone().count();
let length = num_columns * num_batches as usize * 2;
let decoder = PlainDecoder::new(reader, &DataType::Int64, position, length)?;
let raw_arr = decoder.decode().await?;
let arr = raw_arr.as_any().downcast_ref::<Int64Array>().unwrap();
let mut pages = BTreeMap::default();
for (field_pos, field_id) in field_ids.enumerate() {
pages.insert(field_id, BTreeMap::default());
for batch in 0..num_batches {
let idx = field_pos as i32 * num_batches + batch;
let batch_position = &arr.value((idx * 2) as usize);
let batch_length = &arr.value((idx * 2 + 1) as usize);
pages.get_mut(&field_id).unwrap().insert(
batch,
PageInfo {
position: *batch_position as usize,
length: *batch_length as usize,
},
);
}
}
Ok(Self { pages })
}
pub async fn write(&self, writer: &mut dyn Writer, min_field_id: i32) -> Result<usize> {
if self.pages.is_empty() {
return Err(Error::InvalidInput {
source: "empty page table".into(),
location: location!(),
});
}
let observed_min = *self.pages.keys().min().unwrap();
if min_field_id > *self.pages.keys().min().unwrap() {
return Err(Error::invalid_input(
format!(
"field_id_offset {} is greater than the minimum field_id {}",
min_field_id, observed_min
),
location!(),
));
}
let max_field_id = *self.pages.keys().max().unwrap();
let field_ids = min_field_id..=max_field_id;
let pos = writer.tell().await?;
let num_batches = self
.pages
.values()
.flat_map(|c_map| c_map.keys().max())
.max()
.unwrap()
+ 1;
let mut builder =
Int64Builder::with_capacity(field_ids.clone().count() * num_batches as usize);
for field_id in field_ids {
for batch in 0..num_batches {
if let Some(page_info) = self.get(field_id, batch) {
builder.append_value(page_info.position as i64);
builder.append_value(page_info.length as i64);
} else {
builder.append_slice(&[0, 0]);
}
}
}
let arr = builder.finish();
writer
.write_all(arr.into_data().buffers()[0].as_slice())
.await?;
Ok(pos)
}
pub fn set(&mut self, field_id: i32, batch: i32, page_info: PageInfo) {
self.pages
.entry(field_id)
.or_default()
.insert(batch, page_info);
}
pub fn get(&self, field_id: i32, batch: i32) -> Option<&PageInfo> {
self.pages
.get(&field_id)
.and_then(|c_map| c_map.get(&batch))
}
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
use lance_io::local::LocalObjectReader;
#[test]
fn test_set_page_info() {
let mut page_table = PageTable::default();
let page_info = PageInfo::new(1, 2);
page_table.set(10, 20, page_info.clone());
let actual = page_table.get(10, 20).unwrap();
assert_eq!(actual, &page_info);
}
#[tokio::test]
async fn test_roundtrip_page_info() {
let mut page_table = PageTable::default();
let page_info = PageInfo::new(1, 2);
page_table.set(10, 2, page_info.clone());
page_table.set(11, 1, page_info.clone());
page_table.set(13, 0, page_info.clone());
page_table.set(13, 1, page_info.clone());
page_table.set(13, 2, page_info.clone());
page_table.set(13, 3, page_info.clone());
let test_dir = tempfile::tempdir().unwrap();
let path = test_dir.path().join("test");
let starting_field_id = 9;
let mut writer = tokio::fs::File::create(&path).await.unwrap();
let pos = page_table
.write(&mut writer, starting_field_id)
.await
.unwrap();
writer.shutdown().await.unwrap();
let reader = LocalObjectReader::open_local_path(&path, 1024, None)
.await
.unwrap();
let actual = PageTable::load(
reader.as_ref(),
pos,
starting_field_id, 13, 4, )
.await
.unwrap();
let mut expected = actual.clone();
let default_page_info = PageInfo::new(0, 0);
let expected_default_pages = [
(9, 0),
(9, 1),
(9, 2),
(9, 3),
(10, 0),
(10, 1),
(10, 3),
(11, 0),
(11, 2),
(11, 3),
(12, 0),
(12, 1),
(12, 2),
(12, 3),
];
for (field_id, batch) in expected_default_pages.iter() {
expected.set(*field_id, *batch, default_page_info.clone());
}
assert_eq!(expected, actual);
}
#[tokio::test]
async fn test_error_handling() {
let mut page_table = PageTable::default();
let test_dir = tempfile::tempdir().unwrap();
let path = test_dir.path().join("test");
let mut writer = tokio::fs::File::create(&path).await.unwrap();
let res = page_table.write(&mut writer, 1).await;
assert!(res.is_err());
assert!(
matches!(res.unwrap_err(), Error::InvalidInput { source, .. } if source.to_string().contains("empty page table"))
);
let page_info = PageInfo::new(1, 2);
page_table.set(0, 0, page_info.clone());
let mut writer = tokio::fs::File::create(&path).await.unwrap();
let res = page_table.write(&mut writer, 1).await;
assert!(res.is_err());
assert!(
matches!(res.unwrap_err(), Error::InvalidInput { source, .. }
if source.to_string().contains("field_id_offset 1 is greater than the minimum field_id 0"))
);
let mut writer = tokio::fs::File::create(&path).await.unwrap();
let res = page_table.write(&mut writer, 0).await.unwrap();
let reader = LocalObjectReader::open_local_path(&path, 1024, None)
.await
.unwrap();
let res = PageTable::load(reader.as_ref(), res, 1, 0, 1).await;
assert!(res.is_err());
assert!(matches!(res.unwrap_err(), Error::Internal { message, .. }
if message.contains("max_field_id 0 is less than min_field_id 1")));
}
}