surrealdb_core/idx/trees/hnsw/
mod.rs

1pub(in crate::idx) mod docs;
2mod elements;
3mod flavor;
4mod heuristic;
5pub mod index;
6mod layer;
7
8use crate::err::Error;
9use crate::idx::planner::checker::HnswConditionChecker;
10use crate::idx::trees::dynamicset::DynamicSet;
11use crate::idx::trees::hnsw::docs::HnswDocs;
12use crate::idx::trees::hnsw::docs::VecDocs;
13use crate::idx::trees::hnsw::elements::HnswElements;
14use crate::idx::trees::hnsw::heuristic::Heuristic;
15use crate::idx::trees::hnsw::index::HnswCheckedSearchContext;
16
17use crate::idx::trees::hnsw::layer::{HnswLayer, LayerState};
18use crate::idx::trees::knn::DoublePriorityQueue;
19use crate::idx::trees::vector::{SerializedVector, SharedVector, Vector};
20use crate::idx::{IndexKeyBase, VersionedStore};
21use crate::kvs::{Key, Transaction, Val};
22use crate::sql::index::HnswParams;
23use rand::prelude::SmallRng;
24use rand::{Rng, SeedableRng};
25use reblessive::tree::Stk;
26use revision::revisioned;
27use serde::{Deserialize, Serialize};
28
29struct HnswSearch {
30	pt: SharedVector,
31	k: usize,
32	ef: usize,
33}
34
35impl HnswSearch {
36	pub(super) fn new(pt: SharedVector, k: usize, ef: usize) -> Self {
37		Self {
38			pt,
39			k,
40			ef,
41		}
42	}
43}
44
45#[revisioned(revision = 1)]
46#[derive(Default, Serialize, Deserialize)]
47pub(super) struct HnswState {
48	enter_point: Option<ElementId>,
49	next_element_id: ElementId,
50	layer0: LayerState,
51	layers: Vec<LayerState>,
52}
53
54impl VersionedStore for HnswState {}
55
56struct Hnsw<L0, L>
57where
58	L0: DynamicSet,
59	L: DynamicSet,
60{
61	ikb: IndexKeyBase,
62	state_key: Key,
63	state: HnswState,
64	m: usize,
65	efc: usize,
66	ml: f64,
67	layer0: HnswLayer<L0>,
68	layers: Vec<HnswLayer<L>>,
69	elements: HnswElements,
70	rng: SmallRng,
71	heuristic: Heuristic,
72}
73
74pub(crate) type ElementId = u64;
75
76impl<L0, L> Hnsw<L0, L>
77where
78	L0: DynamicSet,
79	L: DynamicSet,
80{
81	fn new(ikb: IndexKeyBase, p: &HnswParams) -> Result<Self, Error> {
82		let m0 = p.m0 as usize;
83		let state_key = ikb.new_hs_key()?;
84		Ok(Self {
85			state_key,
86			state: Default::default(),
87			m: p.m as usize,
88			efc: p.ef_construction as usize,
89			ml: p.ml.to_float(),
90			layer0: HnswLayer::new(ikb.clone(), 0, m0),
91			layers: Vec::default(),
92			elements: HnswElements::new(ikb.clone(), p.distance.clone()),
93			rng: SmallRng::from_entropy(),
94			heuristic: p.into(),
95			ikb,
96		})
97	}
98
99	async fn check_state(&mut self, tx: &Transaction) -> Result<(), Error> {
100		// Read the state
101		let st: HnswState = if let Some(val) = tx.get(self.state_key.clone(), None).await? {
102			VersionedStore::try_from(val)?
103		} else {
104			Default::default()
105		};
106		// Compare versions
107		if st.layer0.version != self.state.layer0.version {
108			self.layer0.load(tx, &st.layer0).await?;
109		}
110		for ((new_stl, stl), layer) in
111			st.layers.iter().zip(self.state.layers.iter_mut()).zip(self.layers.iter_mut())
112		{
113			if new_stl.version != stl.version {
114				layer.load(tx, new_stl).await?;
115			}
116		}
117		// Retrieve missing layers
118		for i in self.layers.len()..st.layers.len() {
119			let mut l = HnswLayer::new(self.ikb.clone(), i + 1, self.m);
120			l.load(tx, &st.layers[i]).await?;
121			self.layers.push(l);
122		}
123		// Remove non-existing layers
124		for _ in self.layers.len()..st.layers.len() {
125			self.layers.pop();
126		}
127		// Set the enter_point
128		self.elements.set_next_element_id(st.next_element_id);
129		self.state = st;
130		Ok(())
131	}
132
133	async fn insert_level(
134		&mut self,
135		tx: &Transaction,
136		q_pt: Vector,
137		q_level: usize,
138	) -> Result<ElementId, Error> {
139		// Attributes an ID to the vector
140		let q_id = self.elements.next_element_id();
141		let top_up_layers = self.layers.len();
142
143		// Be sure we have existing (up) layers if required
144		for i in top_up_layers..q_level {
145			self.layers.push(HnswLayer::new(self.ikb.clone(), i + 1, self.m));
146			self.state.layers.push(LayerState::default());
147		}
148
149		// Store the vector
150		let pt_ser = SerializedVector::from(&q_pt);
151		let q_pt = self.elements.insert(tx, q_id, q_pt, &pt_ser).await?;
152
153		if let Some(ep_id) = self.state.enter_point {
154			// We already have an enter_point, let's insert the element in the layers
155			self.insert_element(tx, q_id, &q_pt, q_level, ep_id, top_up_layers).await?;
156		} else {
157			// Otherwise is the first element
158			self.insert_first_element(tx, q_id, q_level).await?;
159		}
160
161		self.state.next_element_id = self.elements.inc_next_element_id();
162		Ok(q_id)
163	}
164
165	fn get_random_level(&mut self) -> usize {
166		let unif: f64 = self.rng.gen(); // generate a uniform random number between 0 and 1
167		(-unif.ln() * self.ml).floor() as usize // calculate the layer
168	}
169
170	async fn insert_first_element(
171		&mut self,
172		tx: &Transaction,
173		id: ElementId,
174		level: usize,
175	) -> Result<(), Error> {
176		if level > 0 {
177			// Insert in up levels
178			for (layer, state) in
179				self.layers.iter_mut().zip(self.state.layers.iter_mut()).take(level)
180			{
181				layer.add_empty_node(tx, id, state).await?;
182			}
183		}
184		// Insert in layer 0
185		self.layer0.add_empty_node(tx, id, &mut self.state.layer0).await?;
186		// Update the enter point
187		self.state.enter_point = Some(id);
188		//
189		Ok(())
190	}
191
192	async fn insert_element(
193		&mut self,
194		tx: &Transaction,
195		q_id: ElementId,
196		q_pt: &SharedVector,
197		q_level: usize,
198		mut ep_id: ElementId,
199		top_up_layers: usize,
200	) -> Result<(), Error> {
201		if let Some(mut ep_dist) = self.elements.get_distance(tx, q_pt, &ep_id).await? {
202			if q_level < top_up_layers {
203				for layer in self.layers[q_level..top_up_layers].iter_mut().rev() {
204					if let Some(ep_dist_id) = layer
205						.search_single(tx, &self.elements, q_pt, ep_dist, ep_id, 1)
206						.await?
207						.peek_first()
208					{
209						(ep_dist, ep_id) = ep_dist_id;
210					} else {
211						#[cfg(debug_assertions)]
212						unreachable!()
213					}
214				}
215			}
216
217			let mut eps = DoublePriorityQueue::from(ep_dist, ep_id);
218
219			let insert_to_up_layers = q_level.min(top_up_layers);
220			if insert_to_up_layers > 0 {
221				for (layer, st) in self
222					.layers
223					.iter_mut()
224					.zip(self.state.layers.iter_mut())
225					.take(insert_to_up_layers)
226					.rev()
227				{
228					eps = layer
229						.insert(
230							(tx, st),
231							&self.elements,
232							&self.heuristic,
233							self.efc,
234							(q_id, q_pt),
235							eps,
236						)
237						.await?;
238				}
239			}
240
241			self.layer0
242				.insert(
243					(tx, &mut self.state.layer0),
244					&self.elements,
245					&self.heuristic,
246					self.efc,
247					(q_id, q_pt),
248					eps,
249				)
250				.await?;
251
252			if top_up_layers < q_level {
253				for (layer, st) in self.layers[top_up_layers..q_level]
254					.iter_mut()
255					.zip(self.state.layers[top_up_layers..q_level].iter_mut())
256				{
257					if !layer.add_empty_node(tx, q_id, st).await? {
258						#[cfg(debug_assertions)]
259						unreachable!("Already there {}", q_id);
260					}
261				}
262			}
263
264			if q_level > top_up_layers {
265				self.state.enter_point = Some(q_id);
266			}
267		} else {
268			#[cfg(debug_assertions)]
269			unreachable!()
270		}
271		Ok(())
272	}
273
274	async fn save_state(&self, tx: &Transaction) -> Result<(), Error> {
275		let val: Val = VersionedStore::try_into(&self.state)?;
276		tx.set(self.state_key.clone(), val, None).await?;
277		Ok(())
278	}
279
280	async fn insert(&mut self, tx: &Transaction, q_pt: Vector) -> Result<ElementId, Error> {
281		let q_level = self.get_random_level();
282		let res = self.insert_level(tx, q_pt, q_level).await?;
283		self.save_state(tx).await?;
284		Ok(res)
285	}
286
287	async fn remove(&mut self, tx: &Transaction, e_id: ElementId) -> Result<bool, Error> {
288		let mut removed = false;
289
290		// Do we have the vector?
291		if let Some(e_pt) = self.elements.get_vector(tx, &e_id).await? {
292			// Check if we are deleted the current enter_point
293			let mut new_enter_point = if Some(e_id) == self.state.enter_point {
294				None
295			} else {
296				self.state.enter_point
297			};
298
299			// Remove from the up layers
300			for (layer, st) in self.layers.iter_mut().zip(self.state.layers.iter_mut()).rev() {
301				if new_enter_point.is_none() {
302					new_enter_point = layer
303						.search_single_with_ignore(tx, &self.elements, &e_pt, e_id, self.efc)
304						.await?;
305				}
306				if layer.remove(tx, st, &self.elements, &self.heuristic, e_id, self.efc).await? {
307					removed = true;
308				}
309			}
310
311			// Check possible new enter_point at layer0
312			if new_enter_point.is_none() {
313				new_enter_point = self
314					.layer0
315					.search_single_with_ignore(tx, &self.elements, &e_pt, e_id, self.efc)
316					.await?;
317			}
318
319			// Remove from layer 0
320			if self
321				.layer0
322				.remove(tx, &mut self.state.layer0, &self.elements, &self.heuristic, e_id, self.efc)
323				.await?
324			{
325				removed = true;
326			}
327
328			self.elements.remove(tx, e_id).await?;
329
330			self.state.enter_point = new_enter_point;
331		}
332
333		self.save_state(tx).await?;
334		Ok(removed)
335	}
336
337	async fn knn_search(
338		&self,
339		tx: &Transaction,
340		search: &HnswSearch,
341	) -> Result<Vec<(f64, ElementId)>, Error> {
342		if let Some((ep_dist, ep_id)) = self.search_ep(tx, &search.pt).await? {
343			let w = self
344				.layer0
345				.search_single(tx, &self.elements, &search.pt, ep_dist, ep_id, search.ef)
346				.await?;
347			Ok(w.to_vec_limit(search.k))
348		} else {
349			Ok(vec![])
350		}
351	}
352
353	async fn knn_search_checked(
354		&self,
355		tx: &Transaction,
356		stk: &mut Stk,
357		search: &HnswSearch,
358		hnsw_docs: &HnswDocs,
359		vec_docs: &VecDocs,
360		chk: &mut HnswConditionChecker<'_>,
361	) -> Result<Vec<(f64, ElementId)>, Error> {
362		if let Some((ep_dist, ep_id)) = self.search_ep(tx, &search.pt).await? {
363			if let Some(ep_pt) = self.elements.get_vector(tx, &ep_id).await? {
364				let search_ctx = HnswCheckedSearchContext::new(
365					&self.elements,
366					hnsw_docs,
367					vec_docs,
368					&search.pt,
369					search.ef,
370				);
371				let w = self
372					.layer0
373					.search_single_checked(tx, stk, &search_ctx, &ep_pt, ep_dist, ep_id, chk)
374					.await?;
375				return Ok(w.to_vec_limit(search.k));
376			}
377		}
378		Ok(vec![])
379	}
380
381	async fn search_ep(
382		&self,
383		tx: &Transaction,
384		pt: &SharedVector,
385	) -> Result<Option<(f64, ElementId)>, Error> {
386		if let Some(mut ep_id) = self.state.enter_point {
387			if let Some(mut ep_dist) = self.elements.get_distance(tx, pt, &ep_id).await? {
388				for layer in self.layers.iter().rev() {
389					if let Some(ep_dist_id) = layer
390						.search_single(tx, &self.elements, pt, ep_dist, ep_id, 1)
391						.await?
392						.peek_first()
393					{
394						(ep_dist, ep_id) = ep_dist_id;
395					} else {
396						#[cfg(debug_assertions)]
397						unreachable!()
398					}
399				}
400				return Ok(Some((ep_dist, ep_id)));
401			} else {
402				#[cfg(debug_assertions)]
403				unreachable!()
404			}
405		}
406		Ok(None)
407	}
408
409	async fn get_vector(
410		&self,
411		tx: &Transaction,
412		e_id: &ElementId,
413	) -> Result<Option<SharedVector>, Error> {
414		self.elements.get_vector(tx, e_id).await
415	}
416	#[cfg(test)]
417	fn check_hnsw_properties(&self, expected_count: usize) {
418		check_hnsw_props(self, expected_count);
419	}
420}
421
422#[cfg(test)]
423fn check_hnsw_props<L0, L>(h: &Hnsw<L0, L>, expected_count: usize)
424where
425	L0: DynamicSet,
426	L: DynamicSet,
427{
428	assert_eq!(h.elements.len(), expected_count);
429	for layer in h.layers.iter() {
430		layer.check_props(&h.elements);
431	}
432}
433
434#[cfg(test)]
435mod tests {
436	use crate::ctx::{Context, MutableContext};
437	use crate::err::Error;
438	use crate::idx::docids::DocId;
439	use crate::idx::planner::checker::HnswConditionChecker;
440	use crate::idx::trees::hnsw::flavor::HnswFlavor;
441	use crate::idx::trees::hnsw::index::HnswIndex;
442	use crate::idx::trees::hnsw::{ElementId, HnswSearch};
443	use crate::idx::trees::knn::tests::{new_vectors_from_file, TestCollection};
444	use crate::idx::trees::knn::{Ids64, KnnResult, KnnResultBuilder};
445	use crate::idx::trees::vector::{SharedVector, Vector};
446	use crate::idx::IndexKeyBase;
447	use crate::kvs::LockType::Optimistic;
448	use crate::kvs::{Datastore, Transaction, TransactionType};
449	use crate::sql::index::{Distance, HnswParams, VectorType};
450	use crate::sql::{Id, Value};
451	use ahash::{HashMap, HashSet};
452	use ndarray::Array1;
453	use reblessive::tree::Stk;
454	use roaring::RoaringTreemap;
455	use std::collections::hash_map::Entry;
456	use std::ops::Deref;
457	use std::sync::Arc;
458	use test_log::test;
459
460	async fn insert_collection_hnsw(
461		tx: &Transaction,
462		h: &mut HnswFlavor,
463		collection: &TestCollection,
464	) -> HashMap<ElementId, SharedVector> {
465		let mut map = HashMap::default();
466		for (_, obj) in collection.to_vec_ref() {
467			let obj: SharedVector = obj.clone();
468			let e_id = h.insert(tx, obj.clone_vector()).await.unwrap();
469			map.insert(e_id, obj);
470			h.check_hnsw_properties(map.len());
471		}
472		map
473	}
474
475	async fn find_collection_hnsw(tx: &Transaction, h: &HnswFlavor, collection: &TestCollection) {
476		let max_knn = 20.min(collection.len());
477		for (_, obj) in collection.to_vec_ref() {
478			for knn in 1..max_knn {
479				let search = HnswSearch::new(obj.clone(), knn, 80);
480				let res = h.knn_search(tx, &search).await.unwrap();
481				if collection.is_unique() {
482					let mut found = false;
483					for (_, e_id) in &res {
484						if let Some(v) = h.get_vector(tx, e_id).await.unwrap() {
485							if v.eq(obj) {
486								found = true;
487								break;
488							}
489						}
490					}
491					assert!(
492						found,
493						"Search: {:?} - Knn: {} - Vector not found - Got: {:?} - Coll: {}",
494						obj,
495						knn,
496						res,
497						collection.len(),
498					);
499				}
500				let expected_len = collection.len().min(knn);
501				if expected_len != res.len() {
502					info!("expected_len != res.len()")
503				}
504				assert_eq!(
505					expected_len,
506					res.len(),
507					"Wrong knn count - Expected: {} - Got: {} - Collection: {} - - Res: {:?}",
508					expected_len,
509					res.len(),
510					collection.len(),
511					res,
512				)
513			}
514		}
515	}
516
517	async fn delete_collection_hnsw(
518		tx: &Transaction,
519		h: &mut HnswFlavor,
520		mut map: HashMap<ElementId, SharedVector>,
521	) {
522		let element_ids: Vec<ElementId> = map.keys().copied().collect();
523		for e_id in element_ids {
524			assert!(h.remove(tx, e_id).await.unwrap());
525			map.remove(&e_id);
526			h.check_hnsw_properties(map.len());
527		}
528	}
529
530	async fn test_hnsw_collection(p: &HnswParams, collection: &TestCollection) {
531		let ds = Datastore::new("memory").await.unwrap();
532		let mut h = HnswFlavor::new(IndexKeyBase::default(), p).unwrap();
533		let map = {
534			let tx = ds.transaction(TransactionType::Write, Optimistic).await.unwrap();
535			let map = insert_collection_hnsw(&tx, &mut h, collection).await;
536			tx.commit().await.unwrap();
537			map
538		};
539		{
540			let tx = ds.transaction(TransactionType::Read, Optimistic).await.unwrap();
541			find_collection_hnsw(&tx, &h, collection).await;
542			tx.cancel().await.unwrap();
543		}
544		{
545			let tx = ds.transaction(TransactionType::Write, Optimistic).await.unwrap();
546			delete_collection_hnsw(&tx, &mut h, map).await;
547			tx.commit().await.unwrap();
548		}
549	}
550
551	fn new_params(
552		dimension: usize,
553		vector_type: VectorType,
554		distance: Distance,
555		m: usize,
556		efc: usize,
557		extend_candidates: bool,
558		keep_pruned_connections: bool,
559	) -> HnswParams {
560		let m = m as u8;
561		let m0 = m * 2;
562		HnswParams::new(
563			dimension as u16,
564			distance,
565			vector_type,
566			m,
567			m0,
568			(1.0 / (m as f64).ln()).into(),
569			efc as u16,
570			extend_candidates,
571			keep_pruned_connections,
572		)
573	}
574
575	async fn test_hnsw(collection_size: usize, p: HnswParams) {
576		info!("Collection size: {collection_size} - Params: {p:?}");
577		let collection = TestCollection::new(
578			true,
579			collection_size,
580			p.vector_type,
581			p.dimension as usize,
582			&p.distance,
583		);
584		test_hnsw_collection(&p, &collection).await;
585	}
586
587	#[test(tokio::test(flavor = "multi_thread"))]
588	async fn tests_hnsw() -> Result<(), Error> {
589		let mut futures = Vec::new();
590		for (dist, dim) in [
591			(Distance::Chebyshev, 5),
592			(Distance::Cosine, 5),
593			(Distance::Euclidean, 5),
594			(Distance::Hamming, 20),
595			// (Distance::Jaccard, 100),
596			(Distance::Manhattan, 5),
597			(Distance::Minkowski(2.into()), 5),
598			// (Distance::Pearson, 5),
599		] {
600			for vt in [
601				VectorType::F64,
602				VectorType::F32,
603				VectorType::I64,
604				VectorType::I32,
605				VectorType::I16,
606			] {
607				for (extend, keep) in [(false, false), (true, false), (false, true), (true, true)] {
608					let p = new_params(dim, vt, dist.clone(), 24, 500, extend, keep);
609					let f = tokio::spawn(async move {
610						test_hnsw(30, p).await;
611					});
612					futures.push(f);
613				}
614			}
615		}
616		for f in futures {
617			f.await.expect("Task error");
618		}
619		Ok(())
620	}
621
622	async fn insert_collection_hnsw_index(
623		tx: &Transaction,
624		h: &mut HnswIndex,
625		collection: &TestCollection,
626	) -> Result<HashMap<SharedVector, HashSet<DocId>>, Error> {
627		let mut map: HashMap<SharedVector, HashSet<DocId>> = HashMap::default();
628		for (doc_id, obj) in collection.to_vec_ref() {
629			let content = vec![Value::from(obj.deref())];
630			h.index_document(tx, &Id::Number(*doc_id as i64), &content).await.unwrap();
631			match map.entry(obj.clone()) {
632				Entry::Occupied(mut e) => {
633					e.get_mut().insert(*doc_id);
634				}
635				Entry::Vacant(e) => {
636					e.insert(HashSet::from_iter([*doc_id]));
637				}
638			}
639			h.check_hnsw_properties(map.len());
640		}
641		Ok(map)
642	}
643
644	async fn find_collection_hnsw_index(
645		tx: &Transaction,
646		stk: &mut Stk,
647		h: &mut HnswIndex,
648		collection: &TestCollection,
649	) {
650		let max_knn = 20.min(collection.len());
651		for (doc_id, obj) in collection.to_vec_ref() {
652			for knn in 1..max_knn {
653				let mut chk = HnswConditionChecker::new();
654				let search = HnswSearch::new(obj.clone(), knn, 500);
655				let res = h.search(tx, stk, &search, &mut chk).await.unwrap();
656				if knn == 1 && res.docs.len() == 1 && res.docs[0].1 > 0.0 {
657					let docs: Vec<DocId> = res.docs.iter().map(|(d, _)| *d).collect();
658					if collection.is_unique() {
659						assert!(
660							docs.contains(doc_id),
661							"Search: {:?} - Knn: {} - Wrong Doc - Expected: {} - Got: {:?}",
662							obj,
663							knn,
664							doc_id,
665							res.docs
666						);
667					}
668				}
669				let expected_len = collection.len().min(knn);
670				assert_eq!(
671					expected_len,
672					res.docs.len(),
673					"Wrong knn count - Expected: {} - Got: {} - - Docs: {:?} - Collection: {}",
674					expected_len,
675					res.docs.len(),
676					res.docs,
677					collection.len(),
678				)
679			}
680		}
681	}
682
683	async fn delete_hnsw_index_collection(
684		tx: &Transaction,
685		h: &mut HnswIndex,
686		collection: &TestCollection,
687		mut map: HashMap<SharedVector, HashSet<DocId>>,
688	) -> Result<(), Error> {
689		for (doc_id, obj) in collection.to_vec_ref() {
690			let content = vec![Value::from(obj.deref())];
691			h.remove_document(tx, Id::Number(*doc_id as i64), &content).await?;
692			if let Entry::Occupied(mut e) = map.entry(obj.clone()) {
693				let set = e.get_mut();
694				set.remove(doc_id);
695				if set.is_empty() {
696					e.remove();
697				}
698			}
699			// Check properties
700			h.check_hnsw_properties(map.len());
701		}
702		Ok(())
703	}
704
705	async fn new_ctx(ds: &Datastore, tt: TransactionType) -> Context {
706		let tx = Arc::new(ds.transaction(tt, Optimistic).await.unwrap());
707		let mut ctx = MutableContext::default();
708		ctx.set_transaction(tx);
709		ctx.freeze()
710	}
711
712	async fn test_hnsw_index(collection_size: usize, unique: bool, p: HnswParams) {
713		info!("test_hnsw_index - coll size: {collection_size} - params: {p:?}");
714
715		let ds = Datastore::new("memory").await.unwrap();
716
717		let collection = TestCollection::new(
718			unique,
719			collection_size,
720			p.vector_type,
721			p.dimension as usize,
722			&p.distance,
723		);
724
725		// Create index
726		let (mut h, map) = {
727			let ctx = new_ctx(&ds, TransactionType::Write).await;
728			let tx = ctx.tx();
729			let mut h =
730				HnswIndex::new(&tx, IndexKeyBase::default(), "test".to_string(), &p).await.unwrap();
731			// Fill index
732			let map = insert_collection_hnsw_index(&tx, &mut h, &collection).await.unwrap();
733			tx.commit().await.unwrap();
734			(h, map)
735		};
736
737		// Search index
738		{
739			let mut stack = reblessive::tree::TreeStack::new();
740			let ctx = new_ctx(&ds, TransactionType::Read).await;
741			let tx = ctx.tx();
742			stack
743				.enter(|stk| async {
744					find_collection_hnsw_index(&tx, stk, &mut h, &collection).await;
745				})
746				.finish()
747				.await;
748		}
749
750		// Delete collection
751		{
752			let ctx = new_ctx(&ds, TransactionType::Write).await;
753			let tx = ctx.tx();
754			delete_hnsw_index_collection(&tx, &mut h, &collection, map).await.unwrap();
755			tx.commit().await.unwrap();
756		}
757	}
758
759	#[test(tokio::test(flavor = "multi_thread"))]
760	async fn tests_hnsw_index() -> Result<(), Error> {
761		let mut futures = Vec::new();
762		for (dist, dim) in [
763			(Distance::Chebyshev, 5),
764			(Distance::Cosine, 5),
765			(Distance::Euclidean, 5),
766			(Distance::Hamming, 20),
767			// (Distance::Jaccard, 100),
768			(Distance::Manhattan, 5),
769			(Distance::Minkowski(2.into()), 5),
770			// (Distance::Pearson, 5),
771		] {
772			for vt in [
773				VectorType::F64,
774				VectorType::F32,
775				VectorType::I64,
776				VectorType::I32,
777				VectorType::I16,
778			] {
779				for (extend, keep) in [(false, false), (true, false), (false, true), (true, true)] {
780					for unique in [true, false] {
781						let p = new_params(dim, vt, dist.clone(), 8, 150, extend, keep);
782						let f = tokio::spawn(async move {
783							test_hnsw_index(30, unique, p).await;
784						});
785						futures.push(f);
786					}
787				}
788			}
789		}
790		for f in futures {
791			f.await.expect("Task error");
792		}
793		Ok(())
794	}
795
796	#[test(tokio::test(flavor = "multi_thread"))]
797	async fn test_simple_hnsw() {
798		let collection = TestCollection::Unique(vec![
799			(0, new_i16_vec(-2, -3)),
800			(1, new_i16_vec(-2, 1)),
801			(2, new_i16_vec(-4, 3)),
802			(3, new_i16_vec(-3, 1)),
803			(4, new_i16_vec(-1, 1)),
804			(5, new_i16_vec(-2, 3)),
805			(6, new_i16_vec(3, 0)),
806			(7, new_i16_vec(-1, -2)),
807			(8, new_i16_vec(-2, 2)),
808			(9, new_i16_vec(-4, -2)),
809			(10, new_i16_vec(0, 3)),
810		]);
811		let ikb = IndexKeyBase::default();
812		let p = new_params(2, VectorType::I16, Distance::Euclidean, 3, 500, true, true);
813		let mut h = HnswFlavor::new(ikb, &p).unwrap();
814		let ds = Arc::new(Datastore::new("memory").await.unwrap());
815		{
816			let tx = ds.transaction(TransactionType::Write, Optimistic).await.unwrap();
817			insert_collection_hnsw(&tx, &mut h, &collection).await;
818			tx.commit().await.unwrap();
819		}
820		{
821			let tx = ds.transaction(TransactionType::Read, Optimistic).await.unwrap();
822			let search = HnswSearch::new(new_i16_vec(-2, -3), 10, 501);
823			let res = h.knn_search(&tx, &search).await.unwrap();
824			assert_eq!(res.len(), 10);
825		}
826	}
827
828	async fn test_recall(
829		embeddings_file: &str,
830		ingest_limit: usize,
831		queries_file: &str,
832		query_limit: usize,
833		p: HnswParams,
834		tests_ef_recall: &[(usize, f64)],
835	) -> Result<(), Error> {
836		info!("Build data collection");
837
838		let ds = Arc::new(Datastore::new("memory").await?);
839
840		let collection: Arc<TestCollection> =
841			Arc::new(TestCollection::NonUnique(new_vectors_from_file(
842				p.vector_type,
843				&format!("../../tests/data/{embeddings_file}"),
844				Some(ingest_limit),
845			)?));
846
847		let ctx = new_ctx(&ds, TransactionType::Write).await;
848		let tx = ctx.tx();
849		let mut h = HnswIndex::new(&tx, IndexKeyBase::default(), "Index".to_string(), &p).await?;
850		info!("Insert collection");
851		for (doc_id, obj) in collection.to_vec_ref() {
852			let content = vec![Value::from(obj.deref())];
853			h.index_document(&tx, &Id::Number(*doc_id as i64), &content).await?;
854		}
855		tx.commit().await?;
856
857		let h = Arc::new(h);
858
859		info!("Build query collection");
860		let queries = Arc::new(TestCollection::NonUnique(new_vectors_from_file(
861			p.vector_type,
862			&format!("../../tests/data/{queries_file}"),
863			Some(query_limit),
864		)?));
865
866		info!("Check recall");
867		let mut futures = Vec::with_capacity(tests_ef_recall.len());
868		for &(efs, expected_recall) in tests_ef_recall {
869			let queries = queries.clone();
870			let collection = collection.clone();
871			let h = h.clone();
872			let ds = ds.clone();
873			let f = tokio::spawn(async move {
874				let mut stack = reblessive::tree::TreeStack::new();
875				stack
876					.enter(|stk| async {
877						let mut total_recall = 0.0;
878						for (_, pt) in queries.to_vec_ref() {
879							let knn = 10;
880							let mut chk = HnswConditionChecker::new();
881							let search = HnswSearch::new(pt.clone(), knn, efs);
882							let ctx = new_ctx(&ds, TransactionType::Read).await;
883							let tx = ctx.tx();
884							let hnsw_res = h.search(&tx, stk, &search, &mut chk).await.unwrap();
885							assert_eq!(hnsw_res.docs.len(), knn, "Different size - knn: {knn}",);
886							let brute_force_res = collection.knn(pt, Distance::Euclidean, knn);
887							let rec = brute_force_res.recall(&hnsw_res);
888							if rec == 1.0 {
889								assert_eq!(brute_force_res.docs, hnsw_res.docs);
890							}
891							total_recall += rec;
892						}
893						let recall = total_recall / queries.to_vec_ref().len() as f64;
894						info!("EFS: {efs} - Recall: {recall}");
895						assert!(
896							recall >= expected_recall,
897							"EFS: {efs} - Recall: {recall} - Expected: {expected_recall}"
898						);
899					})
900					.finish()
901					.await;
902			});
903			futures.push(f);
904		}
905		for f in futures {
906			f.await.expect("Task failure");
907		}
908		Ok(())
909	}
910
911	#[test(tokio::test(flavor = "multi_thread"))]
912	async fn test_recall_euclidean() -> Result<(), Error> {
913		let p = new_params(20, VectorType::F32, Distance::Euclidean, 8, 100, false, false);
914		test_recall(
915			"hnsw-random-9000-20-euclidean.gz",
916			1000,
917			"hnsw-random-5000-20-euclidean.gz",
918			300,
919			p,
920			&[(10, 0.98), (40, 1.0)],
921		)
922		.await
923	}
924
925	#[test(tokio::test(flavor = "multi_thread"))]
926	async fn test_recall_euclidean_keep_pruned_connections() -> Result<(), Error> {
927		let p = new_params(20, VectorType::F32, Distance::Euclidean, 8, 100, false, true);
928		test_recall(
929			"hnsw-random-9000-20-euclidean.gz",
930			750,
931			"hnsw-random-5000-20-euclidean.gz",
932			200,
933			p,
934			&[(10, 0.98), (40, 1.0)],
935		)
936		.await
937	}
938
939	#[test(tokio::test(flavor = "multi_thread"))]
940	async fn test_recall_euclidean_full() -> Result<(), Error> {
941		let p = new_params(20, VectorType::F32, Distance::Euclidean, 8, 100, true, true);
942		test_recall(
943			"hnsw-random-9000-20-euclidean.gz",
944			500,
945			"hnsw-random-5000-20-euclidean.gz",
946			100,
947			p,
948			&[(10, 0.98), (40, 1.0)],
949		)
950		.await
951	}
952
953	impl TestCollection {
954		fn knn(&self, pt: &SharedVector, dist: Distance, n: usize) -> KnnResult {
955			let mut b = KnnResultBuilder::new(n);
956			for (doc_id, doc_pt) in self.to_vec_ref() {
957				let d = dist.calculate(doc_pt, pt);
958				if b.check_add(d) {
959					b.add(d, Ids64::One(*doc_id));
960				}
961			}
962			b.build(
963				#[cfg(debug_assertions)]
964				HashMap::default(),
965			)
966		}
967	}
968
969	impl KnnResult {
970		fn recall(&self, res: &KnnResult) -> f64 {
971			let mut bits = RoaringTreemap::new();
972			for &(doc_id, _) in &self.docs {
973				bits.insert(doc_id);
974			}
975			let mut found = 0;
976			for &(doc_id, _) in &res.docs {
977				if bits.contains(doc_id) {
978					found += 1;
979				}
980			}
981			found as f64 / bits.len() as f64
982		}
983	}
984
985	fn new_i16_vec(x: isize, y: isize) -> SharedVector {
986		let vec = Vector::I16(Array1::from_vec(vec![x as i16, y as i16]));
987		vec.into()
988	}
989}