datafusion_functions/core/
arrow_cast.rs1use arrow::datatypes::DataType;
21use arrow::error::ArrowError;
22use datafusion_common::{
23 arrow_datafusion_err, exec_err, internal_err, Result, ScalarValue,
24};
25use datafusion_common::{
26 exec_datafusion_err, utils::take_function_args, DataFusionError,
27};
28use std::any::Any;
29
30use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
31use datafusion_expr::{
32 ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs,
33 ScalarUDFImpl, Signature, Volatility,
34};
35use datafusion_macros::user_doc;
36
37#[user_doc(
59 doc_section(label = "Other Functions"),
60 description = "Casts a value to a specific Arrow data type.",
61 syntax_example = "arrow_cast(expression, datatype)",
62 sql_example = r#"```sql
63> select arrow_cast(-5, 'Int8') as a,
64 arrow_cast('foo', 'Dictionary(Int32, Utf8)') as b,
65 arrow_cast('bar', 'LargeUtf8') as c,
66 arrow_cast('2023-01-02T12:53:02', 'Timestamp(Microsecond, Some("+08:00"))') as d
67 ;
68+----+-----+-----+---------------------------+
69| a | b | c | d |
70+----+-----+-----+---------------------------+
71| -5 | foo | bar | 2023-01-02T12:53:02+08:00 |
72+----+-----+-----+---------------------------+
73```"#,
74 argument(
75 name = "expression",
76 description = "Expression to cast. The expression can be a constant, column, or function, and any combination of operators."
77 ),
78 argument(
79 name = "datatype",
80 description = "[Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) name to cast to, as a string. The format is the same as that returned by [`arrow_typeof`]"
81 )
82)]
83#[derive(Debug)]
84pub struct ArrowCastFunc {
85 signature: Signature,
86}
87
88impl Default for ArrowCastFunc {
89 fn default() -> Self {
90 Self::new()
91 }
92}
93
94impl ArrowCastFunc {
95 pub fn new() -> Self {
96 Self {
97 signature: Signature::any(2, Volatility::Immutable),
98 }
99 }
100}
101
102impl ScalarUDFImpl for ArrowCastFunc {
103 fn as_any(&self) -> &dyn Any {
104 self
105 }
106
107 fn name(&self) -> &str {
108 "arrow_cast"
109 }
110
111 fn signature(&self) -> &Signature {
112 &self.signature
113 }
114
115 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
116 internal_err!("return_type_from_args should be called instead")
117 }
118
119 fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result<ReturnInfo> {
120 let nullable = args.nullables.iter().any(|&nullable| nullable);
121
122 let [_, type_arg] = take_function_args(self.name(), args.scalar_arguments)?;
123
124 type_arg
125 .and_then(|sv| sv.try_as_str().flatten().filter(|s| !s.is_empty()))
126 .map_or_else(
127 || {
128 exec_err!(
129 "{} requires its second argument to be a non-empty constant string",
130 self.name()
131 )
132 },
133 |casted_type| match casted_type.parse::<DataType>() {
134 Ok(data_type) => Ok(ReturnInfo::new(data_type, nullable)),
135 Err(ArrowError::ParseError(e)) => Err(exec_datafusion_err!("{e}")),
136 Err(e) => Err(arrow_datafusion_err!(e)),
137 },
138 )
139 }
140
141 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
142 internal_err!("arrow_cast should have been simplified to cast")
143 }
144
145 fn simplify(
146 &self,
147 mut args: Vec<Expr>,
148 info: &dyn SimplifyInfo,
149 ) -> Result<ExprSimplifyResult> {
150 let target_type = data_type_from_args(&args)?;
152 args.pop().unwrap();
154 let arg = args.pop().unwrap();
155
156 let source_type = info.get_data_type(&arg)?;
157 let new_expr = if source_type == target_type {
158 arg
160 } else {
161 Expr::Cast(datafusion_expr::Cast {
163 expr: Box::new(arg),
164 data_type: target_type,
165 })
166 };
167 Ok(ExprSimplifyResult::Simplified(new_expr))
169 }
170
171 fn documentation(&self) -> Option<&Documentation> {
172 self.doc()
173 }
174}
175
176fn data_type_from_args(args: &[Expr]) -> Result<DataType> {
178 let [_, type_arg] = take_function_args("arrow_cast", args)?;
179
180 let Expr::Literal(ScalarValue::Utf8(Some(val))) = type_arg else {
181 return exec_err!(
182 "arrow_cast requires its second argument to be a constant string, got {:?}",
183 type_arg
184 );
185 };
186
187 val.parse().map_err(|e| match e {
188 ArrowError::ParseError(e) => exec_datafusion_err!("{e}"),
191 e => arrow_datafusion_err!(e),
192 })
193}