datafusion_functions/core/
arrow_cast.rsuse arrow::datatypes::DataType;
use arrow::error::ArrowError;
use datafusion_common::{
arrow_datafusion_err, exec_err, internal_err, Result, ScalarValue,
};
use datafusion_common::{exec_datafusion_err, DataFusionError};
use std::any::Any;
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
use datafusion_expr::{
ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs, ScalarUDFImpl,
Signature, Volatility,
};
use datafusion_macros::user_doc;
#[user_doc(
doc_section(label = "Other Functions"),
description = "Casts a value to a specific Arrow data type.",
syntax_example = "arrow_cast(expression, datatype)",
sql_example = r#"```sql
> select arrow_cast(-5, 'Int8') as a,
arrow_cast('foo', 'Dictionary(Int32, Utf8)') as b,
arrow_cast('bar', 'LargeUtf8') as c,
arrow_cast('2023-01-02T12:53:02', 'Timestamp(Microsecond, Some("+08:00"))') as d
;
+----+-----+-----+---------------------------+
| a | b | c | d |
+----+-----+-----+---------------------------+
| -5 | foo | bar | 2023-01-02T12:53:02+08:00 |
+----+-----+-----+---------------------------+
```"#,
argument(
name = "expression",
description = "Expression to cast. The expression can be a constant, column, or function, and any combination of operators."
),
argument(
name = "datatype",
description = "[Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) name to cast to, as a string. The format is the same as that returned by [`arrow_typeof`]"
)
)]
#[derive(Debug)]
pub struct ArrowCastFunc {
signature: Signature,
}
impl Default for ArrowCastFunc {
fn default() -> Self {
Self::new()
}
}
impl ArrowCastFunc {
pub fn new() -> Self {
Self {
signature: Signature::any(2, Volatility::Immutable),
}
}
}
impl ScalarUDFImpl for ArrowCastFunc {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"arrow_cast"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
internal_err!("return_type_from_args should be called instead")
}
fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result<ReturnInfo> {
let nullable = args.nullables.iter().any(|&nullable| nullable);
debug_assert_eq!(args.scalar_arguments.len(), 2);
args.scalar_arguments[1]
.and_then(|sv| sv.try_as_str().flatten().filter(|s| !s.is_empty()))
.map_or_else(
|| {
exec_err!(
"{} requires its second argument to be a non-empty constant string",
self.name()
)
},
|casted_type| match casted_type.parse::<DataType>() {
Ok(data_type) => Ok(ReturnInfo::new(data_type, nullable)),
Err(ArrowError::ParseError(e)) => Err(exec_datafusion_err!("{e}")),
Err(e) => Err(arrow_datafusion_err!(e)),
},
)
}
fn invoke_batch(
&self,
_args: &[ColumnarValue],
_number_rows: usize,
) -> Result<ColumnarValue> {
internal_err!("arrow_cast should have been simplified to cast")
}
fn simplify(
&self,
mut args: Vec<Expr>,
info: &dyn SimplifyInfo,
) -> Result<ExprSimplifyResult> {
let target_type = data_type_from_args(&args)?;
args.pop().unwrap();
let arg = args.pop().unwrap();
let source_type = info.get_data_type(&arg)?;
let new_expr = if source_type == target_type {
arg
} else {
Expr::Cast(datafusion_expr::Cast {
expr: Box::new(arg),
data_type: target_type,
})
};
Ok(ExprSimplifyResult::Simplified(new_expr))
}
fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
}
fn data_type_from_args(args: &[Expr]) -> Result<DataType> {
if args.len() != 2 {
return exec_err!("arrow_cast needs 2 arguments, {} provided", args.len());
}
let Expr::Literal(ScalarValue::Utf8(Some(val))) = &args[1] else {
return exec_err!(
"arrow_cast requires its second argument to be a constant string, got {:?}",
&args[1]
);
};
val.parse().map_err(|e| match e {
ArrowError::ParseError(e) => exec_datafusion_err!("{e}"),
e => arrow_datafusion_err!(e),
})
}