1pub(in crate::idx) mod docs;
2mod elements;
3mod flavor;
4mod heuristic;
5pub mod index;
6mod layer;
7
8use crate::err::Error;
9use crate::idx::planner::checker::HnswConditionChecker;
10use crate::idx::trees::dynamicset::DynamicSet;
11use crate::idx::trees::hnsw::docs::HnswDocs;
12use crate::idx::trees::hnsw::docs::VecDocs;
13use crate::idx::trees::hnsw::elements::HnswElements;
14use crate::idx::trees::hnsw::heuristic::Heuristic;
15use crate::idx::trees::hnsw::index::HnswCheckedSearchContext;
16
17use crate::idx::trees::hnsw::layer::{HnswLayer, LayerState};
18use crate::idx::trees::knn::DoublePriorityQueue;
19use crate::idx::trees::vector::{SerializedVector, SharedVector, Vector};
20use crate::idx::{IndexKeyBase, VersionedStore};
21use crate::kvs::{Key, Transaction, Val};
22use crate::sql::index::HnswParams;
23use rand::prelude::SmallRng;
24use rand::{Rng, SeedableRng};
25use reblessive::tree::Stk;
26use revision::revisioned;
27use serde::{Deserialize, Serialize};
28
29struct HnswSearch {
30 pt: SharedVector,
31 k: usize,
32 ef: usize,
33}
34
35impl HnswSearch {
36 pub(super) fn new(pt: SharedVector, k: usize, ef: usize) -> Self {
37 Self {
38 pt,
39 k,
40 ef,
41 }
42 }
43}
44
45#[revisioned(revision = 1)]
46#[derive(Default, Serialize, Deserialize)]
47pub(super) struct HnswState {
48 enter_point: Option<ElementId>,
49 next_element_id: ElementId,
50 layer0: LayerState,
51 layers: Vec<LayerState>,
52}
53
54impl VersionedStore for HnswState {}
55
56struct Hnsw<L0, L>
57where
58 L0: DynamicSet,
59 L: DynamicSet,
60{
61 ikb: IndexKeyBase,
62 state_key: Key,
63 state: HnswState,
64 m: usize,
65 efc: usize,
66 ml: f64,
67 layer0: HnswLayer<L0>,
68 layers: Vec<HnswLayer<L>>,
69 elements: HnswElements,
70 rng: SmallRng,
71 heuristic: Heuristic,
72}
73
74pub(crate) type ElementId = u64;
75
76impl<L0, L> Hnsw<L0, L>
77where
78 L0: DynamicSet,
79 L: DynamicSet,
80{
81 fn new(ikb: IndexKeyBase, p: &HnswParams) -> Result<Self, Error> {
82 let m0 = p.m0 as usize;
83 let state_key = ikb.new_hs_key()?;
84 Ok(Self {
85 state_key,
86 state: Default::default(),
87 m: p.m as usize,
88 efc: p.ef_construction as usize,
89 ml: p.ml.to_float(),
90 layer0: HnswLayer::new(ikb.clone(), 0, m0),
91 layers: Vec::default(),
92 elements: HnswElements::new(ikb.clone(), p.distance.clone()),
93 rng: SmallRng::from_entropy(),
94 heuristic: p.into(),
95 ikb,
96 })
97 }
98
99 async fn check_state(&mut self, tx: &Transaction) -> Result<(), Error> {
100 let st: HnswState = if let Some(val) = tx.get(self.state_key.clone(), None).await? {
102 VersionedStore::try_from(val)?
103 } else {
104 Default::default()
105 };
106 if st.layer0.version != self.state.layer0.version {
108 self.layer0.load(tx, &st.layer0).await?;
109 }
110 for ((new_stl, stl), layer) in
111 st.layers.iter().zip(self.state.layers.iter_mut()).zip(self.layers.iter_mut())
112 {
113 if new_stl.version != stl.version {
114 layer.load(tx, new_stl).await?;
115 }
116 }
117 for i in self.layers.len()..st.layers.len() {
119 let mut l = HnswLayer::new(self.ikb.clone(), i + 1, self.m);
120 l.load(tx, &st.layers[i]).await?;
121 self.layers.push(l);
122 }
123 for _ in self.layers.len()..st.layers.len() {
125 self.layers.pop();
126 }
127 self.elements.set_next_element_id(st.next_element_id);
129 self.state = st;
130 Ok(())
131 }
132
133 async fn insert_level(
134 &mut self,
135 tx: &Transaction,
136 q_pt: Vector,
137 q_level: usize,
138 ) -> Result<ElementId, Error> {
139 let q_id = self.elements.next_element_id();
141 let top_up_layers = self.layers.len();
142
143 for i in top_up_layers..q_level {
145 self.layers.push(HnswLayer::new(self.ikb.clone(), i + 1, self.m));
146 self.state.layers.push(LayerState::default());
147 }
148
149 let pt_ser = SerializedVector::from(&q_pt);
151 let q_pt = self.elements.insert(tx, q_id, q_pt, &pt_ser).await?;
152
153 if let Some(ep_id) = self.state.enter_point {
154 self.insert_element(tx, q_id, &q_pt, q_level, ep_id, top_up_layers).await?;
156 } else {
157 self.insert_first_element(tx, q_id, q_level).await?;
159 }
160
161 self.state.next_element_id = self.elements.inc_next_element_id();
162 Ok(q_id)
163 }
164
165 fn get_random_level(&mut self) -> usize {
166 let unif: f64 = self.rng.gen(); (-unif.ln() * self.ml).floor() as usize }
169
170 async fn insert_first_element(
171 &mut self,
172 tx: &Transaction,
173 id: ElementId,
174 level: usize,
175 ) -> Result<(), Error> {
176 if level > 0 {
177 for (layer, state) in
179 self.layers.iter_mut().zip(self.state.layers.iter_mut()).take(level)
180 {
181 layer.add_empty_node(tx, id, state).await?;
182 }
183 }
184 self.layer0.add_empty_node(tx, id, &mut self.state.layer0).await?;
186 self.state.enter_point = Some(id);
188 Ok(())
190 }
191
192 async fn insert_element(
193 &mut self,
194 tx: &Transaction,
195 q_id: ElementId,
196 q_pt: &SharedVector,
197 q_level: usize,
198 mut ep_id: ElementId,
199 top_up_layers: usize,
200 ) -> Result<(), Error> {
201 if let Some(mut ep_dist) = self.elements.get_distance(tx, q_pt, &ep_id).await? {
202 if q_level < top_up_layers {
203 for layer in self.layers[q_level..top_up_layers].iter_mut().rev() {
204 if let Some(ep_dist_id) = layer
205 .search_single(tx, &self.elements, q_pt, ep_dist, ep_id, 1)
206 .await?
207 .peek_first()
208 {
209 (ep_dist, ep_id) = ep_dist_id;
210 } else {
211 #[cfg(debug_assertions)]
212 unreachable!()
213 }
214 }
215 }
216
217 let mut eps = DoublePriorityQueue::from(ep_dist, ep_id);
218
219 let insert_to_up_layers = q_level.min(top_up_layers);
220 if insert_to_up_layers > 0 {
221 for (layer, st) in self
222 .layers
223 .iter_mut()
224 .zip(self.state.layers.iter_mut())
225 .take(insert_to_up_layers)
226 .rev()
227 {
228 eps = layer
229 .insert(
230 (tx, st),
231 &self.elements,
232 &self.heuristic,
233 self.efc,
234 (q_id, q_pt),
235 eps,
236 )
237 .await?;
238 }
239 }
240
241 self.layer0
242 .insert(
243 (tx, &mut self.state.layer0),
244 &self.elements,
245 &self.heuristic,
246 self.efc,
247 (q_id, q_pt),
248 eps,
249 )
250 .await?;
251
252 if top_up_layers < q_level {
253 for (layer, st) in self.layers[top_up_layers..q_level]
254 .iter_mut()
255 .zip(self.state.layers[top_up_layers..q_level].iter_mut())
256 {
257 if !layer.add_empty_node(tx, q_id, st).await? {
258 #[cfg(debug_assertions)]
259 unreachable!("Already there {}", q_id);
260 }
261 }
262 }
263
264 if q_level > top_up_layers {
265 self.state.enter_point = Some(q_id);
266 }
267 } else {
268 #[cfg(debug_assertions)]
269 unreachable!()
270 }
271 Ok(())
272 }
273
274 async fn save_state(&self, tx: &Transaction) -> Result<(), Error> {
275 let val: Val = VersionedStore::try_into(&self.state)?;
276 tx.set(self.state_key.clone(), val, None).await?;
277 Ok(())
278 }
279
280 async fn insert(&mut self, tx: &Transaction, q_pt: Vector) -> Result<ElementId, Error> {
281 let q_level = self.get_random_level();
282 let res = self.insert_level(tx, q_pt, q_level).await?;
283 self.save_state(tx).await?;
284 Ok(res)
285 }
286
287 async fn remove(&mut self, tx: &Transaction, e_id: ElementId) -> Result<bool, Error> {
288 let mut removed = false;
289
290 if let Some(e_pt) = self.elements.get_vector(tx, &e_id).await? {
292 let mut new_enter_point = if Some(e_id) == self.state.enter_point {
294 None
295 } else {
296 self.state.enter_point
297 };
298
299 for (layer, st) in self.layers.iter_mut().zip(self.state.layers.iter_mut()).rev() {
301 if new_enter_point.is_none() {
302 new_enter_point = layer
303 .search_single_with_ignore(tx, &self.elements, &e_pt, e_id, self.efc)
304 .await?;
305 }
306 if layer.remove(tx, st, &self.elements, &self.heuristic, e_id, self.efc).await? {
307 removed = true;
308 }
309 }
310
311 if new_enter_point.is_none() {
313 new_enter_point = self
314 .layer0
315 .search_single_with_ignore(tx, &self.elements, &e_pt, e_id, self.efc)
316 .await?;
317 }
318
319 if self
321 .layer0
322 .remove(tx, &mut self.state.layer0, &self.elements, &self.heuristic, e_id, self.efc)
323 .await?
324 {
325 removed = true;
326 }
327
328 self.elements.remove(tx, e_id).await?;
329
330 self.state.enter_point = new_enter_point;
331 }
332
333 self.save_state(tx).await?;
334 Ok(removed)
335 }
336
337 async fn knn_search(
338 &self,
339 tx: &Transaction,
340 search: &HnswSearch,
341 ) -> Result<Vec<(f64, ElementId)>, Error> {
342 if let Some((ep_dist, ep_id)) = self.search_ep(tx, &search.pt).await? {
343 let w = self
344 .layer0
345 .search_single(tx, &self.elements, &search.pt, ep_dist, ep_id, search.ef)
346 .await?;
347 Ok(w.to_vec_limit(search.k))
348 } else {
349 Ok(vec![])
350 }
351 }
352
353 async fn knn_search_checked(
354 &self,
355 tx: &Transaction,
356 stk: &mut Stk,
357 search: &HnswSearch,
358 hnsw_docs: &HnswDocs,
359 vec_docs: &VecDocs,
360 chk: &mut HnswConditionChecker<'_>,
361 ) -> Result<Vec<(f64, ElementId)>, Error> {
362 if let Some((ep_dist, ep_id)) = self.search_ep(tx, &search.pt).await? {
363 if let Some(ep_pt) = self.elements.get_vector(tx, &ep_id).await? {
364 let search_ctx = HnswCheckedSearchContext::new(
365 &self.elements,
366 hnsw_docs,
367 vec_docs,
368 &search.pt,
369 search.ef,
370 );
371 let w = self
372 .layer0
373 .search_single_checked(tx, stk, &search_ctx, &ep_pt, ep_dist, ep_id, chk)
374 .await?;
375 return Ok(w.to_vec_limit(search.k));
376 }
377 }
378 Ok(vec![])
379 }
380
381 async fn search_ep(
382 &self,
383 tx: &Transaction,
384 pt: &SharedVector,
385 ) -> Result<Option<(f64, ElementId)>, Error> {
386 if let Some(mut ep_id) = self.state.enter_point {
387 if let Some(mut ep_dist) = self.elements.get_distance(tx, pt, &ep_id).await? {
388 for layer in self.layers.iter().rev() {
389 if let Some(ep_dist_id) = layer
390 .search_single(tx, &self.elements, pt, ep_dist, ep_id, 1)
391 .await?
392 .peek_first()
393 {
394 (ep_dist, ep_id) = ep_dist_id;
395 } else {
396 #[cfg(debug_assertions)]
397 unreachable!()
398 }
399 }
400 return Ok(Some((ep_dist, ep_id)));
401 } else {
402 #[cfg(debug_assertions)]
403 unreachable!()
404 }
405 }
406 Ok(None)
407 }
408
409 async fn get_vector(
410 &self,
411 tx: &Transaction,
412 e_id: &ElementId,
413 ) -> Result<Option<SharedVector>, Error> {
414 self.elements.get_vector(tx, e_id).await
415 }
416 #[cfg(test)]
417 fn check_hnsw_properties(&self, expected_count: usize) {
418 check_hnsw_props(self, expected_count);
419 }
420}
421
422#[cfg(test)]
423fn check_hnsw_props<L0, L>(h: &Hnsw<L0, L>, expected_count: usize)
424where
425 L0: DynamicSet,
426 L: DynamicSet,
427{
428 assert_eq!(h.elements.len(), expected_count);
429 for layer in h.layers.iter() {
430 layer.check_props(&h.elements);
431 }
432}
433
434#[cfg(test)]
435mod tests {
436 use crate::ctx::{Context, MutableContext};
437 use crate::err::Error;
438 use crate::idx::docids::DocId;
439 use crate::idx::planner::checker::HnswConditionChecker;
440 use crate::idx::trees::hnsw::flavor::HnswFlavor;
441 use crate::idx::trees::hnsw::index::HnswIndex;
442 use crate::idx::trees::hnsw::{ElementId, HnswSearch};
443 use crate::idx::trees::knn::tests::{new_vectors_from_file, TestCollection};
444 use crate::idx::trees::knn::{Ids64, KnnResult, KnnResultBuilder};
445 use crate::idx::trees::vector::{SharedVector, Vector};
446 use crate::idx::IndexKeyBase;
447 use crate::kvs::LockType::Optimistic;
448 use crate::kvs::{Datastore, Transaction, TransactionType};
449 use crate::sql::index::{Distance, HnswParams, VectorType};
450 use crate::sql::{Id, Value};
451 use ahash::{HashMap, HashSet};
452 use ndarray::Array1;
453 use reblessive::tree::Stk;
454 use roaring::RoaringTreemap;
455 use std::collections::hash_map::Entry;
456 use std::ops::Deref;
457 use std::sync::Arc;
458 use test_log::test;
459
460 async fn insert_collection_hnsw(
461 tx: &Transaction,
462 h: &mut HnswFlavor,
463 collection: &TestCollection,
464 ) -> HashMap<ElementId, SharedVector> {
465 let mut map = HashMap::default();
466 for (_, obj) in collection.to_vec_ref() {
467 let obj: SharedVector = obj.clone();
468 let e_id = h.insert(tx, obj.clone_vector()).await.unwrap();
469 map.insert(e_id, obj);
470 h.check_hnsw_properties(map.len());
471 }
472 map
473 }
474
475 async fn find_collection_hnsw(tx: &Transaction, h: &HnswFlavor, collection: &TestCollection) {
476 let max_knn = 20.min(collection.len());
477 for (_, obj) in collection.to_vec_ref() {
478 for knn in 1..max_knn {
479 let search = HnswSearch::new(obj.clone(), knn, 80);
480 let res = h.knn_search(tx, &search).await.unwrap();
481 if collection.is_unique() {
482 let mut found = false;
483 for (_, e_id) in &res {
484 if let Some(v) = h.get_vector(tx, e_id).await.unwrap() {
485 if v.eq(obj) {
486 found = true;
487 break;
488 }
489 }
490 }
491 assert!(
492 found,
493 "Search: {:?} - Knn: {} - Vector not found - Got: {:?} - Coll: {}",
494 obj,
495 knn,
496 res,
497 collection.len(),
498 );
499 }
500 let expected_len = collection.len().min(knn);
501 if expected_len != res.len() {
502 info!("expected_len != res.len()")
503 }
504 assert_eq!(
505 expected_len,
506 res.len(),
507 "Wrong knn count - Expected: {} - Got: {} - Collection: {} - - Res: {:?}",
508 expected_len,
509 res.len(),
510 collection.len(),
511 res,
512 )
513 }
514 }
515 }
516
517 async fn delete_collection_hnsw(
518 tx: &Transaction,
519 h: &mut HnswFlavor,
520 mut map: HashMap<ElementId, SharedVector>,
521 ) {
522 let element_ids: Vec<ElementId> = map.keys().copied().collect();
523 for e_id in element_ids {
524 assert!(h.remove(tx, e_id).await.unwrap());
525 map.remove(&e_id);
526 h.check_hnsw_properties(map.len());
527 }
528 }
529
530 async fn test_hnsw_collection(p: &HnswParams, collection: &TestCollection) {
531 let ds = Datastore::new("memory").await.unwrap();
532 let mut h = HnswFlavor::new(IndexKeyBase::default(), p).unwrap();
533 let map = {
534 let tx = ds.transaction(TransactionType::Write, Optimistic).await.unwrap();
535 let map = insert_collection_hnsw(&tx, &mut h, collection).await;
536 tx.commit().await.unwrap();
537 map
538 };
539 {
540 let tx = ds.transaction(TransactionType::Read, Optimistic).await.unwrap();
541 find_collection_hnsw(&tx, &h, collection).await;
542 tx.cancel().await.unwrap();
543 }
544 {
545 let tx = ds.transaction(TransactionType::Write, Optimistic).await.unwrap();
546 delete_collection_hnsw(&tx, &mut h, map).await;
547 tx.commit().await.unwrap();
548 }
549 }
550
551 fn new_params(
552 dimension: usize,
553 vector_type: VectorType,
554 distance: Distance,
555 m: usize,
556 efc: usize,
557 extend_candidates: bool,
558 keep_pruned_connections: bool,
559 ) -> HnswParams {
560 let m = m as u8;
561 let m0 = m * 2;
562 HnswParams::new(
563 dimension as u16,
564 distance,
565 vector_type,
566 m,
567 m0,
568 (1.0 / (m as f64).ln()).into(),
569 efc as u16,
570 extend_candidates,
571 keep_pruned_connections,
572 )
573 }
574
575 async fn test_hnsw(collection_size: usize, p: HnswParams) {
576 info!("Collection size: {collection_size} - Params: {p:?}");
577 let collection = TestCollection::new(
578 true,
579 collection_size,
580 p.vector_type,
581 p.dimension as usize,
582 &p.distance,
583 );
584 test_hnsw_collection(&p, &collection).await;
585 }
586
587 #[test(tokio::test(flavor = "multi_thread"))]
588 async fn tests_hnsw() -> Result<(), Error> {
589 let mut futures = Vec::new();
590 for (dist, dim) in [
591 (Distance::Chebyshev, 5),
592 (Distance::Cosine, 5),
593 (Distance::Euclidean, 5),
594 (Distance::Hamming, 20),
595 (Distance::Manhattan, 5),
597 (Distance::Minkowski(2.into()), 5),
598 ] {
600 for vt in [
601 VectorType::F64,
602 VectorType::F32,
603 VectorType::I64,
604 VectorType::I32,
605 VectorType::I16,
606 ] {
607 for (extend, keep) in [(false, false), (true, false), (false, true), (true, true)] {
608 let p = new_params(dim, vt, dist.clone(), 24, 500, extend, keep);
609 let f = tokio::spawn(async move {
610 test_hnsw(30, p).await;
611 });
612 futures.push(f);
613 }
614 }
615 }
616 for f in futures {
617 f.await.expect("Task error");
618 }
619 Ok(())
620 }
621
622 async fn insert_collection_hnsw_index(
623 tx: &Transaction,
624 h: &mut HnswIndex,
625 collection: &TestCollection,
626 ) -> Result<HashMap<SharedVector, HashSet<DocId>>, Error> {
627 let mut map: HashMap<SharedVector, HashSet<DocId>> = HashMap::default();
628 for (doc_id, obj) in collection.to_vec_ref() {
629 let content = vec![Value::from(obj.deref())];
630 h.index_document(tx, &Id::Number(*doc_id as i64), &content).await.unwrap();
631 match map.entry(obj.clone()) {
632 Entry::Occupied(mut e) => {
633 e.get_mut().insert(*doc_id);
634 }
635 Entry::Vacant(e) => {
636 e.insert(HashSet::from_iter([*doc_id]));
637 }
638 }
639 h.check_hnsw_properties(map.len());
640 }
641 Ok(map)
642 }
643
644 async fn find_collection_hnsw_index(
645 tx: &Transaction,
646 stk: &mut Stk,
647 h: &mut HnswIndex,
648 collection: &TestCollection,
649 ) {
650 let max_knn = 20.min(collection.len());
651 for (doc_id, obj) in collection.to_vec_ref() {
652 for knn in 1..max_knn {
653 let mut chk = HnswConditionChecker::new();
654 let search = HnswSearch::new(obj.clone(), knn, 500);
655 let res = h.search(tx, stk, &search, &mut chk).await.unwrap();
656 if knn == 1 && res.docs.len() == 1 && res.docs[0].1 > 0.0 {
657 let docs: Vec<DocId> = res.docs.iter().map(|(d, _)| *d).collect();
658 if collection.is_unique() {
659 assert!(
660 docs.contains(doc_id),
661 "Search: {:?} - Knn: {} - Wrong Doc - Expected: {} - Got: {:?}",
662 obj,
663 knn,
664 doc_id,
665 res.docs
666 );
667 }
668 }
669 let expected_len = collection.len().min(knn);
670 assert_eq!(
671 expected_len,
672 res.docs.len(),
673 "Wrong knn count - Expected: {} - Got: {} - - Docs: {:?} - Collection: {}",
674 expected_len,
675 res.docs.len(),
676 res.docs,
677 collection.len(),
678 )
679 }
680 }
681 }
682
683 async fn delete_hnsw_index_collection(
684 tx: &Transaction,
685 h: &mut HnswIndex,
686 collection: &TestCollection,
687 mut map: HashMap<SharedVector, HashSet<DocId>>,
688 ) -> Result<(), Error> {
689 for (doc_id, obj) in collection.to_vec_ref() {
690 let content = vec![Value::from(obj.deref())];
691 h.remove_document(tx, Id::Number(*doc_id as i64), &content).await?;
692 if let Entry::Occupied(mut e) = map.entry(obj.clone()) {
693 let set = e.get_mut();
694 set.remove(doc_id);
695 if set.is_empty() {
696 e.remove();
697 }
698 }
699 h.check_hnsw_properties(map.len());
701 }
702 Ok(())
703 }
704
705 async fn new_ctx(ds: &Datastore, tt: TransactionType) -> Context {
706 let tx = Arc::new(ds.transaction(tt, Optimistic).await.unwrap());
707 let mut ctx = MutableContext::default();
708 ctx.set_transaction(tx);
709 ctx.freeze()
710 }
711
712 async fn test_hnsw_index(collection_size: usize, unique: bool, p: HnswParams) {
713 info!("test_hnsw_index - coll size: {collection_size} - params: {p:?}");
714
715 let ds = Datastore::new("memory").await.unwrap();
716
717 let collection = TestCollection::new(
718 unique,
719 collection_size,
720 p.vector_type,
721 p.dimension as usize,
722 &p.distance,
723 );
724
725 let (mut h, map) = {
727 let ctx = new_ctx(&ds, TransactionType::Write).await;
728 let tx = ctx.tx();
729 let mut h =
730 HnswIndex::new(&tx, IndexKeyBase::default(), "test".to_string(), &p).await.unwrap();
731 let map = insert_collection_hnsw_index(&tx, &mut h, &collection).await.unwrap();
733 tx.commit().await.unwrap();
734 (h, map)
735 };
736
737 {
739 let mut stack = reblessive::tree::TreeStack::new();
740 let ctx = new_ctx(&ds, TransactionType::Read).await;
741 let tx = ctx.tx();
742 stack
743 .enter(|stk| async {
744 find_collection_hnsw_index(&tx, stk, &mut h, &collection).await;
745 })
746 .finish()
747 .await;
748 }
749
750 {
752 let ctx = new_ctx(&ds, TransactionType::Write).await;
753 let tx = ctx.tx();
754 delete_hnsw_index_collection(&tx, &mut h, &collection, map).await.unwrap();
755 tx.commit().await.unwrap();
756 }
757 }
758
759 #[test(tokio::test(flavor = "multi_thread"))]
760 async fn tests_hnsw_index() -> Result<(), Error> {
761 let mut futures = Vec::new();
762 for (dist, dim) in [
763 (Distance::Chebyshev, 5),
764 (Distance::Cosine, 5),
765 (Distance::Euclidean, 5),
766 (Distance::Hamming, 20),
767 (Distance::Manhattan, 5),
769 (Distance::Minkowski(2.into()), 5),
770 ] {
772 for vt in [
773 VectorType::F64,
774 VectorType::F32,
775 VectorType::I64,
776 VectorType::I32,
777 VectorType::I16,
778 ] {
779 for (extend, keep) in [(false, false), (true, false), (false, true), (true, true)] {
780 for unique in [true, false] {
781 let p = new_params(dim, vt, dist.clone(), 8, 150, extend, keep);
782 let f = tokio::spawn(async move {
783 test_hnsw_index(30, unique, p).await;
784 });
785 futures.push(f);
786 }
787 }
788 }
789 }
790 for f in futures {
791 f.await.expect("Task error");
792 }
793 Ok(())
794 }
795
796 #[test(tokio::test(flavor = "multi_thread"))]
797 async fn test_simple_hnsw() {
798 let collection = TestCollection::Unique(vec![
799 (0, new_i16_vec(-2, -3)),
800 (1, new_i16_vec(-2, 1)),
801 (2, new_i16_vec(-4, 3)),
802 (3, new_i16_vec(-3, 1)),
803 (4, new_i16_vec(-1, 1)),
804 (5, new_i16_vec(-2, 3)),
805 (6, new_i16_vec(3, 0)),
806 (7, new_i16_vec(-1, -2)),
807 (8, new_i16_vec(-2, 2)),
808 (9, new_i16_vec(-4, -2)),
809 (10, new_i16_vec(0, 3)),
810 ]);
811 let ikb = IndexKeyBase::default();
812 let p = new_params(2, VectorType::I16, Distance::Euclidean, 3, 500, true, true);
813 let mut h = HnswFlavor::new(ikb, &p).unwrap();
814 let ds = Arc::new(Datastore::new("memory").await.unwrap());
815 {
816 let tx = ds.transaction(TransactionType::Write, Optimistic).await.unwrap();
817 insert_collection_hnsw(&tx, &mut h, &collection).await;
818 tx.commit().await.unwrap();
819 }
820 {
821 let tx = ds.transaction(TransactionType::Read, Optimistic).await.unwrap();
822 let search = HnswSearch::new(new_i16_vec(-2, -3), 10, 501);
823 let res = h.knn_search(&tx, &search).await.unwrap();
824 assert_eq!(res.len(), 10);
825 }
826 }
827
828 async fn test_recall(
829 embeddings_file: &str,
830 ingest_limit: usize,
831 queries_file: &str,
832 query_limit: usize,
833 p: HnswParams,
834 tests_ef_recall: &[(usize, f64)],
835 ) -> Result<(), Error> {
836 info!("Build data collection");
837
838 let ds = Arc::new(Datastore::new("memory").await?);
839
840 let collection: Arc<TestCollection> =
841 Arc::new(TestCollection::NonUnique(new_vectors_from_file(
842 p.vector_type,
843 &format!("../../tests/data/{embeddings_file}"),
844 Some(ingest_limit),
845 )?));
846
847 let ctx = new_ctx(&ds, TransactionType::Write).await;
848 let tx = ctx.tx();
849 let mut h = HnswIndex::new(&tx, IndexKeyBase::default(), "Index".to_string(), &p).await?;
850 info!("Insert collection");
851 for (doc_id, obj) in collection.to_vec_ref() {
852 let content = vec![Value::from(obj.deref())];
853 h.index_document(&tx, &Id::Number(*doc_id as i64), &content).await?;
854 }
855 tx.commit().await?;
856
857 let h = Arc::new(h);
858
859 info!("Build query collection");
860 let queries = Arc::new(TestCollection::NonUnique(new_vectors_from_file(
861 p.vector_type,
862 &format!("../../tests/data/{queries_file}"),
863 Some(query_limit),
864 )?));
865
866 info!("Check recall");
867 let mut futures = Vec::with_capacity(tests_ef_recall.len());
868 for &(efs, expected_recall) in tests_ef_recall {
869 let queries = queries.clone();
870 let collection = collection.clone();
871 let h = h.clone();
872 let ds = ds.clone();
873 let f = tokio::spawn(async move {
874 let mut stack = reblessive::tree::TreeStack::new();
875 stack
876 .enter(|stk| async {
877 let mut total_recall = 0.0;
878 for (_, pt) in queries.to_vec_ref() {
879 let knn = 10;
880 let mut chk = HnswConditionChecker::new();
881 let search = HnswSearch::new(pt.clone(), knn, efs);
882 let ctx = new_ctx(&ds, TransactionType::Read).await;
883 let tx = ctx.tx();
884 let hnsw_res = h.search(&tx, stk, &search, &mut chk).await.unwrap();
885 assert_eq!(hnsw_res.docs.len(), knn, "Different size - knn: {knn}",);
886 let brute_force_res = collection.knn(pt, Distance::Euclidean, knn);
887 let rec = brute_force_res.recall(&hnsw_res);
888 if rec == 1.0 {
889 assert_eq!(brute_force_res.docs, hnsw_res.docs);
890 }
891 total_recall += rec;
892 }
893 let recall = total_recall / queries.to_vec_ref().len() as f64;
894 info!("EFS: {efs} - Recall: {recall}");
895 assert!(
896 recall >= expected_recall,
897 "EFS: {efs} - Recall: {recall} - Expected: {expected_recall}"
898 );
899 })
900 .finish()
901 .await;
902 });
903 futures.push(f);
904 }
905 for f in futures {
906 f.await.expect("Task failure");
907 }
908 Ok(())
909 }
910
911 #[test(tokio::test(flavor = "multi_thread"))]
912 async fn test_recall_euclidean() -> Result<(), Error> {
913 let p = new_params(20, VectorType::F32, Distance::Euclidean, 8, 100, false, false);
914 test_recall(
915 "hnsw-random-9000-20-euclidean.gz",
916 1000,
917 "hnsw-random-5000-20-euclidean.gz",
918 300,
919 p,
920 &[(10, 0.98), (40, 1.0)],
921 )
922 .await
923 }
924
925 #[test(tokio::test(flavor = "multi_thread"))]
926 async fn test_recall_euclidean_keep_pruned_connections() -> Result<(), Error> {
927 let p = new_params(20, VectorType::F32, Distance::Euclidean, 8, 100, false, true);
928 test_recall(
929 "hnsw-random-9000-20-euclidean.gz",
930 750,
931 "hnsw-random-5000-20-euclidean.gz",
932 200,
933 p,
934 &[(10, 0.98), (40, 1.0)],
935 )
936 .await
937 }
938
939 #[test(tokio::test(flavor = "multi_thread"))]
940 async fn test_recall_euclidean_full() -> Result<(), Error> {
941 let p = new_params(20, VectorType::F32, Distance::Euclidean, 8, 100, true, true);
942 test_recall(
943 "hnsw-random-9000-20-euclidean.gz",
944 500,
945 "hnsw-random-5000-20-euclidean.gz",
946 100,
947 p,
948 &[(10, 0.98), (40, 1.0)],
949 )
950 .await
951 }
952
953 impl TestCollection {
954 fn knn(&self, pt: &SharedVector, dist: Distance, n: usize) -> KnnResult {
955 let mut b = KnnResultBuilder::new(n);
956 for (doc_id, doc_pt) in self.to_vec_ref() {
957 let d = dist.calculate(doc_pt, pt);
958 if b.check_add(d) {
959 b.add(d, Ids64::One(*doc_id));
960 }
961 }
962 b.build(
963 #[cfg(debug_assertions)]
964 HashMap::default(),
965 )
966 }
967 }
968
969 impl KnnResult {
970 fn recall(&self, res: &KnnResult) -> f64 {
971 let mut bits = RoaringTreemap::new();
972 for &(doc_id, _) in &self.docs {
973 bits.insert(doc_id);
974 }
975 let mut found = 0;
976 for &(doc_id, _) in &res.docs {
977 if bits.contains(doc_id) {
978 found += 1;
979 }
980 }
981 found as f64 / bits.len() as f64
982 }
983 }
984
985 fn new_i16_vec(x: isize, y: isize) -> SharedVector {
986 let vec = Vector::I16(Array1::from_vec(vec![x as i16, y as i16]));
987 vec.into()
988 }
989}