1use arrow_array::builder::Int64Builder;
5use arrow_array::{Array, Int64Array};
6use arrow_schema::DataType;
7use deepsize::DeepSizeOf;
8use lance_io::encodings::plain::PlainDecoder;
9use lance_io::encodings::Decoder;
10use snafu::location;
11use std::collections::BTreeMap;
12use tokio::io::AsyncWriteExt;
13
14use lance_core::{Error, Result};
15use lance_io::traits::{Reader, Writer};
16
17#[derive(Clone, Debug, PartialEq, DeepSizeOf)]
18pub struct PageInfo {
19 pub position: usize,
20 pub length: usize,
21}
22
23impl PageInfo {
24 pub fn new(position: usize, length: usize) -> Self {
25 Self { position, length }
26 }
27}
28
29#[derive(Debug, Default, Clone, PartialEq, DeepSizeOf)]
32pub struct PageTable {
33 pages: BTreeMap<i32, BTreeMap<i32, PageInfo>>,
35}
36
37impl PageTable {
38 pub async fn load(
55 reader: &dyn Reader,
56 position: usize,
57 min_field_id: i32,
58 max_field_id: i32,
59 num_batches: i32,
60 ) -> Result<Self> {
61 if max_field_id < min_field_id {
62 return Err(Error::Internal {
63 message: format!(
64 "max_field_id {} is less than min_field_id {}",
65 max_field_id, min_field_id
66 ),
67 location: location!(),
68 });
69 }
70
71 let field_ids = min_field_id..=max_field_id;
72 let num_columns = field_ids.clone().count();
73 let length = num_columns * num_batches as usize * 2;
74 let decoder = PlainDecoder::new(reader, &DataType::Int64, position, length)?;
75 let raw_arr = decoder.decode().await?;
76 let arr = raw_arr.as_any().downcast_ref::<Int64Array>().unwrap();
77
78 let mut pages = BTreeMap::default();
79 for (field_pos, field_id) in field_ids.enumerate() {
80 pages.insert(field_id, BTreeMap::default());
81 for batch in 0..num_batches {
82 let idx = field_pos as i32 * num_batches + batch;
83 let batch_position = &arr.value((idx * 2) as usize);
84 let batch_length = &arr.value((idx * 2 + 1) as usize);
85 pages.get_mut(&field_id).unwrap().insert(
86 batch,
87 PageInfo {
88 position: *batch_position as usize,
89 length: *batch_length as usize,
90 },
91 );
92 }
93 }
94
95 Ok(Self { pages })
96 }
97
98 pub async fn write(&self, writer: &mut dyn Writer, min_field_id: i32) -> Result<usize> {
108 if self.pages.is_empty() {
109 return Err(Error::InvalidInput {
110 source: "empty page table".into(),
111 location: location!(),
112 });
113 }
114
115 let observed_min = *self.pages.keys().min().unwrap();
116 if min_field_id > *self.pages.keys().min().unwrap() {
117 return Err(Error::invalid_input(
118 format!(
119 "field_id_offset {} is greater than the minimum field_id {}",
120 min_field_id, observed_min
121 ),
122 location!(),
123 ));
124 }
125 let max_field_id = *self.pages.keys().max().unwrap();
126 let field_ids = min_field_id..=max_field_id;
127
128 let pos = writer.tell().await?;
129 let num_batches = self
130 .pages
131 .values()
132 .flat_map(|c_map| c_map.keys().max())
133 .max()
134 .unwrap()
135 + 1;
136
137 let mut builder =
138 Int64Builder::with_capacity(field_ids.clone().count() * num_batches as usize);
139 for field_id in field_ids {
140 for batch in 0..num_batches {
141 if let Some(page_info) = self.get(field_id, batch) {
142 builder.append_value(page_info.position as i64);
143 builder.append_value(page_info.length as i64);
144 } else {
145 builder.append_slice(&[0, 0]);
146 }
147 }
148 }
149 let arr = builder.finish();
150 writer
151 .write_all(arr.into_data().buffers()[0].as_slice())
152 .await?;
153
154 Ok(pos)
155 }
156
157 pub fn set(&mut self, field_id: i32, batch: i32, page_info: PageInfo) {
159 self.pages
160 .entry(field_id)
161 .or_default()
162 .insert(batch, page_info);
163 }
164
165 pub fn get(&self, field_id: i32, batch: i32) -> Option<&PageInfo> {
166 self.pages
167 .get(&field_id)
168 .and_then(|c_map| c_map.get(&batch))
169 }
170}
171
172#[cfg(test)]
173mod tests {
174
175 use super::*;
176 use pretty_assertions::assert_eq;
177
178 use lance_io::local::LocalObjectReader;
179
180 #[test]
181 fn test_set_page_info() {
182 let mut page_table = PageTable::default();
183 let page_info = PageInfo::new(1, 2);
184 page_table.set(10, 20, page_info.clone());
185
186 let actual = page_table.get(10, 20).unwrap();
187 assert_eq!(actual, &page_info);
188 }
189
190 #[tokio::test]
191 async fn test_roundtrip_page_info() {
192 let mut page_table = PageTable::default();
193 let page_info = PageInfo::new(1, 2);
194
195 page_table.set(10, 2, page_info.clone());
197 page_table.set(11, 1, page_info.clone());
198 page_table.set(13, 0, page_info.clone());
200 page_table.set(13, 1, page_info.clone());
201 page_table.set(13, 2, page_info.clone());
202 page_table.set(13, 3, page_info.clone());
203
204 let test_dir = tempfile::tempdir().unwrap();
205 let path = test_dir.path().join("test");
206
207 let starting_field_id = 9;
211
212 let mut writer = tokio::fs::File::create(&path).await.unwrap();
213 let pos = page_table
214 .write(&mut writer, starting_field_id)
215 .await
216 .unwrap();
217 writer.shutdown().await.unwrap();
218
219 let reader = LocalObjectReader::open_local_path(&path, 1024, None)
220 .await
221 .unwrap();
222 let actual = PageTable::load(
223 reader.as_ref(),
224 pos,
225 starting_field_id, 13, 4, )
229 .await
230 .unwrap();
231
232 let mut expected = actual.clone();
234 let default_page_info = PageInfo::new(0, 0);
235 let expected_default_pages = [
236 (9, 0),
237 (9, 1),
238 (9, 2),
239 (9, 3),
240 (10, 0),
241 (10, 1),
242 (10, 3),
243 (11, 0),
244 (11, 2),
245 (11, 3),
246 (12, 0),
247 (12, 1),
248 (12, 2),
249 (12, 3),
250 ];
251 for (field_id, batch) in expected_default_pages.iter() {
252 expected.set(*field_id, *batch, default_page_info.clone());
253 }
254
255 assert_eq!(expected, actual);
256 }
257
258 #[tokio::test]
259 async fn test_error_handling() {
260 let mut page_table = PageTable::default();
261
262 let test_dir = tempfile::tempdir().unwrap();
263 let path = test_dir.path().join("test");
264
265 let mut writer = tokio::fs::File::create(&path).await.unwrap();
267 let res = page_table.write(&mut writer, 1).await;
268 assert!(res.is_err());
269 assert!(
270 matches!(res.unwrap_err(), Error::InvalidInput { source, .. } if source.to_string().contains("empty page table"))
271 );
272
273 let page_info = PageInfo::new(1, 2);
274 page_table.set(0, 0, page_info.clone());
275
276 let mut writer = tokio::fs::File::create(&path).await.unwrap();
278 let res = page_table.write(&mut writer, 1).await;
279 assert!(res.is_err());
280 assert!(
281 matches!(res.unwrap_err(), Error::InvalidInput { source, .. }
282 if source.to_string().contains("field_id_offset 1 is greater than the minimum field_id 0"))
283 );
284
285 let mut writer = tokio::fs::File::create(&path).await.unwrap();
286 let res = page_table.write(&mut writer, 0).await.unwrap();
287
288 let reader = LocalObjectReader::open_local_path(&path, 1024, None)
289 .await
290 .unwrap();
291
292 let res = PageTable::load(reader.as_ref(), res, 1, 0, 1).await;
294 assert!(res.is_err());
295 assert!(matches!(res.unwrap_err(), Error::Internal { message, .. }
296 if message.contains("max_field_id 0 is less than min_field_id 1")));
297 }
298}