datafusion_functions/core/
arrow_cast.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ArrowCastFunc`]: Implementation of the `arrow_cast`
19
20use arrow::datatypes::DataType;
21use arrow::error::ArrowError;
22use datafusion_common::{
23    arrow_datafusion_err, exec_err, internal_err, Result, ScalarValue,
24};
25use datafusion_common::{
26    exec_datafusion_err, utils::take_function_args, DataFusionError,
27};
28use std::any::Any;
29
30use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
31use datafusion_expr::{
32    ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs,
33    ScalarUDFImpl, Signature, Volatility,
34};
35use datafusion_macros::user_doc;
36
37/// Implements casting to arbitrary arrow types (rather than SQL types)
38///
39/// Note that the `arrow_cast` function is somewhat special in that its
40/// return depends only on the *value* of its second argument (not its type)
41///
42/// It is implemented by calling the same underlying arrow `cast` kernel as
43/// normal SQL casts.
44///
45/// For example to cast to `int` using SQL  (which is then mapped to the arrow
46/// type `Int32`)
47///
48/// ```sql
49/// select cast(column_x as int) ...
50/// ```
51///
52/// Use the `arrow_cast` function to cast to a specific arrow type
53///
54/// For example
55/// ```sql
56/// select arrow_cast(column_x, 'Float64')
57/// ```
58#[user_doc(
59    doc_section(label = "Other Functions"),
60    description = "Casts a value to a specific Arrow data type.",
61    syntax_example = "arrow_cast(expression, datatype)",
62    sql_example = r#"```sql
63> select arrow_cast(-5, 'Int8') as a,
64  arrow_cast('foo', 'Dictionary(Int32, Utf8)') as b,
65  arrow_cast('bar', 'LargeUtf8') as c,
66  arrow_cast('2023-01-02T12:53:02', 'Timestamp(Microsecond, Some("+08:00"))') as d
67  ;
68+----+-----+-----+---------------------------+
69| a  | b   | c   | d                         |
70+----+-----+-----+---------------------------+
71| -5 | foo | bar | 2023-01-02T12:53:02+08:00 |
72+----+-----+-----+---------------------------+
73```"#,
74    argument(
75        name = "expression",
76        description = "Expression to cast. The expression can be a constant, column, or function, and any combination of operators."
77    ),
78    argument(
79        name = "datatype",
80        description = "[Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) name to cast to, as a string. The format is the same as that returned by [`arrow_typeof`]"
81    )
82)]
83#[derive(Debug)]
84pub struct ArrowCastFunc {
85    signature: Signature,
86}
87
88impl Default for ArrowCastFunc {
89    fn default() -> Self {
90        Self::new()
91    }
92}
93
94impl ArrowCastFunc {
95    pub fn new() -> Self {
96        Self {
97            signature: Signature::any(2, Volatility::Immutable),
98        }
99    }
100}
101
102impl ScalarUDFImpl for ArrowCastFunc {
103    fn as_any(&self) -> &dyn Any {
104        self
105    }
106
107    fn name(&self) -> &str {
108        "arrow_cast"
109    }
110
111    fn signature(&self) -> &Signature {
112        &self.signature
113    }
114
115    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
116        internal_err!("return_type_from_args should be called instead")
117    }
118
119    fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result<ReturnInfo> {
120        let nullable = args.nullables.iter().any(|&nullable| nullable);
121
122        let [_, type_arg] = take_function_args(self.name(), args.scalar_arguments)?;
123
124        type_arg
125            .and_then(|sv| sv.try_as_str().flatten().filter(|s| !s.is_empty()))
126            .map_or_else(
127                || {
128                    exec_err!(
129                        "{} requires its second argument to be a non-empty constant string",
130                        self.name()
131                    )
132                },
133                |casted_type| match casted_type.parse::<DataType>() {
134                    Ok(data_type) => Ok(ReturnInfo::new(data_type, nullable)),
135                    Err(ArrowError::ParseError(e)) => Err(exec_datafusion_err!("{e}")),
136                    Err(e) => Err(arrow_datafusion_err!(e)),
137                },
138            )
139    }
140
141    fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
142        internal_err!("arrow_cast should have been simplified to cast")
143    }
144
145    fn simplify(
146        &self,
147        mut args: Vec<Expr>,
148        info: &dyn SimplifyInfo,
149    ) -> Result<ExprSimplifyResult> {
150        // convert this into a real cast
151        let target_type = data_type_from_args(&args)?;
152        // remove second (type) argument
153        args.pop().unwrap();
154        let arg = args.pop().unwrap();
155
156        let source_type = info.get_data_type(&arg)?;
157        let new_expr = if source_type == target_type {
158            // the argument's data type is already the correct type
159            arg
160        } else {
161            // Use an actual cast to get the correct type
162            Expr::Cast(datafusion_expr::Cast {
163                expr: Box::new(arg),
164                data_type: target_type,
165            })
166        };
167        // return the newly written argument to DataFusion
168        Ok(ExprSimplifyResult::Simplified(new_expr))
169    }
170
171    fn documentation(&self) -> Option<&Documentation> {
172        self.doc()
173    }
174}
175
176/// Returns the requested type from the arguments
177fn data_type_from_args(args: &[Expr]) -> Result<DataType> {
178    let [_, type_arg] = take_function_args("arrow_cast", args)?;
179
180    let Expr::Literal(ScalarValue::Utf8(Some(val))) = type_arg else {
181        return exec_err!(
182            "arrow_cast requires its second argument to be a constant string, got {:?}",
183            type_arg
184        );
185    };
186
187    val.parse().map_err(|e| match e {
188        // If the data type cannot be parsed, return a Plan error to signal an
189        // error in the input rather than a more general ArrowError
190        ArrowError::ParseError(e) => exec_datafusion_err!("{e}"),
191        e => arrow_datafusion_err!(e),
192    })
193}