use std::any::Any;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use crate::physical_expr::down_cast_any_ref;
use crate::sort_properties::SortProperties;
use crate::PhysicalExpr;
use arrow::{
datatypes::{DataType, Schema},
record_batch::RecordBatch,
};
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{ColumnarValue, Expr};
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct Literal {
value: ScalarValue,
}
impl Literal {
pub fn new(value: ScalarValue) -> Self {
Self { value }
}
pub fn value(&self) -> &ScalarValue {
&self.value
}
}
impl std::fmt::Display for Literal {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}", self.value)
}
}
impl PhysicalExpr for Literal {
fn as_any(&self) -> &dyn Any {
self
}
fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
Ok(self.value.get_datatype())
}
fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
Ok(self.value.is_null())
}
fn evaluate(&self, _batch: &RecordBatch) -> Result<ColumnarValue> {
Ok(ColumnarValue::Scalar(self.value.clone()))
}
fn children(&self) -> Vec<Arc<dyn PhysicalExpr>> {
vec![]
}
fn with_new_children(
self: Arc<Self>,
_children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>> {
Ok(self)
}
fn dyn_hash(&self, state: &mut dyn Hasher) {
let mut s = state;
self.hash(&mut s);
}
fn get_ordering(&self, _children: &[SortProperties]) -> SortProperties {
SortProperties::Singleton
}
}
impl PartialEq<dyn Any> for Literal {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| self == x)
.unwrap_or(false)
}
}
pub fn lit<T: datafusion_expr::Literal>(value: T) -> Arc<dyn PhysicalExpr> {
match value.lit() {
Expr::Literal(v) => Arc::new(Literal::new(v)),
_ => unreachable!(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::Int32Array;
use arrow::datatypes::*;
use datafusion_common::cast::as_int32_array;
use datafusion_common::Result;
#[test]
fn literal_i32() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
let a = Int32Array::from(vec![Some(1), None, Some(3), Some(4), Some(5)]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
let literal_expr = lit(42i32);
assert_eq!("42", format!("{literal_expr}"));
let literal_array = literal_expr.evaluate(&batch)?.into_array(batch.num_rows());
let literal_array = as_int32_array(&literal_array)?;
assert_eq!(literal_array.len(), 5); for i in 0..literal_array.len() {
assert_eq!(literal_array.value(i), 42);
}
Ok(())
}
}