1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//! [`ArrowCastFunc`]: Implementation of the `arrow_cast`
use std::any::Any;
use arrow::datatypes::DataType;
use datafusion_common::{
arrow_datafusion_err, internal_err, plan_datafusion_err, plan_err, DataFusionError,
ExprSchema, Result, ScalarValue,
};
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility};
/// Implements casting to arbitrary arrow types (rather than SQL types)
///
/// Note that the `arrow_cast` function is somewhat special in that its
/// return depends only on the *value* of its second argument (not its type)
///
/// It is implemented by calling the same underlying arrow `cast` kernel as
/// normal SQL casts.
///
/// For example to cast to `int` using SQL (which is then mapped to the arrow
/// type `Int32`)
///
/// ```sql
/// select cast(column_x as int) ...
/// ```
///
/// Use the `arrow_cast` function to cast to a specific arrow type
///
/// For example
/// ```sql
/// select arrow_cast(column_x, 'Float64')
/// ```
#[derive(Debug)]
pub struct ArrowCastFunc {
signature: Signature,
}
impl Default for ArrowCastFunc {
fn default() -> Self {
Self::new()
}
}
impl ArrowCastFunc {
pub fn new() -> Self {
Self {
signature: Signature::any(2, Volatility::Immutable),
}
}
}
impl ScalarUDFImpl for ArrowCastFunc {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"arrow_cast"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
// should be using return_type_from_exprs and not calling the default
// implementation
internal_err!("arrow_cast should return type from exprs")
}
fn return_type_from_exprs(
&self,
args: &[Expr],
_schema: &dyn ExprSchema,
_arg_types: &[DataType],
) -> Result<DataType> {
data_type_from_args(args)
}
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
internal_err!("arrow_cast should have been simplified to cast")
}
fn simplify(
&self,
mut args: Vec<Expr>,
info: &dyn SimplifyInfo,
) -> Result<ExprSimplifyResult> {
// convert this into a real cast
let target_type = data_type_from_args(&args)?;
// remove second (type) argument
args.pop().unwrap();
let arg = args.pop().unwrap();
let source_type = info.get_data_type(&arg)?;
let new_expr = if source_type == target_type {
// the argument's data type is already the correct type
arg
} else {
// Use an actual cast to get the correct type
Expr::Cast(datafusion_expr::Cast {
expr: Box::new(arg),
data_type: target_type,
})
};
// return the newly written argument to DataFusion
Ok(ExprSimplifyResult::Simplified(new_expr))
}
}
/// Returns the requested type from the arguments
fn data_type_from_args(args: &[Expr]) -> Result<DataType> {
if args.len() != 2 {
return plan_err!("arrow_cast needs 2 arguments, {} provided", args.len());
}
let Expr::Literal(ScalarValue::Utf8(Some(val))) = &args[1] else {
return plan_err!(
"arrow_cast requires its second argument to be a constant string, got {:?}",
&args[1]
);
};
val.parse().map_err(|e| match e {
// If the data type cannot be parsed, return a Plan error to signal an
// error in the input rather than a more general ArrowError
arrow::error::ArrowError::ParseError(e) => plan_datafusion_err!("{e}"),
e => arrow_datafusion_err!(e),
})
}