1use crate::ctx::Context;
2use ahash::{HashMap, HashMapExt, HashSet};
3use reblessive::tree::Stk;
4use revision::revisioned;
5use roaring::RoaringTreemap;
6use serde::{Deserialize, Serialize};
7use std::collections::hash_map::Entry;
8use std::collections::{BinaryHeap, VecDeque};
9use std::fmt::{Debug, Display, Formatter};
10use std::io::Cursor;
11use std::sync::Arc;
12use tokio::sync::RwLock;
13
14use crate::err::Error;
15
16use crate::idx::docids::{DocId, DocIds};
17use crate::idx::planner::checker::MTreeConditionChecker;
18use crate::idx::planner::iterators::KnnIteratorResult;
19use crate::idx::trees::btree::BStatistics;
20use crate::idx::trees::knn::{Ids64, KnnResult, KnnResultBuilder, PriorityNode};
21use crate::idx::trees::store::{NodeId, StoredNode, TreeNode, TreeNodeProvider, TreeStore};
22use crate::idx::trees::vector::{SharedVector, Vector};
23use crate::idx::{IndexKeyBase, VersionedStore};
24use crate::kvs::{Key, Transaction, TransactionType, Val};
25use crate::sql::index::{Distance, MTreeParams, VectorType};
26use crate::sql::{Number, Object, Thing, Value};
27
28#[non_exhaustive]
29pub struct MTreeIndex {
30 state_key: Key,
31 dim: usize,
32 vector_type: VectorType,
33 store: MTreeStore,
34 doc_ids: Arc<RwLock<DocIds>>,
35 mtree: Arc<RwLock<MTree>>,
36}
37
38struct MTreeSearchContext<'a> {
39 ctx: &'a Context,
40 pt: SharedVector,
41 k: usize,
42 store: &'a MTreeStore,
43}
44
45impl MTreeIndex {
46 pub async fn new(
47 txn: &Transaction,
48 ikb: IndexKeyBase,
49 p: &MTreeParams,
50 tt: TransactionType,
51 ) -> Result<Self, Error> {
52 let doc_ids = Arc::new(RwLock::new(
53 DocIds::new(txn, tt, ikb.clone(), p.doc_ids_order, p.doc_ids_cache).await?,
54 ));
55 let state_key = ikb.new_vm_key(None)?;
56 let state: MState = if let Some(val) = txn.get(state_key.clone(), None).await? {
57 VersionedStore::try_from(val)?
58 } else {
59 MState::new(p.capacity)
60 };
61 let store = txn
62 .index_caches()
63 .get_store_mtree(
64 TreeNodeProvider::Vector(ikb),
65 state.generation,
66 tt,
67 p.mtree_cache as usize,
68 )
69 .await?;
70 let mtree = Arc::new(RwLock::new(MTree::new(state, p.distance.clone())));
71 Ok(Self {
72 state_key,
73 dim: p.dimension as usize,
74 vector_type: p.vector_type,
75 doc_ids,
76 mtree,
77 store,
78 })
79 }
80
81 pub async fn index_document(
82 &mut self,
83 stk: &mut Stk,
84 txn: &Transaction,
85 rid: &Thing,
86 content: &[Value],
87 ) -> Result<(), Error> {
88 let mut doc_ids = self.doc_ids.write().await;
90 let resolved = doc_ids.resolve_doc_id(txn, revision::to_vec(rid)?).await?;
91 let doc_id = *resolved.doc_id();
92 drop(doc_ids);
93 let mut mtree = self.mtree.write().await;
95 for v in content.iter().filter(|v| v.is_some()) {
96 let vector = Vector::try_from_value(self.vector_type, self.dim, v)?;
98 vector.check_dimension(self.dim)?;
99 mtree.insert(stk, txn, &mut self.store, vector.into(), doc_id).await?;
101 }
102 drop(mtree);
103 Ok(())
104 }
105
106 pub async fn remove_document(
107 &mut self,
108 stk: &mut Stk,
109 txn: &Transaction,
110 rid: &Thing,
111 content: &[Value],
112 ) -> Result<(), Error> {
113 let mut doc_ids = self.doc_ids.write().await;
114 let doc_id = doc_ids.remove_doc(txn, revision::to_vec(rid)?).await?;
115 drop(doc_ids);
116 if let Some(doc_id) = doc_id {
117 let mut mtree = self.mtree.write().await;
119 for v in content.iter().filter(|v| v.is_some()) {
120 let vector = Vector::try_from_value(self.vector_type, self.dim, v)?;
122 vector.check_dimension(self.dim)?;
123 mtree.delete(stk, txn, &mut self.store, vector.into(), doc_id).await?;
125 }
126 drop(mtree);
127 }
128 Ok(())
129 }
130
131 pub async fn knn_search(
132 &self,
133 stk: &mut Stk,
134 ctx: &Context,
135 v: &[Number],
136 k: usize,
137 mut chk: MTreeConditionChecker<'_>,
138 ) -> Result<VecDeque<KnnIteratorResult>, Error> {
139 let vector = Vector::try_from_vector(self.vector_type, v)?;
141 vector.check_dimension(self.dim)?;
142 let search = MTreeSearchContext {
144 ctx,
145 pt: vector.into(),
146 k,
147 store: &self.store,
148 };
149 let mtree = self.mtree.read().await;
151 let doc_ids = self.doc_ids.read().await;
152 let res = mtree.knn_search(&search, &doc_ids, stk, &mut chk).await?;
154 drop(mtree);
155 let res = chk.convert_result(&doc_ids, res.docs).await;
157 drop(doc_ids);
158 res
159 }
160
161 pub(crate) async fn statistics(&self, tx: &Transaction) -> Result<MtStatistics, Error> {
162 Ok(MtStatistics {
163 doc_ids: self.doc_ids.read().await.statistics(tx).await?,
164 })
165 }
166
167 pub async fn finish(&mut self, tx: &Transaction) -> Result<(), Error> {
168 let mut doc_ids = self.doc_ids.write().await;
169 doc_ids.finish(tx).await?;
170 drop(doc_ids);
171 let mut mtree = self.mtree.write().await;
172 if let Some(new_cache) = self.store.finish(tx).await? {
173 mtree.state.generation += 1;
174 tx.set(self.state_key.clone(), VersionedStore::try_into(&mtree.state)?, None).await?;
175 tx.index_caches().advance_store_mtree(new_cache);
176 }
177 drop(mtree);
178 Ok(())
179 }
180}
181
182#[non_exhaustive]
185struct MTree {
186 state: MState,
187 distance: Distance,
188 minimum: usize,
189}
190
191impl MTree {
192 fn new(state: MState, distance: Distance) -> Self {
193 let minimum = (state.capacity + 1) as usize / 2;
194 Self {
195 state,
196 distance,
197 minimum,
198 }
199 }
200
201 async fn knn_search(
202 &self,
203 search: &MTreeSearchContext<'_>,
204 doc_ids: &DocIds,
205 stk: &mut Stk,
206 chk: &mut MTreeConditionChecker<'_>,
207 ) -> Result<KnnResult, Error> {
208 #[cfg(debug_assertions)]
209 debug!("knn_search - pt: {:?} - k: {}", search.pt, search.k);
210 let mut queue = BinaryHeap::new();
211 let mut res = KnnResultBuilder::new(search.k);
212 if let Some(root_id) = self.state.root {
213 queue.push(PriorityNode::new(0.0, root_id));
214 }
215 #[cfg(debug_assertions)]
216 let mut visited_nodes = HashMap::default();
217 while let Some(e) = queue.pop() {
218 let id = e.id();
219 let node = search.store.get_node_txn(search.ctx, id).await?;
220 #[cfg(debug_assertions)]
221 {
222 debug!("Visit node id: {}", id);
223 if visited_nodes.insert(id, node.n.len()).is_some() {
224 return Err(fail!("MTree::knn_search"));
225 }
226 }
227 match node.n {
228 MTreeNode::Leaf(ref n) => {
229 #[cfg(debug_assertions)]
230 debug!("Leaf found - id: {} - len: {}", node.id, n.len(),);
231 for (o, p) in n {
232 let d = self.calculate_distance(o, &search.pt)?;
233 if res.check_add(d) {
234 #[cfg(debug_assertions)]
235 debug!("Add: {d} - obj: {o:?} - docs: {:?}", p.docs);
236 let mut docs = Ids64::Empty;
237 for doc in &p.docs {
238 if chk.check_truthy(stk, doc_ids, doc).await? {
239 if let Some(new_docs) = docs.insert(doc) {
240 docs = new_docs;
241 }
242 }
243 }
244 if !docs.is_empty() {
245 let evicted_docs = res.add(d, docs);
246 chk.expires(evicted_docs);
247 }
248 }
249 }
250 }
251 MTreeNode::Internal(ref n) => {
252 #[cfg(debug_assertions)]
253 debug!("Internal found - id: {} - {:?}", node.id, n);
254 for (o, p) in n {
255 let d = self.calculate_distance(o, &search.pt)?;
256 let min_dist = (d - p.radius).max(0.0);
257 if res.check_add(min_dist) {
258 debug!("Queue add - dist: {} - node: {}", min_dist, p.node);
259 queue.push(PriorityNode::new(min_dist, p.node));
260 }
261 }
262 }
263 }
264 }
265 Ok(res.build(
266 #[cfg(debug_assertions)]
267 visited_nodes,
268 ))
269 }
270}
271
272enum InsertionResult {
273 DocAdded,
274 CoveringRadius(f64),
275 PromotedEntries(SharedVector, RoutingProperties, SharedVector, RoutingProperties),
276}
277
278enum DeletionResult {
279 NotFound,
280 DocRemoved,
281 CoveringRadius(f64),
282 Underflown(MStoredNode, bool),
283}
284
285impl MTree {
287 fn new_node_id(&mut self) -> NodeId {
288 let new_node_id = self.state.next_node_id;
289 self.state.next_node_id += 1;
290 new_node_id
291 }
292
293 async fn insert(
294 &mut self,
295 stk: &mut Stk,
296 tx: &Transaction,
297 store: &mut MTreeStore,
298 obj: SharedVector,
299 id: DocId,
300 ) -> Result<(), Error> {
301 #[cfg(debug_assertions)]
302 debug!("Insert - obj: {:?} - doc: {}", obj, id);
303 if self.append(tx, store, &obj, id).await? {
305 return Ok(());
306 }
307 if let Some(root_id) = self.state.root {
308 let node = store.get_node_mut(tx, root_id).await?;
309 if let InsertionResult::PromotedEntries(o1, p1, o2, p2) =
311 self.insert_at_node(stk, tx, store, node, &None, obj, id).await?
312 {
313 self.create_new_internal_root(store, o1, p1, o2, p2).await?;
314 }
315 } else {
316 self.create_new_leaf_root(store, obj, id).await?;
317 }
318 Ok(())
319 }
320
321 async fn create_new_leaf_root(
322 &mut self,
323 store: &mut MTreeStore,
324 obj: SharedVector,
325 id: DocId,
326 ) -> Result<(), Error> {
327 let new_root_id = self.new_node_id();
328 let p = ObjectProperties::new_root(id);
329 let mut objects = LeafMap::with_capacity(1);
330 objects.insert(obj, p);
331 let new_root_node = store.new_node(new_root_id, MTreeNode::Leaf(objects))?;
332 store.set_node(new_root_node, true).await?;
333 self.set_root(Some(new_root_id));
334 Ok(())
335 }
336
337 async fn create_new_internal_root(
338 &mut self,
339 store: &mut MTreeStore,
340 o1: SharedVector,
341 p1: RoutingProperties,
342 o2: SharedVector,
343 p2: RoutingProperties,
344 ) -> Result<(), Error> {
345 let new_root_id = self.new_node_id();
346 #[cfg(debug_assertions)]
347 debug!(
348 "New internal root - node: {} - e1.node: {} - e1.obj: {:?} - e1.radius: {} - e2.node: {} - e2.obj: {:?} - e2.radius: {}",
349 new_root_id,
350 p1.node,
351 o1,
352 p1.radius,
353 p2.node,
354 o2,
355 p2.radius
356 );
357 let mut entries = InternalMap::new();
358 entries.insert(o1, p1);
359 entries.insert(o2, p2);
360 let new_root_node = store.new_node(new_root_id, MTreeNode::Internal(entries))?;
361 store.set_node(new_root_node, true).await?;
362 self.set_root(Some(new_root_id));
363 Ok(())
364 }
365
366 async fn append(
367 &mut self,
368 tx: &Transaction,
369 store: &mut MTreeStore,
370 object: &SharedVector,
371 id: DocId,
372 ) -> Result<bool, Error> {
373 let mut queue = BinaryHeap::new();
374 if let Some(root_id) = self.state.root {
375 queue.push(root_id);
376 }
377 while let Some(current) = queue.pop() {
378 let mut node = store.get_node_mut(tx, current).await?;
379 match node.n {
380 MTreeNode::Leaf(ref mut n) => {
381 if let Some(p) = n.get_mut(object) {
382 p.docs.insert(id);
383 store.set_node(node, true).await?;
384 return Ok(true);
385 }
386 }
387 MTreeNode::Internal(ref n) => {
388 for (o, p) in n {
389 let d = self.calculate_distance(o, object)?;
390 if d <= p.radius {
391 queue.push(p.node);
392 }
393 }
394 }
395 }
396 store.set_node(node, false).await?;
397 }
398 Ok(false)
399 }
400
401 #[allow(clippy::too_many_arguments)]
403 async fn insert_at_node(
404 &mut self,
405 stk: &mut Stk,
406 tx: &Transaction,
407 store: &mut MTreeStore,
408 node: MStoredNode,
409 parent_center: &Option<SharedVector>,
410 object: SharedVector,
411 doc: DocId,
412 ) -> Result<InsertionResult, Error> {
413 #[cfg(debug_assertions)]
414 debug!("insert_at_node - node: {} - doc: {} - obj: {:?}", node.id, doc, object);
415 match node.n {
416 MTreeNode::Leaf(n) => {
418 self.insert_node_leaf(store, node.id, node.key, n, parent_center, object, doc).await
419 }
420 MTreeNode::Internal(n) => {
422 self.insert_node_internal(
423 stk,
424 tx,
425 store,
426 node.id,
427 node.key,
428 n,
429 parent_center,
430 object,
431 doc,
432 )
433 .await
434 }
435 }
436 }
437
438 #[allow(clippy::too_many_arguments)]
439 async fn insert_node_internal(
440 &mut self,
441 stk: &mut Stk,
442 tx: &Transaction,
443 store: &mut MTreeStore,
444 node_id: NodeId,
445 node_key: Key,
446 mut node: InternalNode,
447 parent_center: &Option<SharedVector>,
448 object: SharedVector,
449 doc_id: DocId,
450 ) -> Result<InsertionResult, Error> {
451 let (best_entry_obj, mut best_entry) = self.find_closest(&node, &object)?;
453 let best_node = store.get_node_mut(tx, best_entry.node).await?;
454 let best_entry_obj_op = Some(best_entry_obj.clone());
456 let this = &mut *self;
457 match stk
458 .run(|stk| async {
459 this.insert_at_node(stk, tx, store, best_node, &best_entry_obj_op, object, doc_id)
460 .await
461 })
462 .await?
463 {
464 InsertionResult::PromotedEntries(o1, mut p1, o2, mut p2) => {
466 #[cfg(debug_assertions)]
467 debug!(
468 "Promote to Node ID: {} - e1.node: {} - e1.obj: {:?} - e1.radius: {} - e2.node: {} - e2.obj: {:?} - e2.radius: {} ",
469 node_id, p1.node, o1, p1.radius, p2.node, o2, p2.radius
470 );
471 node.remove(&best_entry_obj);
473 let mut nup: HashSet<SharedVector> = HashSet::from_iter(node.keys().cloned());
475 nup.insert(o1.clone());
476 nup.insert(o2.clone());
477 if nup.len() <= self.state.capacity as usize {
478 if let Some(pc) = parent_center {
480 p1.parent_dist = self.calculate_distance(&o1, pc)?;
481 p2.parent_dist = self.calculate_distance(&o2, pc)?;
482 } else {
483 p1.parent_dist = 0.0;
484 p2.parent_dist = 0.0;
485 }
486 node.insert(o1, p1);
487 node.insert(o2, p2);
488 let max_dist = self.compute_internal_max_distance(&node);
489 Self::set_stored_node(store, node_id, node_key, node.into_mtree_node(), true)
490 .await?;
491 Ok(InsertionResult::CoveringRadius(max_dist))
492 } else {
493 node.insert(o1, p1);
494 node.insert(o2, p2);
495 let (o1, p1, o2, p2) = self.split_node(store, node_id, node_key, node).await?;
497 Ok(InsertionResult::PromotedEntries(o1, p1, o2, p2))
498 }
499 }
500 InsertionResult::DocAdded => {
501 store
502 .set_node(StoredNode::new(node.into_mtree_node(), node_id, node_key, 0), false)
503 .await?;
504 Ok(InsertionResult::DocAdded)
505 }
506 InsertionResult::CoveringRadius(covering_radius) => {
507 let mut updated = false;
508 if covering_radius > best_entry.radius {
509 #[cfg(debug_assertions)]
510 debug!(
511 "NODE: {} - BE_OBJ: {:?} - BE_RADIUS: {} -> {}",
512 node_id, best_entry_obj, best_entry.radius, covering_radius
513 );
514 best_entry.radius = covering_radius;
515 node.insert(best_entry_obj, best_entry);
516 updated = true;
517 }
518 let max_dist = self.compute_internal_max_distance(&node);
519 #[cfg(debug_assertions)]
520 debug!("NODE INTERNAL: {} - MAX_DIST: {:?}", node_id, max_dist);
521 store
522 .set_node(
523 StoredNode::new(node.into_mtree_node(), node_id, node_key, 0),
524 updated,
525 )
526 .await?;
527 Ok(InsertionResult::CoveringRadius(max_dist))
528 }
529 }
530 }
531
532 fn find_closest(
533 &self,
534 node: &InternalNode,
535 object: &SharedVector,
536 ) -> Result<(SharedVector, RoutingProperties), Error> {
537 let mut closest = None;
538 let mut dist = f64::MAX;
539 for (o, p) in node {
540 let d = self.calculate_distance(o, object)?;
541 if d < dist {
542 closest = Some((o.clone(), p.clone()));
543 dist = d;
544 }
545 }
546 #[cfg(debug_assertions)]
547 debug!("Find closest {:?} - Res: {:?}", object, closest);
548 if let Some((o, p)) = closest {
549 Ok((o, p))
550 } else {
551 Err(fail!("MTree::find_closest"))
552 }
553 }
554
555 #[allow(clippy::too_many_arguments)]
556 async fn insert_node_leaf(
557 &mut self,
558 store: &mut MTreeStore,
559 node_id: NodeId,
560 node_key: Key,
561 mut node: LeafNode,
562 parent_center: &Option<SharedVector>,
563 object: SharedVector,
564 doc_id: DocId,
565 ) -> Result<InsertionResult, Error> {
566 match node.entry(object) {
567 Entry::Occupied(mut e) => {
568 e.get_mut().docs.insert(doc_id);
569 store
570 .set_node(StoredNode::new(node.into_mtree_node(), node_id, node_key, 0), true)
571 .await?;
572 return Ok(InsertionResult::DocAdded);
573 }
574 Entry::Vacant(e) => {
576 let parent_dist = if let Some(pc) = parent_center {
578 self.calculate_distance(pc, e.key())?
579 } else {
580 0.0
581 };
582 e.insert(ObjectProperties::new(parent_dist, doc_id));
583 }
584 };
585 if node.len() <= self.state.capacity as usize {
587 let max_dist = self.compute_leaf_max_distance(&node, parent_center)?;
588 #[cfg(debug_assertions)]
589 debug!("NODE LEAF: {} - MAX_DIST: {:?}", node_id, max_dist);
590 store
591 .set_node(StoredNode::new(node.into_mtree_node(), node_id, node_key, 0), true)
592 .await?;
593 Ok(InsertionResult::CoveringRadius(max_dist))
594 } else {
595 let (o1, p1, o2, p2) = self.split_node(store, node_id, node_key, node).await?;
598 Ok(InsertionResult::PromotedEntries(o1, p1, o2, p2))
599 }
600 }
601
602 fn set_root(&mut self, new_root: Option<NodeId>) {
603 #[cfg(debug_assertions)]
604 debug!("SET_ROOT: {:?}", new_root);
605 self.state.root = new_root;
606 }
607
608 async fn split_node<N>(
609 &mut self,
610 store: &mut MTreeStore,
611 node_id: NodeId,
612 node_key: Key,
613 mut node: N,
614 ) -> Result<(SharedVector, RoutingProperties, SharedVector, RoutingProperties), Error>
615 where
616 N: NodeVectors + Debug,
617 {
618 #[cfg(debug_assertions)]
619 debug!("Split node: {:?}", node);
620 let mut a2 = node.get_objects();
621 let (distances, p1, p2) = self.compute_distances_and_promoted_objects(&a2)?;
622
623 a2.sort_by(|o1, o2| {
625 let d1 = *distances.0.get(&(p1.clone(), o1.clone())).unwrap_or(&0.0);
626 let d2 = *distances.0.get(&(p1.clone(), o2.clone())).unwrap_or(&0.0);
627 d1.total_cmp(&d2)
628 });
629 let a1_size = a2.len() / 2;
630 let a1: Vec<SharedVector> = a2.drain(0..a1_size).collect();
631
632 let (node1, r1, o1) = node.extract_node(&distances, p1, a1)?;
633 let (node2, r2, o2) = node.extract_node(&distances, p2, a2)?;
634
635 let new_node_id = self.new_node_id();
637
638 let n = StoredNode::new(node1.into_mtree_node(), node_id, node_key, 0);
640 store.set_node(n, true).await?;
641 let n = store.new_node(new_node_id, node2.into_mtree_node())?;
642 store.set_node(n, true).await?;
643
644 let p1 = RoutingProperties {
646 node: node_id,
647 radius: r1,
648 parent_dist: 0.0,
649 };
650 let p2 = RoutingProperties {
651 node: new_node_id,
652 radius: r2,
653 parent_dist: 0.0,
654 };
655
656 #[cfg(debug_assertions)]
657 if p1.node == p2.node {
658 return Err(fail!("MTree::split_node"));
659 }
660 Ok((o1, p1, o2, p2))
661 }
662
663 fn compute_distances_and_promoted_objects(
665 &self,
666 objects: &[SharedVector],
667 ) -> Result<(DistanceCache, SharedVector, SharedVector), Error> {
668 let mut promo = None;
669 let mut max_dist = 0f64;
670 let n = objects.len();
671 let mut dist_cache = HashMap::new();
672 for (i, o1) in objects.iter().enumerate() {
673 for o2 in objects.iter().take(n).skip(i + 1) {
674 let distance = self.calculate_distance(o1, o2)?;
675 dist_cache.insert((o1.clone(), o2.clone()), distance);
676 dist_cache.insert((o2.clone(), o1.clone()), distance); #[cfg(debug_assertions)]
678 {
679 assert_eq!(self.calculate_distance(o2, o1)?, distance);
681 debug!(
682 "dist_cache - len: {} - dist: {} - o1: {:?} - o2: {:?})",
683 dist_cache.len(),
684 distance,
685 o1,
686 o2
687 );
688 }
689 if distance > max_dist {
690 promo = Some((o1.clone(), o2.clone()));
691 max_dist = distance;
692 }
693 }
694 }
695 #[cfg(debug_assertions)]
696 {
697 debug!("Promo: {:?}", promo);
698 assert_eq!(dist_cache.len(), n * n - n);
699 }
700 match promo {
701 None => Err(fail!("MTree::compute_distances_and_promoted_objects")),
702 Some((p1, p2)) => Ok((DistanceCache(dist_cache), p1, p2)),
703 }
704 }
705
706 fn compute_internal_max_distance(&self, node: &InternalNode) -> f64 {
707 let mut max_dist = 0f64;
708 for p in node.values() {
709 max_dist = max_dist.max(p.parent_dist + p.radius);
710 }
711 max_dist
712 }
713
714 fn compute_leaf_max_distance(
715 &self,
716 node: &LeafNode,
717 parent: &Option<SharedVector>,
718 ) -> Result<f64, Error> {
719 Ok(if let Some(p) = parent {
720 let mut max_dist = 0f64;
721 for o in node.keys() {
722 max_dist = max_dist.max(self.calculate_distance(p, o)?);
723 }
724 max_dist
725 } else {
726 0.0
727 })
728 }
729
730 fn calculate_distance(&self, v1: &SharedVector, v2: &SharedVector) -> Result<f64, Error> {
731 if v1.eq(v2) {
732 return Ok(0.0);
733 }
734 let dist = self.distance.calculate(v1, v2);
735 if dist.is_finite() {
736 Ok(dist)
737 } else {
738 Err(Error::InvalidVectorDistance {
739 left: v1.clone(),
740 right: v2.clone(),
741 dist,
742 })
743 }
744 }
745
746 async fn delete(
747 &mut self,
748 stk: &mut Stk,
749 tx: &Transaction,
750 store: &mut MTreeStore,
751 object: SharedVector,
752 doc_id: DocId,
753 ) -> Result<bool, Error> {
754 let mut deleted = false;
755 if let Some(root_id) = self.state.root {
756 let root_node = store.get_node_mut(tx, root_id).await?;
757 if let DeletionResult::Underflown(sn, n_updated) = self
758 .delete_at_node(stk, tx, store, root_node, &None, object, doc_id, &mut deleted)
759 .await?
760 {
761 match &sn.n {
762 MTreeNode::Internal(n) => match n.len() {
763 0 => {
764 store.remove_node(sn.id, sn.key).await?;
765 self.set_root(None);
766 return Ok(deleted);
767 }
768 1 => {
769 store.remove_node(sn.id, sn.key).await?;
770 let e = n.values().next().ok_or_else(|| fail!("MTree::delete"))?;
771 self.set_root(Some(e.node));
772 return Ok(deleted);
773 }
774 _ => {}
775 },
776 MTreeNode::Leaf(n) => {
777 if n.is_empty() {
778 store.remove_node(sn.id, sn.key).await?;
779 self.set_root(None);
780 return Ok(deleted);
781 }
782 }
783 }
784 store.set_node(sn, n_updated).await?;
785 }
786 }
787 Ok(deleted)
788 }
789
790 #[allow(clippy::too_many_arguments)]
792 async fn delete_at_node(
793 &mut self,
794 stk: &mut Stk,
795 tx: &Transaction,
796 store: &mut MTreeStore,
797 node: MStoredNode,
798 parent_center: &Option<SharedVector>,
799 object: SharedVector,
800 id: DocId,
801 deleted: &mut bool,
802 ) -> Result<DeletionResult, Error> {
803 #[cfg(debug_assertions)]
804 debug!("delete_at_node ID: {} - obj: {:?}", node.id, object);
805 match node.n {
807 MTreeNode::Leaf(n) => {
809 self.delete_node_leaf(
810 store,
811 node.id,
812 node.key,
813 n,
814 parent_center,
815 object,
816 id,
817 deleted,
818 )
819 .await
820 }
821 MTreeNode::Internal(n) => {
823 self.delete_node_internal(
824 stk,
825 tx,
826 store,
827 node.id,
828 node.key,
829 n,
830 parent_center,
831 object,
832 id,
833 deleted,
834 )
835 .await
836 }
837 }
838 }
839
840 #[allow(clippy::too_many_arguments)]
841 async fn delete_node_internal(
842 &mut self,
843 stk: &mut Stk,
844 tx: &Transaction,
845 store: &mut MTreeStore,
846 node_id: NodeId,
847 node_key: Key,
848 mut n_node: InternalNode,
849 parent_center: &Option<SharedVector>,
850 od: SharedVector,
851 id: DocId,
852 deleted: &mut bool,
853 ) -> Result<DeletionResult, Error> {
854 #[cfg(debug_assertions)]
855 debug!("delete_node_internal ID: {} - DocID: {} - obj: {:?}", node_id, id, od);
856 let mut on_objs = Vec::new();
857 let mut n_updated = false;
858 for (on_obj, on_entry) in &n_node {
860 let on_od_dist = self.calculate_distance(on_obj, &od)?;
861 #[cfg(debug_assertions)]
862 debug!("on_od_dist: {:?} / {} / {}", on_obj, on_od_dist, on_entry.radius);
863 if on_od_dist <= on_entry.radius {
865 on_objs.push((on_obj.clone(), on_entry.clone()));
866 }
867 }
868 #[cfg(debug_assertions)]
869 debug!("on_objs: {:?}", on_objs);
870 for (on_obj, mut on_entry) in on_objs {
871 #[cfg(debug_assertions)]
872 debug!("on_obj: {:?}", on_obj);
873 let on_node = store.get_node_mut(tx, on_entry.node).await?;
875 #[cfg(debug_assertions)]
876 let d_id = on_node.id;
877 let on_obj_op = Some(on_obj.clone());
878 match stk
879 .run(|stk| {
880 self.delete_at_node(
881 stk,
882 tx,
883 store,
884 on_node,
885 &on_obj_op,
886 od.clone(),
887 id,
888 deleted,
889 )
890 })
891 .await?
892 {
893 DeletionResult::NotFound => {
894 #[cfg(debug_assertions)]
895 debug!("delete_at_node ID {} => NotFound", d_id);
896 }
897 DeletionResult::DocRemoved => {
898 #[cfg(debug_assertions)]
899 debug!("delete_at_node ID {} => DocRemoved", d_id);
900 }
901 DeletionResult::CoveringRadius(r) => {
903 #[cfg(debug_assertions)]
904 debug!("delete_at_node ID {} => CoveringRadius", d_id);
905 if r > on_entry.radius {
907 on_entry.radius = r;
909 n_node.insert(on_obj, on_entry);
910 n_updated = true;
911 }
912 }
913 DeletionResult::Underflown(sn, sn_updated) => {
914 #[cfg(debug_assertions)]
915 debug!("delete_at_node {} => Underflown", d_id);
916 if self
917 .deletion_underflown(
918 tx,
919 store,
920 parent_center,
921 &mut n_node,
922 on_obj,
923 sn,
924 sn_updated,
925 )
926 .await?
927 {
928 n_updated = true;
929 break;
930 }
931 }
932 }
933 }
934 self.delete_node_internal_check_underflown(store, node_id, node_key, n_node, n_updated)
935 .await
936 }
937
938 async fn delete_node_internal_check_underflown(
939 &mut self,
940 store: &mut MTreeStore,
941 node_id: NodeId,
942 node_key: Key,
943 n_node: InternalNode,
944 n_updated: bool,
945 ) -> Result<DeletionResult, Error> {
946 if n_node.len() < self.minimum {
948 return Ok(DeletionResult::Underflown(
950 StoredNode::new(MTreeNode::Internal(n_node), node_id, node_key, 0),
951 n_updated,
952 ));
953 }
954 let max_dist = self.compute_internal_max_distance(&n_node);
956 Self::set_stored_node(store, node_id, node_key, n_node.into_mtree_node(), n_updated)
957 .await?;
958 Ok(DeletionResult::CoveringRadius(max_dist))
959 }
960
961 async fn set_stored_node(
962 store: &mut MTreeStore,
963 node_id: NodeId,
964 node_key: Key,
965 node: MTreeNode,
966 updated: bool,
967 ) -> Result<(), Error> {
968 store.set_node(StoredNode::new(node, node_id, node_key, 0), updated).await?;
969 Ok(())
970 }
971
972 #[allow(clippy::too_many_arguments)]
973 async fn deletion_underflown(
974 &mut self,
975 tx: &Transaction,
976 store: &mut MTreeStore,
977 parent_center: &Option<SharedVector>,
978 n_node: &mut InternalNode,
979 on_obj: SharedVector,
980 p: MStoredNode,
981 p_updated: bool,
982 ) -> Result<bool, Error> {
983 #[cfg(debug_assertions)]
984 debug!("deletion_underflown Node ID: {}", p.id);
985 let min = f64::NAN;
986 let mut onn = None;
987 for (onn_obj, onn_entry) in n_node.iter() {
989 if onn_entry.node != p.id {
990 let d = self.calculate_distance(&on_obj, onn_obj)?;
991 if min.is_nan() || d < min {
992 onn = Some((onn_obj.clone(), onn_entry.clone()));
993 }
994 }
995 }
996 #[cfg(debug_assertions)]
997 debug!("deletion_underflown - p_id: {} - onn: {:?} - n_len: {}", p.id, onn, n_node.len());
998 if let Some((onn_obj, onn_entry)) = onn {
999 #[cfg(debug_assertions)]
1000 debug!("deletion_underflown: onn_entry {}", onn_entry.node);
1001 let onn_child = store.get_node_mut(tx, onn_entry.node).await?;
1003 if onn_child.n.len() + p.n.len() <= self.state.capacity as usize {
1005 self.delete_underflown_fit_into_child(
1006 store, n_node, on_obj, p, onn_obj, onn_entry, onn_child,
1007 )
1008 .await?;
1009 } else {
1010 self.delete_underflown_redistribute(
1011 store,
1012 parent_center,
1013 n_node,
1014 on_obj,
1015 onn_obj,
1016 p,
1017 onn_child,
1018 )
1019 .await?;
1020 }
1021 return Ok(true);
1022 }
1023 store.set_node(p, p_updated).await?;
1024 Ok(false)
1025 }
1026
1027 #[allow(clippy::too_many_arguments)]
1028 async fn delete_underflown_fit_into_child(
1029 &mut self,
1030 store: &mut MTreeStore,
1031 n_node: &mut InternalNode,
1032 on_obj: SharedVector,
1033 p: MStoredNode,
1034 onn_obj: SharedVector,
1035 mut onn_entry: RoutingProperties,
1036 mut onn_child: MStoredNode,
1037 ) -> Result<(), Error> {
1038 #[cfg(debug_assertions)]
1039 debug!("deletion_underflown - fit into Node ID: {}", onn_child.id);
1040 n_node.remove(&on_obj);
1042 match &mut onn_child.n {
1043 MTreeNode::Internal(s) => {
1044 let p_node = p.n.internal()?;
1045 for (p_obj, mut p_entry) in p_node {
1047 p_entry.parent_dist = self.calculate_distance(&p_obj, &onn_obj)?;
1049 s.insert(p_obj, p_entry);
1051 }
1052 let mut radius = 0.0;
1054 for s_entry in s.values() {
1055 let d = s_entry.parent_dist + s_entry.radius;
1056 if d > radius {
1057 radius = d;
1058 }
1059 }
1060 if onn_entry.radius != radius {
1061 onn_entry.radius = radius;
1062 }
1063 n_node.insert(onn_obj, onn_entry);
1064 }
1065 MTreeNode::Leaf(s) => {
1066 let p_node = p.n.leaf()?;
1067 for (p_obj, mut p_entry) in p_node {
1069 p_entry.parent_dist = self.calculate_distance(&p_obj, &onn_obj)?;
1071 s.insert(p_obj, p_entry);
1073 }
1074 let mut radius = 0.0;
1076 for s_entry in s.values() {
1077 if s_entry.parent_dist > radius {
1078 radius = s_entry.parent_dist;
1079 }
1080 }
1081 if onn_entry.radius != radius {
1082 onn_entry.radius = radius;
1083 }
1084 n_node.insert(onn_obj, onn_entry);
1085 }
1086 }
1087 store.remove_node(p.id, p.key).await?;
1088 store.set_node(onn_child, true).await?;
1089 Ok(())
1090 }
1091
1092 #[allow(clippy::too_many_arguments)]
1093 async fn delete_underflown_redistribute(
1094 &mut self,
1095 store: &mut MTreeStore,
1096 parent_center: &Option<SharedVector>,
1097 n_node: &mut InternalNode,
1098 on_obj: SharedVector,
1099 onn_obj: SharedVector,
1100 mut p: MStoredNode,
1101 onn_child: MStoredNode,
1102 ) -> Result<(), Error> {
1103 #[cfg(debug_assertions)]
1104 debug!("deletion_underflown - delete_underflown_redistribute Node ID: {}", p.id);
1105 n_node.remove(&on_obj);
1107 n_node.remove(&onn_obj);
1108 p.n.merge(onn_child.n)?;
1110 let (o1, mut e1, o2, mut e2) = match p.n {
1112 MTreeNode::Internal(n) => self.split_node(store, p.id, p.key, n).await?,
1113 MTreeNode::Leaf(n) => self.split_node(store, p.id, p.key, n).await?,
1114 };
1115 if let Some(pc) = parent_center {
1116 e1.parent_dist = self.calculate_distance(&o1, pc)?;
1117 e2.parent_dist = self.calculate_distance(&o2, pc)?;
1118 } else {
1119 e1.parent_dist = 0.0;
1120 e2.parent_dist = 0.0;
1121 }
1122 n_node.insert(o1, e1);
1124 n_node.insert(o2, e2);
1125 store.remove_node(onn_child.id, onn_child.key).await?;
1126 Ok(())
1127 }
1128
1129 #[allow(clippy::too_many_arguments)]
1130 async fn delete_node_leaf(
1131 &mut self,
1132 store: &mut MTreeStore,
1133 node_id: NodeId,
1134 node_key: Key,
1135 mut leaf_node: LeafNode,
1136 parent_center: &Option<SharedVector>,
1137 od: SharedVector,
1138 id: DocId,
1139 deleted: &mut bool,
1140 ) -> Result<DeletionResult, Error> {
1141 #[cfg(debug_assertions)]
1142 debug!("delete_node_leaf - n_id: {} - obj: {:?} - doc: {}", node_id, od, id);
1143 let mut entry_removed = false;
1144 if let Entry::Occupied(mut e) = leaf_node.entry(od) {
1146 let p = e.get_mut();
1147 if p.docs.remove(id) {
1149 *deleted = true;
1150 #[cfg(debug_assertions)]
1151 debug!("deleted - n_id: {} - doc: {}", node_id, id);
1152 if p.docs.is_empty() {
1153 e.remove();
1154 entry_removed = true;
1155 }
1156 }
1157 } else {
1158 let sn = StoredNode::new(MTreeNode::Leaf(leaf_node), node_id, node_key, 0);
1159 store.set_node(sn, false).await?;
1160 return Ok(DeletionResult::NotFound);
1161 }
1162 if entry_removed {
1163 if leaf_node.len() < self.minimum {
1165 return Ok(DeletionResult::Underflown(
1166 StoredNode::new(MTreeNode::Leaf(leaf_node), node_id, node_key, 0),
1167 true,
1168 ));
1169 }
1170 let max_dist = self.compute_leaf_max_distance(&leaf_node, parent_center)?;
1172 let sn = StoredNode::new(MTreeNode::Leaf(leaf_node), node_id, node_key, 0);
1173 store.set_node(sn, true).await?;
1174 Ok(DeletionResult::CoveringRadius(max_dist))
1175 } else {
1176 let sn = StoredNode::new(MTreeNode::Leaf(leaf_node), node_id, node_key, 0);
1177 store.set_node(sn, true).await?;
1178 Ok(DeletionResult::DocRemoved)
1179 }
1180 }
1181}
1182
1183struct DistanceCache(HashMap<(SharedVector, SharedVector), f64>);
1184
1185pub(in crate::idx) type MTreeStore = TreeStore<MTreeNode>;
1186type MStoredNode = StoredNode<MTreeNode>;
1187
1188type InternalMap = HashMap<SharedVector, RoutingProperties>;
1189
1190type LeafMap = HashMap<SharedVector, ObjectProperties>;
1191
1192#[derive(Debug, Clone)]
1193#[non_exhaustive]
1200pub enum MTreeNode {
1201 Internal(InternalNode),
1202 Leaf(LeafNode),
1203}
1204
1205impl MTreeNode {
1206 fn len(&self) -> usize {
1207 match self {
1208 MTreeNode::Internal(e) => e.len(),
1209 MTreeNode::Leaf(m) => m.len(),
1210 }
1211 }
1212
1213 fn internal(self) -> Result<InternalNode, Error> {
1214 match self {
1215 MTreeNode::Internal(n) => Ok(n),
1216 MTreeNode::Leaf(_) => Err(fail!("MTreeNode::internal")),
1217 }
1218 }
1219
1220 fn leaf(self) -> Result<LeafNode, Error> {
1221 match self {
1222 MTreeNode::Internal(_) => Err(fail!("MTreeNode::lead")),
1223 MTreeNode::Leaf(n) => Ok(n),
1224 }
1225 }
1226
1227 fn merge(&mut self, other: MTreeNode) -> Result<(), Error> {
1228 match (self, other) {
1229 (MTreeNode::Internal(s), MTreeNode::Internal(o)) => {
1230 Self::merge_internal(s, o);
1231 Ok(())
1232 }
1233 (MTreeNode::Leaf(s), MTreeNode::Leaf(o)) => {
1234 Self::merge_leaf(s, o);
1235 Ok(())
1236 }
1237 (_, _) => Err(fail!("MTreeNode::merge")),
1238 }
1239 }
1240
1241 fn merge_internal(n: &mut InternalNode, other: InternalNode) {
1242 for (o, p) in other {
1243 n.insert(o, p);
1244 }
1245 }
1246
1247 fn merge_leaf(s: &mut LeafNode, o: LeafNode) {
1248 for (v, p) in o {
1249 match s.entry(v) {
1250 Entry::Occupied(mut e) => {
1251 let props = e.get_mut();
1252 for doc in p.docs {
1253 props.docs.insert(doc);
1254 }
1255 }
1256 Entry::Vacant(e) => {
1257 e.insert(p);
1258 }
1259 }
1260 }
1261 }
1262}
1263
1264impl Display for MTreeNode {
1265 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1266 match self {
1267 MTreeNode::Internal(i) => write!(f, "Internal: {i:?}"),
1268 MTreeNode::Leaf(l) => write!(f, "Leaf: {l:?}"),
1269 }
1270 }
1271}
1272
1273trait NodeVectors: Sized {
1274 #[allow(dead_code)]
1275 fn len(&self) -> usize;
1276
1277 fn get_objects(&self) -> Vec<SharedVector>;
1278
1279 fn extract_node(
1280 &mut self,
1281 distances: &DistanceCache,
1282 p: SharedVector,
1283 a: Vec<SharedVector>,
1284 ) -> Result<(Self, f64, SharedVector), Error>;
1285
1286 fn into_mtree_node(self) -> MTreeNode;
1287}
1288
1289impl NodeVectors for LeafNode {
1290 fn len(&self) -> usize {
1291 self.len()
1292 }
1293
1294 fn get_objects(&self) -> Vec<SharedVector> {
1295 self.keys().cloned().collect()
1296 }
1297
1298 fn extract_node(
1299 &mut self,
1300 distances: &DistanceCache,
1301 p: SharedVector,
1302 a: Vec<SharedVector>,
1303 ) -> Result<(Self, f64, SharedVector), Error> {
1304 let mut n = LeafNode::new();
1305 let mut r = 0f64;
1306 for o in a {
1307 let mut props =
1308 self.remove(&o).ok_or_else(|| fail!("NodeVectors/LeafNode::extract_node)"))?;
1309 let dist = *distances.0.get(&(o.clone(), p.clone())).unwrap_or(&0f64);
1310 if dist > r {
1311 r = dist;
1312 }
1313 props.parent_dist = dist;
1314 n.insert(o, props);
1315 }
1316 Ok((n, r, p))
1317 }
1318
1319 fn into_mtree_node(self) -> MTreeNode {
1320 MTreeNode::Leaf(self)
1321 }
1322}
1323
1324impl NodeVectors for InternalNode {
1325 fn len(&self) -> usize {
1326 self.len()
1327 }
1328
1329 fn get_objects(&self) -> Vec<SharedVector> {
1330 self.keys().cloned().collect()
1331 }
1332
1333 fn extract_node(
1334 &mut self,
1335 distances: &DistanceCache,
1336 p: SharedVector,
1337 a: Vec<SharedVector>,
1338 ) -> Result<(Self, f64, SharedVector), Error> {
1339 let mut n = InternalNode::new();
1340 let mut max_r = 0f64;
1341 for o in a {
1342 let mut props =
1343 self.remove(&o).ok_or_else(|| fail!("NodeVectors/InternalNode::extract_node"))?;
1344 let dist = *distances.0.get(&(o.clone(), p.clone())).unwrap_or(&0f64);
1345 let r = dist + props.radius;
1346 if r > max_r {
1347 max_r = r;
1348 }
1349 props.parent_dist = dist;
1350 n.insert(o, props);
1351 }
1352 Ok((n, max_r, p))
1353 }
1354
1355 fn into_mtree_node(self) -> MTreeNode {
1356 MTreeNode::Internal(self)
1357 }
1358}
1359
1360pub type InternalNode = InternalMap;
1361pub type LeafNode = LeafMap;
1362
1363impl TreeNode for MTreeNode {
1364 fn try_from_val(val: Val) -> Result<Self, Error> {
1365 let mut c: Cursor<Vec<u8>> = Cursor::new(val);
1366 let node_type: u8 = bincode::deserialize_from(&mut c)?;
1367 match node_type {
1368 1u8 => {
1369 let objects: LeafNode = bincode::deserialize_from(c)?;
1370 Ok(MTreeNode::Leaf(objects))
1371 }
1372 2u8 => {
1373 let entries: InternalNode = bincode::deserialize_from(c)?;
1374 Ok(MTreeNode::Internal(entries))
1375 }
1376 _ => Err(Error::CorruptedIndex("MTreeNode::try_from_val")),
1377 }
1378 }
1379
1380 fn try_into_val(&self) -> Result<Val, Error> {
1381 let mut c: Cursor<Vec<u8>> = Cursor::new(Vec::new());
1382 match self {
1383 MTreeNode::Leaf(objects) => {
1384 bincode::serialize_into(&mut c, &1u8)?;
1385 bincode::serialize_into(&mut c, objects)?;
1386 }
1387 MTreeNode::Internal(entries) => {
1388 bincode::serialize_into(&mut c, &2u8)?;
1389 bincode::serialize_into(&mut c, entries)?;
1390 }
1391 };
1392 Ok(c.into_inner())
1393 }
1394}
1395
1396pub(crate) struct MtStatistics {
1397 doc_ids: BStatistics,
1398}
1399
1400impl From<MtStatistics> for Value {
1401 fn from(stats: MtStatistics) -> Self {
1402 let mut res = Object::default();
1403 res.insert("doc_ids".to_owned(), Value::from(stats.doc_ids));
1404 Value::from(res)
1405 }
1406}
1407
1408#[revisioned(revision = 2)]
1409#[derive(Clone, Serialize, Deserialize)]
1410#[non_exhaustive]
1411pub struct MState {
1412 capacity: u16,
1413 root: Option<NodeId>,
1414 next_node_id: NodeId,
1415 #[revision(start = 2)]
1416 generation: u64,
1417}
1418
1419impl MState {
1420 pub fn new(capacity: u16) -> Self {
1421 assert!(capacity >= 2, "Capacity should be >= 2");
1422 Self {
1423 capacity,
1424 root: None,
1425 next_node_id: 0,
1426 generation: 0,
1427 }
1428 }
1429}
1430
1431#[derive(Clone, Serialize, Deserialize, Debug)]
1432#[non_exhaustive]
1433pub struct RoutingProperties {
1434 node: NodeId,
1436 parent_dist: f64,
1438 radius: f64,
1440}
1441
1442#[derive(Serialize, Deserialize, Debug, Clone)]
1443#[non_exhaustive]
1444pub struct ObjectProperties {
1445 parent_dist: f64,
1447 docs: RoaringTreemap,
1449}
1450
1451impl ObjectProperties {
1452 fn new(parent_dist: f64, id: DocId) -> Self {
1453 let mut docs = RoaringTreemap::new();
1454 docs.insert(id);
1455 Self {
1456 parent_dist,
1457 docs,
1458 }
1459 }
1460
1461 fn new_root(id: DocId) -> Self {
1462 Self::new(0.0, id)
1463 }
1464}
1465
1466impl VersionedStore for MState {}
1467
1468#[cfg(test)]
1469mod tests {
1470
1471 use crate::ctx::{Context, MutableContext};
1472 use crate::err::Error;
1473 use crate::idx::docids::{DocId, DocIds};
1474 use crate::idx::planner::checker::MTreeConditionChecker;
1475 use crate::idx::trees::knn::tests::TestCollection;
1476 use crate::idx::trees::mtree::{MState, MTree, MTreeNode, MTreeSearchContext, MTreeStore};
1477 use crate::idx::trees::store::{NodeId, TreeNodeProvider, TreeStore};
1478 use crate::idx::trees::vector::SharedVector;
1479 use crate::idx::IndexKeyBase;
1480 use crate::kvs::LockType::*;
1481 use crate::kvs::Transaction;
1482 use crate::kvs::{Datastore, TransactionType};
1483 use crate::sql::index::{Distance, VectorType};
1484 use ahash::{HashMap, HashMapExt, HashSet};
1485 use reblessive::tree::Stk;
1486 use std::collections::VecDeque;
1487 use test_log::test;
1488
1489 async fn new_operation(
1490 ds: &Datastore,
1491 t: &MTree,
1492 tt: TransactionType,
1493 cache_size: usize,
1494 ) -> (Context, TreeStore<MTreeNode>) {
1495 let tx = ds.transaction(tt, Optimistic).await.unwrap().enclose();
1496 let st = tx
1497 .index_caches()
1498 .get_store_mtree(TreeNodeProvider::Debug, t.state.generation, tt, cache_size)
1499 .await
1500 .unwrap();
1501 let mut ctx = MutableContext::default();
1502 ctx.set_transaction(tx);
1503 (ctx.freeze(), st)
1504 }
1505
1506 async fn finish_operation(
1507 t: &mut MTree,
1508 tx: &Transaction,
1509 mut st: TreeStore<MTreeNode>,
1510 commit: bool,
1511 ) -> Result<(), Error> {
1512 if let Some(new_cache) = st.finish(tx).await? {
1513 assert!(new_cache.len() > 0, "new_cache.len() = {}", new_cache.len());
1514 t.state.generation += 1;
1515 tx.index_caches().advance_store_mtree(new_cache);
1516 }
1517 if commit {
1518 tx.commit().await?;
1519 Ok(())
1520 } else {
1521 tx.cancel().await
1522 }
1523 }
1524
1525 async fn insert_collection_one_by_one(
1526 stk: &mut Stk,
1527 ds: &Datastore,
1528 t: &mut MTree,
1529 collection: &TestCollection,
1530 cache_size: usize,
1531 ) -> Result<HashMap<DocId, SharedVector>, Error> {
1532 let mut map = HashMap::with_capacity(collection.len());
1533 let mut c = 0;
1534 for (doc_id, obj) in collection.to_vec_ref() {
1535 {
1536 let (ctx, mut st) = new_operation(ds, t, TransactionType::Write, cache_size).await;
1537 let tx = ctx.tx();
1538 t.insert(stk, &tx, &mut st, obj.clone(), *doc_id).await?;
1539 finish_operation(t, &tx, st, true).await?;
1540 map.insert(*doc_id, obj.clone());
1541 }
1542 c += 1;
1543 {
1544 let (ctx, mut st) = new_operation(ds, t, TransactionType::Read, cache_size).await;
1545 let tx = ctx.tx();
1546 let p = check_tree_properties(&tx, &mut st, t).await?;
1547 assert_eq!(p.doc_count, c);
1548 }
1549 }
1550 Ok(map)
1551 }
1552
1553 async fn insert_collection_batch(
1554 stk: &mut Stk,
1555 ds: &Datastore,
1556 t: &mut MTree,
1557 collection: &TestCollection,
1558 cache_size: usize,
1559 ) -> Result<HashMap<DocId, SharedVector>, Error> {
1560 let mut map = HashMap::with_capacity(collection.len());
1561 {
1562 let (ctx, mut st) = new_operation(ds, t, TransactionType::Write, cache_size).await;
1563 let tx = ctx.tx();
1564 for (doc_id, obj) in collection.to_vec_ref() {
1565 t.insert(stk, &tx, &mut st, obj.clone(), *doc_id).await?;
1566 map.insert(*doc_id, obj.clone());
1567 }
1568 finish_operation(t, &tx, st, true).await?;
1569 }
1570 {
1571 let (ctx, mut st) = new_operation(ds, t, TransactionType::Read, cache_size).await;
1572 let tx = ctx.tx();
1573 check_tree_properties(&tx, &mut st, t).await?;
1574 }
1575 Ok(map)
1576 }
1577
1578 async fn delete_collection(
1579 stk: &mut Stk,
1580 ds: &Datastore,
1581 doc_ids: &DocIds,
1582 t: &mut MTree,
1583 collection: &TestCollection,
1584 cache_size: usize,
1585 ) -> Result<(), Error> {
1586 let mut all_deleted = true;
1587 for (doc_id, obj) in collection.to_vec_ref() {
1588 let deleted = {
1589 debug!("### Remove {} {:?}", doc_id, obj);
1590 let (ctx, mut st) = new_operation(ds, t, TransactionType::Write, cache_size).await;
1591 let tx = ctx.tx();
1592 let deleted = t.delete(stk, &tx, &mut st, obj.clone(), *doc_id).await?;
1593 finish_operation(t, &tx, st, true).await?;
1594 drop(tx);
1595 deleted
1596 };
1597 all_deleted = all_deleted && deleted;
1598 if deleted {
1599 let (ctx, st) = new_operation(ds, t, TransactionType::Read, cache_size).await;
1600 let mut chk = MTreeConditionChecker::new(&ctx);
1601 let search = MTreeSearchContext {
1602 ctx: &ctx,
1603 pt: obj.clone(),
1604 k: 1,
1605 store: &st,
1606 };
1607 let res = t.knn_search(&search, doc_ids, stk, &mut chk).await?;
1608 assert!(
1609 !res.docs.iter().any(|(id, _)| id == doc_id),
1610 "Found: {} {:?}",
1611 doc_id,
1612 obj
1613 );
1614 } else {
1615 warn!("Delete failed: {} {:?}", doc_id, obj);
1617 }
1618 {
1619 let (ctx, mut st) = new_operation(ds, t, TransactionType::Read, cache_size).await;
1620 let tx = ctx.tx();
1621 check_tree_properties(&tx, &mut st, t).await?;
1622 drop(tx);
1623 }
1624 }
1625
1626 if all_deleted {
1627 let (ctx, mut st) = new_operation(ds, t, TransactionType::Read, cache_size).await;
1628 let tx = ctx.tx();
1629 check_tree_properties(&tx, &mut st, t).await?.check(0, 0, None, None, 0, 0);
1630 drop(tx);
1631 }
1632 Ok(())
1633 }
1634
1635 async fn find_collection(
1636 stk: &mut Stk,
1637 ds: &Datastore,
1638 doc_ids: &DocIds,
1639 t: &mut MTree,
1640 collection: &TestCollection,
1641 cache_size: usize,
1642 ) -> Result<(), Error> {
1643 let (ctx, mut st) = new_operation(ds, t, TransactionType::Read, cache_size).await;
1644 let max_knn = 20.max(collection.len());
1645 for (doc_id, obj) in collection.to_vec_ref() {
1646 for knn in 1..max_knn {
1647 let mut chk = MTreeConditionChecker::new(&ctx);
1648 let search = MTreeSearchContext {
1649 ctx: &ctx,
1650 pt: obj.clone(),
1651 k: knn,
1652 store: &st,
1653 };
1654 let res = t.knn_search(&search, doc_ids, stk, &mut chk).await?;
1655 let docs: Vec<DocId> = res.docs.iter().map(|(d, _)| *d).collect();
1656 if collection.is_unique() {
1657 assert!(
1658 docs.contains(doc_id),
1659 "Search: {:?} - Knn: {} - Wrong Doc - Expected: {} - Got: {:?}",
1660 obj,
1661 knn,
1662 doc_id,
1663 res.docs
1664 );
1665 }
1666 let expected_len = collection.len().min(knn);
1667 if expected_len != res.docs.len() {
1668 #[cfg(debug_assertions)]
1669 debug!("{:?}", res.visited_nodes);
1670 let tx = ctx.tx();
1671 check_tree_properties(&tx, &mut st, t).await?;
1672 }
1673 assert_eq!(
1674 expected_len,
1675 res.docs.len(),
1676 "Wrong knn count - Expected: {} - Got: {} - Collection: {}",
1677 expected_len,
1678 res.docs.len(),
1679 collection.len(),
1680 )
1681 }
1682 }
1683 Ok(())
1684 }
1685
1686 async fn check_full_knn(
1687 stk: &mut Stk,
1688 ds: &Datastore,
1689 doc_ids: &DocIds,
1690 t: &mut MTree,
1691 map: &HashMap<DocId, SharedVector>,
1692 cache_size: usize,
1693 ) -> Result<(), Error> {
1694 let (ctx, st) = new_operation(ds, t, TransactionType::Read, cache_size).await;
1695 for obj in map.values() {
1696 let mut chk = MTreeConditionChecker::new(&ctx);
1697 let search = MTreeSearchContext {
1698 ctx: &ctx,
1699 pt: obj.clone(),
1700 k: map.len(),
1701 store: &st,
1702 };
1703 let res = t.knn_search(&search, doc_ids, stk, &mut chk).await?;
1704 assert_eq!(
1705 map.len(),
1706 res.docs.len(),
1707 "Wrong knn count - Expected: {} - Got: {} - Collection: {}",
1708 map.len(),
1709 res.docs.len(),
1710 map.len(),
1711 );
1712 let mut dist = 0.0;
1714 for (doc, d) in res.docs {
1715 let o = map.get(&doc).unwrap();
1716 debug!("doc: {doc} - d: {d} - {obj:?} - {o:?}");
1717 assert!(d >= dist, "d: {d} - dist: {dist}");
1718 dist = d;
1719 }
1720 }
1721 Ok(())
1722 }
1723
1724 #[allow(clippy::too_many_arguments)]
1725 async fn test_mtree_collection(
1726 stk: &mut Stk,
1727 capacities: &[u16],
1728 vector_type: VectorType,
1729 collection: TestCollection,
1730 check_find: bool,
1731 check_full: bool,
1732 check_delete: bool,
1733 cache_size: usize,
1734 ) -> Result<(), Error> {
1735 for distance in [Distance::Euclidean, Distance::Cosine, Distance::Manhattan] {
1736 if distance == Distance::Cosine && vector_type == VectorType::F64 {
1737 continue;
1739 }
1740 for capacity in capacities {
1741 info!(
1742 "test_mtree_collection - Distance: {:?} - Capacity: {} - Collection: {} - Vector type: {}",
1743 distance,
1744 capacity,
1745 collection.len(),
1746 vector_type,
1747 );
1748 let ds = Datastore::new("memory").await?;
1749
1750 let mut t = MTree::new(MState::new(*capacity), distance.clone());
1751
1752 let (ctx, _st) = new_operation(&ds, &t, TransactionType::Read, cache_size).await;
1753 let tx = ctx.tx();
1754 let doc_ids =
1755 DocIds::new(&tx, TransactionType::Read, IndexKeyBase::default(), 7, 100)
1756 .await
1757 .unwrap();
1758
1759 let map = if collection.len() < 1000 {
1760 insert_collection_one_by_one(stk, &ds, &mut t, &collection, cache_size).await?
1761 } else {
1762 insert_collection_batch(stk, &ds, &mut t, &collection, cache_size).await?
1763 };
1764 if check_find {
1765 find_collection(stk, &ds, &doc_ids, &mut t, &collection, cache_size).await?;
1766 }
1767 if check_full {
1768 check_full_knn(stk, &ds, &doc_ids, &mut t, &map, cache_size).await?;
1769 }
1770 if check_delete {
1771 delete_collection(stk, &ds, &doc_ids, &mut t, &collection, cache_size).await?;
1772 }
1773 }
1774 }
1775 Ok(())
1776 }
1777
1778 #[test(tokio::test)]
1779 #[ignore]
1780 async fn test_mtree_unique_xs() -> Result<(), Error> {
1781 let mut stack = reblessive::tree::TreeStack::new();
1782 stack
1783 .enter(|stk| async {
1784 for vt in [
1785 VectorType::F64,
1786 VectorType::F32,
1787 VectorType::I64,
1788 VectorType::I32,
1789 VectorType::I16,
1790 ] {
1791 for i in 0..30 {
1792 test_mtree_collection(
1793 stk,
1794 &[3, 40],
1795 vt,
1796 TestCollection::new(true, i, vt, 2, &Distance::Euclidean),
1797 true,
1798 true,
1799 true,
1800 100,
1801 )
1802 .await?;
1803 }
1804 }
1805 Ok(())
1806 })
1807 .finish()
1808 .await
1809 }
1810
1811 #[test(tokio::test)]
1812 #[ignore]
1813 async fn test_mtree_unique_xs_full_cache() -> Result<(), Error> {
1814 let mut stack = reblessive::tree::TreeStack::new();
1815 stack
1816 .enter(|stk| async {
1817 for vt in [
1818 VectorType::F64,
1819 VectorType::F32,
1820 VectorType::I64,
1821 VectorType::I32,
1822 VectorType::I16,
1823 ] {
1824 for i in 0..30 {
1825 test_mtree_collection(
1826 stk,
1827 &[3, 40],
1828 vt,
1829 TestCollection::new(true, i, vt, 2, &Distance::Euclidean),
1830 true,
1831 true,
1832 true,
1833 0,
1834 )
1835 .await?;
1836 }
1837 }
1838 Ok(())
1839 })
1840 .finish()
1841 .await
1842 }
1843
1844 #[test(tokio::test(flavor = "multi_thread"))]
1845 #[ignore]
1846 async fn test_mtree_unique_small() -> Result<(), Error> {
1847 let mut stack = reblessive::tree::TreeStack::new();
1848 stack
1849 .enter(|stk| async {
1850 for vt in [VectorType::F64, VectorType::I64] {
1851 test_mtree_collection(
1852 stk,
1853 &[10, 20],
1854 vt,
1855 TestCollection::new(true, 150, vt, 3, &Distance::Euclidean),
1856 true,
1857 true,
1858 false,
1859 0,
1860 )
1861 .await?;
1862 }
1863 Ok(())
1864 })
1865 .finish()
1866 .await
1867 }
1868
1869 #[test(tokio::test(flavor = "multi_thread"))]
1870 async fn test_mtree_unique_normal() -> Result<(), Error> {
1871 let mut stack = reblessive::tree::TreeStack::new();
1872 stack
1873 .enter(|stk| async {
1874 for vt in [VectorType::F32, VectorType::I32] {
1875 test_mtree_collection(
1876 stk,
1877 &[40],
1878 vt,
1879 TestCollection::new(true, 500, vt, 5, &Distance::Euclidean),
1880 false,
1881 true,
1882 false,
1883 100,
1884 )
1885 .await?;
1886 }
1887 Ok(())
1888 })
1889 .finish()
1890 .await
1891 }
1892
1893 #[test(tokio::test(flavor = "multi_thread"))]
1894 async fn test_mtree_unique_normal_full_cache() -> Result<(), Error> {
1895 let mut stack = reblessive::tree::TreeStack::new();
1896 stack
1897 .enter(|stk| async {
1898 for vt in [VectorType::F32, VectorType::I32] {
1899 test_mtree_collection(
1900 stk,
1901 &[40],
1902 vt,
1903 TestCollection::new(true, 500, vt, 5, &Distance::Euclidean),
1904 false,
1905 true,
1906 false,
1907 0,
1908 )
1909 .await?;
1910 }
1911 Ok(())
1912 })
1913 .finish()
1914 .await
1915 }
1916
1917 #[test(tokio::test(flavor = "multi_thread"))]
1918 async fn test_mtree_unique_normal_small_cache() -> Result<(), Error> {
1919 let mut stack = reblessive::tree::TreeStack::new();
1920 stack
1921 .enter(|stk| async {
1922 for vt in [VectorType::F32, VectorType::I32] {
1923 test_mtree_collection(
1924 stk,
1925 &[40],
1926 vt,
1927 TestCollection::new(true, 500, vt, 5, &Distance::Euclidean),
1928 false,
1929 true,
1930 false,
1931 10,
1932 )
1933 .await?;
1934 }
1935 Ok(())
1936 })
1937 .finish()
1938 .await
1939 }
1940
1941 #[test(tokio::test)]
1942 #[ignore]
1943 async fn test_mtree_random_xs() -> Result<(), Error> {
1944 let mut stack = reblessive::tree::TreeStack::new();
1945 stack
1946 .enter(|stk| async {
1947 for vt in [
1948 VectorType::F64,
1949 VectorType::F32,
1950 VectorType::I64,
1951 VectorType::I32,
1952 VectorType::I16,
1953 ] {
1954 for collection_size in [0, 1, 5, 10, 15, 20, 30, 40] {
1955 test_mtree_collection(
1956 stk,
1957 &[3, 10, 40],
1958 vt,
1959 TestCollection::new(
1960 false,
1961 collection_size,
1962 vt,
1963 1,
1964 &Distance::Euclidean,
1965 ),
1966 true,
1967 true,
1968 true,
1969 0,
1970 )
1971 .await?;
1972 }
1973 }
1974 Ok(())
1975 })
1976 .finish()
1977 .await
1978 }
1979
1980 #[test(tokio::test(flavor = "multi_thread"))]
1981 #[ignore]
1982 async fn test_mtree_random_small() -> Result<(), Error> {
1983 let mut stack = reblessive::tree::TreeStack::new();
1984 stack
1985 .enter(|stk| async {
1986 for vt in [VectorType::F64, VectorType::I64] {
1987 test_mtree_collection(
1988 stk,
1989 &[10, 20],
1990 vt,
1991 TestCollection::new(false, 150, vt, 3, &Distance::Euclidean),
1992 true,
1993 true,
1994 false,
1995 0,
1996 )
1997 .await?;
1998 }
1999 Ok(())
2000 })
2001 .finish()
2002 .await
2003 }
2004
2005 #[test(tokio::test(flavor = "multi_thread"))]
2006 async fn test_mtree_random_normal() -> Result<(), Error> {
2007 let mut stack = reblessive::tree::TreeStack::new();
2008 stack
2009 .enter(|stk| async {
2010 for vt in [VectorType::F32, VectorType::I32] {
2011 test_mtree_collection(
2012 stk,
2013 &[40],
2014 vt,
2015 TestCollection::new(false, 500, vt, 5, &Distance::Euclidean),
2016 false,
2017 true,
2018 false,
2019 0,
2020 )
2021 .await?;
2022 }
2023 Ok(())
2024 })
2025 .finish()
2026 .await
2027 }
2028
2029 #[derive(Default, Debug)]
2030 struct CheckedProperties {
2031 node_count: usize,
2032 max_depth: usize,
2033 min_leaf_depth: Option<usize>,
2034 max_leaf_depth: Option<usize>,
2035 min_objects: Option<usize>,
2036 max_objects: Option<usize>,
2037 object_count: usize,
2038 doc_count: usize,
2039 }
2040
2041 impl CheckedProperties {
2042 fn check(
2043 &self,
2044 expected_node_count: usize,
2045 expected_depth: usize,
2046 expected_min_objects: Option<usize>,
2047 expected_max_objects: Option<usize>,
2048 expected_object_count: usize,
2049 expected_doc_count: usize,
2050 ) {
2051 assert_eq!(self.node_count, expected_node_count, "Node count - {:?}", self);
2052 assert_eq!(self.max_depth, expected_depth, "Max depth - {:?}", self);
2053 let expected_leaf_depth = if expected_depth == 0 {
2054 None
2055 } else {
2056 Some(expected_depth)
2057 };
2058 assert_eq!(self.min_leaf_depth, expected_leaf_depth, "Min leaf depth - {:?}", self);
2059 assert_eq!(self.max_leaf_depth, expected_leaf_depth, "Max leaf depth - {:?}", self);
2060 assert_eq!(self.min_objects, expected_min_objects, "Min objects - {:?}", self);
2061 assert_eq!(self.max_objects, expected_max_objects, "Max objects - {:?}", self);
2062 assert_eq!(self.object_count, expected_object_count, "Object count- {:?}", self);
2063 assert_eq!(self.doc_count, expected_doc_count, "Doc count - {:?}", self);
2064 }
2065 }
2066
2067 async fn check_tree_properties(
2068 tx: &Transaction,
2069 st: &mut MTreeStore,
2070 t: &MTree,
2071 ) -> Result<CheckedProperties, Error> {
2072 debug!("CheckTreeProperties");
2073 let mut node_ids = HashSet::default();
2074 let mut checks = CheckedProperties::default();
2075 let mut nodes: VecDeque<(NodeId, f64, Option<SharedVector>, usize)> = VecDeque::new();
2076 if let Some(root_id) = t.state.root {
2077 nodes.push_back((root_id, 0.0, None, 1));
2078 }
2079 let mut leaf_objects = HashSet::default();
2080 while let Some((node_id, radius, center, depth)) = nodes.pop_front() {
2081 assert!(node_ids.insert(node_id), "Node already exist: {}", node_id);
2082 checks.node_count += 1;
2083 if depth > checks.max_depth {
2084 checks.max_depth = depth;
2085 }
2086 let node = st.get_node(tx, node_id).await?;
2087 debug!(
2088 "Node id: {} - depth: {} - len: {} - {:?}",
2089 node.id,
2090 depth,
2091 node.n.len(),
2092 node.n
2093 );
2094 assert_ne!(node.n.len(), 0, "Empty node! {}", node.id);
2095 if Some(node_id) != t.state.root {
2096 assert!(
2097 node.n.len() >= t.minimum && node.n.len() <= t.state.capacity as usize,
2098 "Wrong node size - Node: {} - Size: {}",
2099 node_id,
2100 node.n.len()
2101 );
2102 }
2103 match &node.n {
2104 MTreeNode::Internal(entries) => {
2105 let next_depth = depth + 1;
2106 for (o, p) in entries {
2107 if let Some(center) = center.as_ref() {
2108 let pd = t.calculate_distance(center, o)?;
2109 assert_eq!(pd, p.parent_dist, "Incorrect parent distance");
2110 assert!(pd + p.radius <= radius);
2111 }
2112 nodes.push_back((p.node, p.radius, Some(o.clone()), next_depth))
2113 }
2114 }
2115 MTreeNode::Leaf(m) => {
2116 checks.object_count += m.len();
2117 update_min(&mut checks.min_objects, m.len());
2118 update_max(&mut checks.max_objects, m.len());
2119 update_min(&mut checks.min_leaf_depth, depth);
2120 update_max(&mut checks.max_leaf_depth, depth);
2121 for (o, p) in m {
2122 if !leaf_objects.insert(o.clone()) {
2123 panic!("Leaf object already exists: {:?}", o);
2124 }
2125 if let Some(center) = center.as_ref() {
2126 let pd = t.calculate_distance(center, o)?;
2127 debug!("calc_dist: {:?} {:?} = {}", center, &o, pd);
2128 assert_eq!(pd, p.parent_dist, "Invalid parent distance ({}): {} - Expected: {} - Node Id: {} - Obj: {:?} - Center: {:?}", p.parent_dist, t.distance, pd, node_id, o, center);
2129 }
2130 checks.doc_count += p.docs.len() as usize;
2131 }
2132 }
2133 }
2134 }
2135 Ok(checks)
2136 }
2137
2138 fn update_min(min: &mut Option<usize>, val: usize) {
2139 if let Some(m) = *min {
2140 if val < m {
2141 *min = Some(val);
2142 }
2143 } else {
2144 *min = Some(val);
2145 }
2146 }
2147
2148 fn update_max(max: &mut Option<usize>, val: usize) {
2149 if let Some(m) = *max {
2150 if val > m {
2151 *max = Some(val);
2152 }
2153 } else {
2154 *max = Some(val);
2155 }
2156 }
2157}