surrealdb_core/idx/planner/
checker.rs

1use crate::ctx::Context;
2use crate::dbs::{Iterable, Options};
3use crate::doc::CursorDoc;
4use crate::err::Error;
5use crate::idx::docids::{DocId, DocIds};
6use crate::idx::planner::iterators::KnnIteratorResult;
7use crate::idx::trees::hnsw::docs::HnswDocs;
8use crate::idx::trees::knn::Ids64;
9use crate::kvs::Transaction;
10use crate::sql::{Cond, Thing, Value};
11use ahash::HashMap;
12use reblessive::tree::Stk;
13use std::collections::hash_map::Entry;
14use std::collections::VecDeque;
15use std::sync::Arc;
16
17pub enum HnswConditionChecker<'a> {
18	Hnsw(HnswChecker),
19	HnswCondition(HnswCondChecker<'a>),
20}
21
22pub enum MTreeConditionChecker<'a> {
23	MTree(MTreeChecker<'a>),
24	MTreeCondition(MTreeCondChecker<'a>),
25}
26
27impl<'a> HnswConditionChecker<'a> {
28	pub(in crate::idx) fn new() -> Self {
29		Self::Hnsw(HnswChecker {})
30	}
31
32	pub(in crate::idx) fn new_cond(ctx: &'a Context, opt: &'a Options, cond: Arc<Cond>) -> Self {
33		Self::HnswCondition(HnswCondChecker {
34			ctx,
35			opt,
36			cond,
37			cache: Default::default(),
38		})
39	}
40
41	pub(in crate::idx) async fn check_truthy(
42		&mut self,
43		tx: &Transaction,
44		stk: &mut Stk,
45		docs: &HnswDocs,
46		doc_ids: Ids64,
47	) -> Result<bool, Error> {
48		match self {
49			Self::HnswCondition(c) => c.check_any_truthy(tx, stk, docs, doc_ids).await,
50			Self::Hnsw(_) => Ok(true),
51		}
52	}
53
54	pub(in crate::idx) fn expire(&mut self, doc_id: u64) {
55		if let Self::HnswCondition(c) = self {
56			c.expire(doc_id)
57		}
58	}
59
60	pub(in crate::idx) fn expires(&mut self, doc_ids: Ids64) {
61		if let Self::HnswCondition(c) = self {
62			c.expires(doc_ids)
63		}
64	}
65
66	pub(in crate::idx) async fn convert_result(
67		&mut self,
68		tx: &Transaction,
69		docs: &HnswDocs,
70		res: VecDeque<(DocId, f64)>,
71	) -> Result<VecDeque<KnnIteratorResult>, Error> {
72		match self {
73			Self::Hnsw(c) => c.convert_result(tx, docs, res).await,
74			Self::HnswCondition(c) => Ok(c.convert_result(res)),
75		}
76	}
77}
78
79impl<'a> MTreeConditionChecker<'a> {
80	pub fn new_cond(ctx: &'a Context, opt: &'a Options, cond: Arc<Cond>) -> Self {
81		if Cond(Value::Bool(true)).ne(cond.as_ref()) {
82			Self::MTreeCondition(MTreeCondChecker {
83				ctx,
84				opt,
85				cond,
86				cache: Default::default(),
87			})
88		} else {
89			Self::new(ctx)
90		}
91	}
92
93	pub fn new(ctx: &'a Context) -> Self {
94		Self::MTree(MTreeChecker {
95			ctx,
96		})
97	}
98
99	pub(in crate::idx) async fn check_truthy(
100		&mut self,
101		stk: &mut Stk,
102		doc_ids: &DocIds,
103		doc_id: DocId,
104	) -> Result<bool, Error> {
105		match self {
106			Self::MTreeCondition(c) => c.check_truthy(stk, doc_ids, doc_id).await,
107			Self::MTree(_) => Ok(true),
108		}
109	}
110
111	pub(in crate::idx) fn expires(&mut self, ids: Ids64) {
112		if let Self::MTreeCondition(c) = self {
113			c.expires(ids)
114		}
115	}
116
117	pub(in crate::idx) async fn convert_result(
118		&mut self,
119		doc_ids: &DocIds,
120		res: VecDeque<(DocId, f64)>,
121	) -> Result<VecDeque<KnnIteratorResult>, Error> {
122		match self {
123			Self::MTree(c) => c.convert_result(doc_ids, res).await,
124			Self::MTreeCondition(c) => Ok(c.convert_result(res)),
125		}
126	}
127}
128
129pub struct MTreeChecker<'a> {
130	ctx: &'a Context,
131}
132
133impl MTreeChecker<'_> {
134	async fn convert_result(
135		&self,
136		doc_ids: &DocIds,
137		res: VecDeque<(DocId, f64)>,
138	) -> Result<VecDeque<KnnIteratorResult>, Error> {
139		if res.is_empty() {
140			return Ok(VecDeque::from([]));
141		}
142		let mut result = VecDeque::with_capacity(res.len());
143		let txn = self.ctx.tx();
144		for (doc_id, dist) in res {
145			if let Some(key) = doc_ids.get_doc_key(&txn, doc_id).await? {
146				let rid: Thing = revision::from_slice(&key)?;
147				result.push_back((rid.into(), dist, None));
148			}
149		}
150		Ok(result)
151	}
152}
153
154struct CheckerCacheEntry {
155	record: Option<(Arc<Thing>, Arc<Value>)>,
156	truthy: bool,
157}
158
159impl CheckerCacheEntry {
160	fn convert_result(
161		res: VecDeque<(DocId, f64)>,
162		cache: &mut HashMap<DocId, CheckerCacheEntry>,
163	) -> VecDeque<KnnIteratorResult> {
164		let mut result = VecDeque::with_capacity(res.len());
165		for (doc_id, dist) in res {
166			if let Some(e) = cache.remove(&doc_id) {
167				if e.truthy {
168					if let Some((rid, value)) = e.record {
169						result.push_back((rid, dist, Some(value)))
170					}
171				}
172			}
173		}
174		result
175	}
176
177	async fn build(
178		stk: &mut Stk,
179		ctx: &Context,
180		opt: &Options,
181		rid: Option<Thing>,
182		cond: &Cond,
183	) -> Result<Self, Error> {
184		if let Some(rid) = rid {
185			let rid = Arc::new(rid);
186			let txn = ctx.tx();
187			let val = Iterable::fetch_thing(&txn, opt, &rid).await?;
188			if !val.is_none_or_null() {
189				let (value, truthy) = {
190					let mut cursor_doc = CursorDoc {
191						rid: Some(rid.clone()),
192						ir: None,
193						doc: val.into(),
194					};
195					let truthy = cond.compute(stk, ctx, opt, Some(&cursor_doc)).await?.is_truthy();
196					(cursor_doc.doc.as_arc(), truthy)
197				};
198				return Ok(CheckerCacheEntry {
199					record: Some((rid, value)),
200					truthy,
201				});
202			}
203		}
204		Ok(CheckerCacheEntry {
205			record: None,
206			truthy: false,
207		})
208	}
209}
210
211pub struct MTreeCondChecker<'a> {
212	ctx: &'a Context,
213	opt: &'a Options,
214	cond: Arc<Cond>,
215	cache: HashMap<DocId, CheckerCacheEntry>,
216}
217
218impl MTreeCondChecker<'_> {
219	async fn check_truthy(
220		&mut self,
221		stk: &mut Stk,
222		doc_ids: &DocIds,
223		doc_id: u64,
224	) -> Result<bool, Error> {
225		match self.cache.entry(doc_id) {
226			Entry::Occupied(e) => Ok(e.get().truthy),
227			Entry::Vacant(e) => {
228				let txn = self.ctx.tx();
229				let rid = doc_ids
230					.get_doc_key(&txn, doc_id)
231					.await?
232					.map(|k| revision::from_slice(&k))
233					.transpose()?;
234				let ent =
235					CheckerCacheEntry::build(stk, self.ctx, self.opt, rid, self.cond.as_ref())
236						.await?;
237				let truthy = ent.truthy;
238				e.insert(ent);
239				Ok(truthy)
240			}
241		}
242	}
243
244	fn expire(&mut self, doc_id: DocId) {
245		self.cache.remove(&doc_id);
246	}
247
248	fn expires(&mut self, doc_ids: Ids64) {
249		for doc_id in doc_ids.iter() {
250			self.expire(doc_id);
251		}
252	}
253
254	fn convert_result(&mut self, res: VecDeque<(DocId, f64)>) -> VecDeque<KnnIteratorResult> {
255		CheckerCacheEntry::convert_result(res, &mut self.cache)
256	}
257}
258
259pub struct HnswChecker {}
260
261impl HnswChecker {
262	async fn convert_result(
263		&self,
264		tx: &Transaction,
265		docs: &HnswDocs,
266		res: VecDeque<(DocId, f64)>,
267	) -> Result<VecDeque<KnnIteratorResult>, Error> {
268		if res.is_empty() {
269			return Ok(VecDeque::from([]));
270		}
271		let mut result = VecDeque::with_capacity(res.len());
272		for (doc_id, dist) in res {
273			if let Some(rid) = docs.get_thing(tx, doc_id).await? {
274				result.push_back((rid.clone().into(), dist, None));
275			}
276		}
277		Ok(result)
278	}
279}
280
281pub struct HnswCondChecker<'a> {
282	ctx: &'a Context,
283	opt: &'a Options,
284	cond: Arc<Cond>,
285	cache: HashMap<DocId, CheckerCacheEntry>,
286}
287
288impl HnswCondChecker<'_> {
289	fn convert_result(&mut self, res: VecDeque<(DocId, f64)>) -> VecDeque<KnnIteratorResult> {
290		CheckerCacheEntry::convert_result(res, &mut self.cache)
291	}
292
293	async fn check_any_truthy(
294		&mut self,
295		tx: &Transaction,
296		stk: &mut Stk,
297		docs: &HnswDocs,
298		doc_ids: Ids64,
299	) -> Result<bool, Error> {
300		let mut res = false;
301		for doc_id in doc_ids.iter() {
302			if match self.cache.entry(doc_id) {
303				Entry::Occupied(e) => e.get().truthy,
304				Entry::Vacant(e) => {
305					let rid = docs.get_thing(tx, doc_id).await?;
306					let ent =
307						CheckerCacheEntry::build(stk, self.ctx, self.opt, rid, self.cond.as_ref())
308							.await?;
309					let truthy = ent.truthy;
310					e.insert(ent);
311					truthy
312				}
313			} {
314				res = true;
315			}
316		}
317		Ok(res)
318	}
319
320	fn expire(&mut self, doc_id: DocId) {
321		self.cache.remove(&doc_id);
322	}
323
324	fn expires(&mut self, doc_ids: Ids64) {
325		for doc_id in doc_ids.iter() {
326			self.expire(doc_id);
327		}
328	}
329}