surrealdb_core/fnc/
vector.rs

1use crate::err::Error;
2use crate::fnc::util::math::vector::{
3	Add, Angle, CrossProduct, Divide, DotProduct, Magnitude, Multiply, Normalize, Project, Scale,
4	Subtract,
5};
6use crate::sql::{Number, Value};
7
8pub fn add((a, b): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
9	Ok(a.add(&b)?.into())
10}
11
12pub fn angle((a, b): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
13	Ok(a.angle(&b)?.into())
14}
15
16pub fn divide((a, b): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
17	Ok(a.divide(&b)?.into())
18}
19
20pub fn cross((a, b): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
21	Ok(a.cross(&b)?.into())
22}
23
24pub fn dot((a, b): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
25	Ok(a.dot(&b)?.into())
26}
27
28pub fn magnitude((a,): (Vec<Number>,)) -> Result<Value, Error> {
29	Ok(a.magnitude().into())
30}
31
32pub fn multiply((a, b): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
33	Ok(a.multiply(&b)?.into())
34}
35
36pub fn normalize((a,): (Vec<Number>,)) -> Result<Value, Error> {
37	Ok(a.normalize().into())
38}
39
40pub fn project((a, b): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
41	Ok(a.project(&b)?.into())
42}
43
44pub fn subtract((a, b): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
45	Ok(a.subtract(&b)?.into())
46}
47
48pub fn scale((a, b): (Vec<Number>, Number)) -> Result<Value, Error> {
49	Ok(a.scale(&b)?.into())
50}
51
52pub mod distance {
53	use crate::ctx::Context;
54	use crate::doc::CursorDoc;
55	use crate::err::Error;
56	use crate::fnc::get_execution_context;
57	use crate::fnc::util::math::vector::{
58		ChebyshevDistance, EuclideanDistance, HammingDistance, ManhattanDistance, MinkowskiDistance,
59	};
60	use crate::idx::planner::IterationStage;
61	use crate::sql::{Number, Value};
62
63	pub fn chebyshev((a, b): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
64		Ok(a.chebyshev_distance(&b)?.into())
65	}
66
67	pub fn euclidean((a, b): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
68		Ok(a.euclidean_distance(&b)?.into())
69	}
70
71	pub fn hamming((a, b): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
72		Ok(a.hamming_distance(&b)?.into())
73	}
74
75	pub fn knn(
76		(ctx, doc): (&Context, Option<&CursorDoc>),
77		(knn_ref,): (Option<Value>,),
78	) -> Result<Value, Error> {
79		if let Some((_exe, doc, thg)) = get_execution_context(ctx, doc) {
80			if let Some(ir) = &doc.ir {
81				if let Some(d) = ir.dist() {
82					return Ok(d.into());
83				}
84			}
85			if let Some(IterationStage::Iterate(Some(results))) = ctx.get_iteration_stage() {
86				let n = if let Some(Value::Number(n)) = knn_ref {
87					n.as_usize()
88				} else {
89					0
90				};
91				if let Some(d) = results.get_dist(n, thg) {
92					return Ok(d.into());
93				}
94			}
95		}
96		Ok(Value::None)
97	}
98
99	pub fn mahalanobis((_, _): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
100		Err(Error::FeatureNotYetImplemented {
101			feature: "vector::distance::mahalanobis() function".to_string(),
102		})
103	}
104
105	pub fn manhattan((a, b): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
106		Ok(a.manhattan_distance(&b)?.into())
107	}
108
109	pub fn minkowski((a, b, o): (Vec<Number>, Vec<Number>, Number)) -> Result<Value, Error> {
110		Ok(a.minkowski_distance(&b, &o)?.into())
111	}
112}
113
114pub mod similarity {
115
116	use crate::err::Error;
117	use crate::fnc::util::math::vector::{CosineSimilarity, JaccardSimilarity, PearsonSimilarity};
118	use crate::sql::{Number, Value};
119
120	pub fn cosine((a, b): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
121		Ok(a.cosine_similarity(&b)?.into())
122	}
123
124	pub fn jaccard((a, b): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
125		Ok(a.jaccard_similarity(&b)?.into())
126	}
127
128	pub fn pearson((a, b): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
129		Ok(a.pearson_similarity(&b)?.into())
130	}
131
132	pub fn spearman((_, _): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
133		Err(Error::FeatureNotYetImplemented {
134			feature: "vector::similarity::spearman() function".to_string(),
135		})
136	}
137}
138
139impl TryFrom<&Value> for Vec<Number> {
140	type Error = Error;
141
142	fn try_from(val: &Value) -> Result<Self, Self::Error> {
143		if let Value::Array(a) = val {
144			a.iter()
145				.map(|v| v.try_into())
146				.collect::<Result<Self, Error>>()
147				.map_err(|e| Error::InvalidVectorValue(e.to_string()))
148		} else {
149			Err(Error::InvalidVectorValue(val.to_string()))
150		}
151	}
152}
153
154impl TryFrom<Value> for Vec<Number> {
155	type Error = Error;
156
157	fn try_from(val: Value) -> Result<Self, Self::Error> {
158		if let Value::Array(a) = val {
159			a.into_iter()
160				.map(Value::try_into)
161				.collect::<Result<Self, Error>>()
162				.map_err(|e| Error::InvalidVectorValue(e.to_string()))
163		} else {
164			Err(Error::InvalidVectorValue(val.to_string()))
165		}
166	}
167}
168
169#[cfg(test)]
170mod tests {
171	use super::*;
172	use crate::sql::Number;
173	use rust_decimal::Decimal;
174
175	#[test]
176	fn vector_scale_int() {
177		let input_vector: Vec<Number> = vec![1, 2, 3, 4].into_iter().map(Number::Int).collect();
178		let scalar_int = Number::Int(2);
179
180		let result: Result<Value, Error> = scale((input_vector.clone(), scalar_int));
181
182		let expected_output: Vec<Number> = vec![2, 4, 6, 8].into_iter().map(Number::Int).collect();
183
184		assert_eq!(result.unwrap(), expected_output.into());
185	}
186
187	#[test]
188	fn vector_scale_float() {
189		let input_vector: Vec<Number> = vec![1, 2, 3, 4].into_iter().map(Number::Int).collect();
190		let scalar_float = Number::Float(1.51);
191
192		let result: Result<Value, Error> = scale((input_vector.clone(), scalar_float));
193		let expected_output: Vec<Number> =
194			vec![1.51, 3.02, 4.53, 6.04].into_iter().map(Number::Float).collect();
195		assert_eq!(result.unwrap(), expected_output.into());
196	}
197
198	#[test]
199	fn vector_scale_decimal() {
200		let input_vector: Vec<Number> = vec![1, 2, 3, 4].into_iter().map(Number::Int).collect();
201		let scalar_decimal = Number::Decimal(Decimal::new(3141, 3));
202
203		let result: Result<Value, Error> = scale((input_vector.clone(), scalar_decimal));
204		let expected_output: Vec<Number> = vec![
205			Number::Decimal(Decimal::new(3141, 3)),  // 3.141 * 1
206			Number::Decimal(Decimal::new(6282, 3)),  // 3.141 * 2
207			Number::Decimal(Decimal::new(9423, 3)),  // 3.141 * 3
208			Number::Decimal(Decimal::new(12564, 3)), // 3.141 * 4
209		];
210		assert_eq!(result.unwrap(), expected_output.into());
211	}
212}