1use crate::ast::*;
18use crate::entities::SchemaType;
19use crate::evaluator;
20use std::any::Any;
21use std::collections::{BTreeSet, HashMap};
22use std::fmt::Debug;
23use std::panic::{RefUnwindSafe, UnwindSafe};
24use std::sync::Arc;
25
26pub struct Extension {
33 name: Name,
35 functions: HashMap<Name, ExtensionFunction>,
37 types_with_operator_overloading: BTreeSet<Name>,
39}
40
41impl Extension {
42 pub fn new(
44 name: Name,
45 functions: impl IntoIterator<Item = ExtensionFunction>,
46 types_with_operator_overloading: impl IntoIterator<Item = Name>,
47 ) -> Self {
48 Self {
49 name,
50 functions: functions.into_iter().map(|f| (f.name.clone(), f)).collect(),
51 types_with_operator_overloading: types_with_operator_overloading.into_iter().collect(),
52 }
53 }
54
55 pub fn name(&self) -> &Name {
57 &self.name
58 }
59
60 pub fn get_func(&self, name: &Name) -> Option<&ExtensionFunction> {
63 self.functions.get(name)
64 }
65
66 pub fn funcs(&self) -> impl Iterator<Item = &ExtensionFunction> {
68 self.functions.values()
69 }
70
71 pub fn ext_types(&self) -> impl Iterator<Item = &Name> + '_ {
74 self.funcs().flat_map(|func| func.ext_types())
75 }
76
77 pub fn types_with_operator_overloading(&self) -> impl Iterator<Item = &Name> + '_ {
79 self.types_with_operator_overloading.iter()
80 }
81}
82
83impl std::fmt::Debug for Extension {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 write!(f, "<extension {}>", self.name())
86 }
87}
88
89#[derive(Debug, Clone)]
91pub enum ExtensionOutputValue {
92 Known(Value),
94 Unknown(Unknown),
96}
97
98impl<T> From<T> for ExtensionOutputValue
99where
100 T: Into<Value>,
101{
102 fn from(v: T) -> Self {
103 ExtensionOutputValue::Known(v.into())
104 }
105}
106
107#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
109#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
110pub enum CallStyle {
111 FunctionStyle,
113 MethodStyle,
115}
116
117macro_rules! extension_function_object {
120 ( $( $tys:ty ), * ) => {
121 Box<dyn Fn($($tys,)*) -> evaluator::Result<ExtensionOutputValue> + Sync + Send + 'static>
122 }
123}
124
125pub type ExtensionFunctionObject = extension_function_object!(&[Value]);
127pub type NullaryExtensionFunctionObject = extension_function_object!();
129pub type UnaryExtensionFunctionObject = extension_function_object!(&Value);
131pub type BinaryExtensionFunctionObject = extension_function_object!(&Value, &Value);
133pub type TernaryExtensionFunctionObject = extension_function_object!(&Value, &Value, &Value);
135
136pub struct ExtensionFunction {
139 name: Name,
141 style: CallStyle,
143 func: ExtensionFunctionObject,
146 return_type: Option<SchemaType>,
154 arg_types: Vec<SchemaType>,
156}
157
158impl ExtensionFunction {
159 fn new(
161 name: Name,
162 style: CallStyle,
163 func: ExtensionFunctionObject,
164 return_type: Option<SchemaType>,
165 arg_types: Vec<SchemaType>,
166 ) -> Self {
167 Self {
168 name,
169 func,
170 style,
171 return_type,
172 arg_types,
173 }
174 }
175
176 pub fn nullary(
178 name: Name,
179 style: CallStyle,
180 func: NullaryExtensionFunctionObject,
181 return_type: SchemaType,
182 ) -> Self {
183 Self::new(
184 name.clone(),
185 style,
186 Box::new(move |args: &[Value]| {
187 if args.is_empty() {
188 func()
189 } else {
190 Err(evaluator::EvaluationError::wrong_num_arguments(
191 name.clone(),
192 0,
193 args.len(),
194 None, ))
196 }
197 }),
198 Some(return_type),
199 vec![],
200 )
201 }
202
203 pub fn partial_eval_unknown(
206 name: Name,
207 style: CallStyle,
208 func: UnaryExtensionFunctionObject,
209 arg_type: SchemaType,
210 ) -> Self {
211 Self::new(
212 name.clone(),
213 style,
214 Box::new(move |args: &[Value]| match args.first() {
215 Some(arg) => func(arg),
216 None => Err(evaluator::EvaluationError::wrong_num_arguments(
217 name.clone(),
218 1,
219 args.len(),
220 None, )),
222 }),
223 None,
224 vec![arg_type],
225 )
226 }
227
228 #[allow(clippy::type_complexity)]
230 pub fn unary(
231 name: Name,
232 style: CallStyle,
233 func: UnaryExtensionFunctionObject,
234 return_type: SchemaType,
235 arg_type: SchemaType,
236 ) -> Self {
237 Self::new(
238 name.clone(),
239 style,
240 Box::new(move |args: &[Value]| match &args {
241 &[arg] => func(arg),
242 _ => Err(evaluator::EvaluationError::wrong_num_arguments(
243 name.clone(),
244 1,
245 args.len(),
246 None, )),
248 }),
249 Some(return_type),
250 vec![arg_type],
251 )
252 }
253
254 #[allow(clippy::type_complexity)]
256 pub fn binary(
257 name: Name,
258 style: CallStyle,
259 func: BinaryExtensionFunctionObject,
260 return_type: SchemaType,
261 arg_types: (SchemaType, SchemaType),
262 ) -> Self {
263 Self::new(
264 name.clone(),
265 style,
266 Box::new(move |args: &[Value]| match &args {
267 &[first, second] => func(first, second),
268 _ => Err(evaluator::EvaluationError::wrong_num_arguments(
269 name.clone(),
270 2,
271 args.len(),
272 None, )),
274 }),
275 Some(return_type),
276 vec![arg_types.0, arg_types.1],
277 )
278 }
279
280 #[allow(clippy::type_complexity)]
282 pub fn ternary(
283 name: Name,
284 style: CallStyle,
285 func: TernaryExtensionFunctionObject,
286 return_type: SchemaType,
287 arg_types: (SchemaType, SchemaType, SchemaType),
288 ) -> Self {
289 Self::new(
290 name.clone(),
291 style,
292 Box::new(move |args: &[Value]| match &args {
293 &[first, second, third] => func(first, second, third),
294 _ => Err(evaluator::EvaluationError::wrong_num_arguments(
295 name.clone(),
296 3,
297 args.len(),
298 None, )),
300 }),
301 Some(return_type),
302 vec![arg_types.0, arg_types.1, arg_types.2],
303 )
304 }
305
306 pub fn name(&self) -> &Name {
308 &self.name
309 }
310
311 pub fn style(&self) -> CallStyle {
313 self.style
314 }
315
316 pub fn return_type(&self) -> Option<&SchemaType> {
320 self.return_type.as_ref()
321 }
322
323 pub fn arg_types(&self) -> &[SchemaType] {
325 &self.arg_types
326 }
327
328 pub fn is_constructor(&self) -> bool {
333 matches!(self.return_type(), Some(SchemaType::Extension { .. }))
335 && !self.arg_types().iter().any(|ty| matches!(ty, SchemaType::Extension { .. }))
337 }
338
339 pub fn call(&self, args: &[Value]) -> evaluator::Result<PartialValue> {
341 match (self.func)(args)? {
342 ExtensionOutputValue::Known(v) => Ok(PartialValue::Value(v)),
343 ExtensionOutputValue::Unknown(u) => Ok(PartialValue::Residual(Expr::unknown(u))),
344 }
345 }
346
347 pub fn ext_types(&self) -> impl Iterator<Item = &Name> + '_ {
350 self.return_type
351 .iter()
352 .flat_map(|ret_ty| ret_ty.contained_ext_types())
353 }
354}
355
356impl std::fmt::Debug for ExtensionFunction {
357 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
358 write!(f, "<extension function {}>", self.name())
359 }
360}
361
362pub trait ExtensionValue: Debug + Send + Sync + UnwindSafe + RefUnwindSafe {
368 fn typename(&self) -> Name;
373
374 fn supports_operator_overloading(&self) -> bool;
376}
377
378impl<V: ExtensionValue> StaticallyTyped for V {
379 fn type_of(&self) -> Type {
380 Type::Extension {
381 name: self.typename(),
382 }
383 }
384}
385
386#[derive(Debug, Clone)]
387pub struct RepresentableExtensionValue {
394 pub(crate) func: Name,
395 pub(crate) args: Vec<RestrictedExpr>,
396 pub(crate) value: Arc<dyn InternalExtensionValue>,
397}
398
399impl RepresentableExtensionValue {
400 pub fn new(
402 value: Arc<dyn InternalExtensionValue + Send + Sync>,
403 func: Name,
404 args: Vec<RestrictedExpr>,
405 ) -> Self {
406 Self { value, func, args }
407 }
408
409 pub fn value(&self) -> &(dyn InternalExtensionValue) {
411 self.value.as_ref()
412 }
413
414 pub fn typename(&self) -> Name {
416 self.value.typename()
417 }
418
419 pub(crate) fn supports_operator_overloading(&self) -> bool {
421 self.value.supports_operator_overloading()
422 }
423}
424
425impl From<RepresentableExtensionValue> for RestrictedExpr {
426 fn from(val: RepresentableExtensionValue) -> Self {
427 RestrictedExpr::call_extension_fn(val.func, val.args)
428 }
429}
430
431impl StaticallyTyped for RepresentableExtensionValue {
432 fn type_of(&self) -> Type {
433 self.value.type_of()
434 }
435}
436
437impl PartialEq for RepresentableExtensionValue {
438 fn eq(&self, other: &Self) -> bool {
439 self.value.as_ref() == other.value.as_ref()
441 }
442}
443
444impl Eq for RepresentableExtensionValue {}
445
446impl PartialOrd for RepresentableExtensionValue {
447 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
448 Some(self.cmp(other))
449 }
450}
451
452impl Ord for RepresentableExtensionValue {
453 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
454 self.value.cmp(&other.value)
455 }
456}
457
458pub trait InternalExtensionValue: ExtensionValue {
471 fn as_any(&self) -> &dyn Any;
473 fn equals_extvalue(&self, other: &dyn InternalExtensionValue) -> bool;
476 fn cmp_extvalue(&self, other: &dyn InternalExtensionValue) -> std::cmp::Ordering;
479}
480
481impl<V: 'static + Eq + Ord + ExtensionValue + Send + Sync + Clone> InternalExtensionValue for V {
482 fn as_any(&self) -> &dyn Any {
483 self
484 }
485
486 fn equals_extvalue(&self, other: &dyn InternalExtensionValue) -> bool {
487 other
488 .as_any()
489 .downcast_ref::<V>()
490 .map(|v| self == v)
491 .unwrap_or(false) }
493
494 fn cmp_extvalue(&self, other: &dyn InternalExtensionValue) -> std::cmp::Ordering {
495 other
496 .as_any()
497 .downcast_ref::<V>()
498 .map(|v| self.cmp(v))
499 .unwrap_or_else(|| {
500 self.typename().cmp(&other.typename())
503 })
504 }
505}
506
507impl PartialEq for dyn InternalExtensionValue {
508 fn eq(&self, other: &Self) -> bool {
509 self.equals_extvalue(other)
510 }
511}
512
513impl Eq for dyn InternalExtensionValue {}
514
515impl PartialOrd for dyn InternalExtensionValue {
516 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
517 Some(self.cmp(other))
518 }
519}
520
521impl Ord for dyn InternalExtensionValue {
522 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
523 self.cmp_extvalue(other)
524 }
525}
526
527impl StaticallyTyped for dyn InternalExtensionValue {
528 fn type_of(&self) -> Type {
529 Type::Extension {
530 name: self.typename(),
531 }
532 }
533}