datafusion_physical_expr/expressions/
cast.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::fmt;
20use std::hash::Hash;
21use std::sync::Arc;
22
23use crate::physical_expr::PhysicalExpr;
24
25use arrow::compute::{can_cast_types, CastOptions};
26use arrow::datatypes::{DataType, DataType::*, Schema};
27use arrow::record_batch::RecordBatch;
28use datafusion_common::format::DEFAULT_FORMAT_OPTIONS;
29use datafusion_common::{not_impl_err, Result};
30use datafusion_expr_common::columnar_value::ColumnarValue;
31use datafusion_expr_common::interval_arithmetic::Interval;
32use datafusion_expr_common::sort_properties::ExprProperties;
33
34const DEFAULT_CAST_OPTIONS: CastOptions<'static> = CastOptions {
35    safe: false,
36    format_options: DEFAULT_FORMAT_OPTIONS,
37};
38
39const DEFAULT_SAFE_CAST_OPTIONS: CastOptions<'static> = CastOptions {
40    safe: true,
41    format_options: DEFAULT_FORMAT_OPTIONS,
42};
43
44/// CAST expression casts an expression to a specific data type and returns a runtime error on invalid cast
45#[derive(Debug, Clone, Eq)]
46pub struct CastExpr {
47    /// The expression to cast
48    pub expr: Arc<dyn PhysicalExpr>,
49    /// The data type to cast to
50    cast_type: DataType,
51    /// Cast options
52    cast_options: CastOptions<'static>,
53}
54
55// Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808
56impl PartialEq for CastExpr {
57    fn eq(&self, other: &Self) -> bool {
58        self.expr.eq(&other.expr)
59            && self.cast_type.eq(&other.cast_type)
60            && self.cast_options.eq(&other.cast_options)
61    }
62}
63
64impl Hash for CastExpr {
65    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
66        self.expr.hash(state);
67        self.cast_type.hash(state);
68        self.cast_options.hash(state);
69    }
70}
71
72impl CastExpr {
73    /// Create a new CastExpr
74    pub fn new(
75        expr: Arc<dyn PhysicalExpr>,
76        cast_type: DataType,
77        cast_options: Option<CastOptions<'static>>,
78    ) -> Self {
79        Self {
80            expr,
81            cast_type,
82            cast_options: cast_options.unwrap_or(DEFAULT_CAST_OPTIONS),
83        }
84    }
85
86    /// The expression to cast
87    pub fn expr(&self) -> &Arc<dyn PhysicalExpr> {
88        &self.expr
89    }
90
91    /// The data type to cast to
92    pub fn cast_type(&self) -> &DataType {
93        &self.cast_type
94    }
95
96    /// The cast options
97    pub fn cast_options(&self) -> &CastOptions<'static> {
98        &self.cast_options
99    }
100    pub fn is_bigger_cast(&self, src: DataType) -> bool {
101        if src == self.cast_type {
102            return true;
103        }
104        matches!(
105            (src, &self.cast_type),
106            (Int8, Int16 | Int32 | Int64)
107                | (Int16, Int32 | Int64)
108                | (Int32, Int64)
109                | (UInt8, UInt16 | UInt32 | UInt64)
110                | (UInt16, UInt32 | UInt64)
111                | (UInt32, UInt64)
112                | (
113                    Int8 | Int16 | Int32 | UInt8 | UInt16 | UInt32,
114                    Float32 | Float64
115                )
116                | (Int64 | UInt64, Float64)
117                | (Utf8, LargeUtf8)
118        )
119    }
120}
121
122impl fmt::Display for CastExpr {
123    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
124        write!(f, "CAST({} AS {:?})", self.expr, self.cast_type)
125    }
126}
127
128impl PhysicalExpr for CastExpr {
129    /// Return a reference to Any that can be used for downcasting
130    fn as_any(&self) -> &dyn Any {
131        self
132    }
133
134    fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
135        Ok(self.cast_type.clone())
136    }
137
138    fn nullable(&self, input_schema: &Schema) -> Result<bool> {
139        self.expr.nullable(input_schema)
140    }
141
142    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
143        let value = self.expr.evaluate(batch)?;
144        value.cast_to(&self.cast_type, Some(&self.cast_options))
145    }
146
147    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
148        vec![&self.expr]
149    }
150
151    fn with_new_children(
152        self: Arc<Self>,
153        children: Vec<Arc<dyn PhysicalExpr>>,
154    ) -> Result<Arc<dyn PhysicalExpr>> {
155        Ok(Arc::new(CastExpr::new(
156            Arc::clone(&children[0]),
157            self.cast_type.clone(),
158            Some(self.cast_options.clone()),
159        )))
160    }
161
162    fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
163        // Cast current node's interval to the right type:
164        children[0].cast_to(&self.cast_type, &self.cast_options)
165    }
166
167    fn propagate_constraints(
168        &self,
169        interval: &Interval,
170        children: &[&Interval],
171    ) -> Result<Option<Vec<Interval>>> {
172        let child_interval = children[0];
173        // Get child's datatype:
174        let cast_type = child_interval.data_type();
175        Ok(Some(vec![
176            interval.cast_to(&cast_type, &DEFAULT_SAFE_CAST_OPTIONS)?
177        ]))
178    }
179
180    /// A [`CastExpr`] preserves the ordering of its child if the cast is done
181    /// under the same datatype family.
182    fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
183        let source_datatype = children[0].range.data_type();
184        let target_type = &self.cast_type;
185
186        let unbounded = Interval::make_unbounded(target_type)?;
187        if (source_datatype.is_numeric() || source_datatype == Boolean)
188            && target_type.is_numeric()
189            || source_datatype.is_temporal() && target_type.is_temporal()
190            || source_datatype.eq(target_type)
191        {
192            Ok(children[0].clone().with_range(unbounded))
193        } else {
194            Ok(ExprProperties::new_unknown().with_range(unbounded))
195        }
196    }
197}
198
199/// Return a PhysicalExpression representing `expr` casted to
200/// `cast_type`, if any casting is needed.
201///
202/// Note that such casts may lose type information
203pub fn cast_with_options(
204    expr: Arc<dyn PhysicalExpr>,
205    input_schema: &Schema,
206    cast_type: DataType,
207    cast_options: Option<CastOptions<'static>>,
208) -> Result<Arc<dyn PhysicalExpr>> {
209    let expr_type = expr.data_type(input_schema)?;
210    if expr_type == cast_type {
211        Ok(Arc::clone(&expr))
212    } else if can_cast_types(&expr_type, &cast_type) {
213        Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options)))
214    } else {
215        not_impl_err!("Unsupported CAST from {expr_type:?} to {cast_type:?}")
216    }
217}
218
219/// Return a PhysicalExpression representing `expr` casted to
220/// `cast_type`, if any casting is needed.
221///
222/// Note that such casts may lose type information
223pub fn cast(
224    expr: Arc<dyn PhysicalExpr>,
225    input_schema: &Schema,
226    cast_type: DataType,
227) -> Result<Arc<dyn PhysicalExpr>> {
228    cast_with_options(expr, input_schema, cast_type, None)
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234
235    use crate::expressions::column::col;
236
237    use arrow::{
238        array::{
239            Array, Decimal128Array, Float32Array, Float64Array, Int16Array, Int32Array,
240            Int64Array, Int8Array, StringArray, Time64NanosecondArray,
241            TimestampNanosecondArray, UInt32Array,
242        },
243        datatypes::*,
244    };
245    use datafusion_common::assert_contains;
246
247    // runs an end-to-end test of physical type cast
248    // 1. construct a record batch with a column "a" of type A
249    // 2. construct a physical expression of CAST(a AS B)
250    // 3. evaluate the expression
251    // 4. verify that the resulting expression is of type B
252    // 5. verify that the resulting values are downcastable and correct
253    macro_rules! generic_decimal_to_other_test_cast {
254        ($DECIMAL_ARRAY:ident, $A_TYPE:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr,$CAST_OPTIONS:expr) => {{
255            let schema = Schema::new(vec![Field::new("a", $A_TYPE, true)]);
256            let batch = RecordBatch::try_new(
257                Arc::new(schema.clone()),
258                vec![Arc::new($DECIMAL_ARRAY)],
259            )?;
260            // verify that we can construct the expression
261            let expression =
262                cast_with_options(col("a", &schema)?, &schema, $TYPE, $CAST_OPTIONS)?;
263
264            // verify that its display is correct
265            assert_eq!(
266                format!("CAST(a@0 AS {:?})", $TYPE),
267                format!("{}", expression)
268            );
269
270            // verify that the expression's type is correct
271            assert_eq!(expression.data_type(&schema)?, $TYPE);
272
273            // compute
274            let result = expression
275                .evaluate(&batch)?
276                .into_array(batch.num_rows())
277                .expect("Failed to convert to array");
278
279            // verify that the array's data_type is correct
280            assert_eq!(*result.data_type(), $TYPE);
281
282            // verify that the data itself is downcastable
283            let result = result
284                .as_any()
285                .downcast_ref::<$TYPEARRAY>()
286                .expect("failed to downcast");
287
288            // verify that the result itself is correct
289            for (i, x) in $VEC.iter().enumerate() {
290                match x {
291                    Some(x) => assert_eq!(result.value(i), *x),
292                    None => assert!(!result.is_valid(i)),
293                }
294            }
295        }};
296    }
297
298    // runs an end-to-end test of physical type cast
299    // 1. construct a record batch with a column "a" of type A
300    // 2. construct a physical expression of CAST(a AS B)
301    // 3. evaluate the expression
302    // 4. verify that the resulting expression is of type B
303    // 5. verify that the resulting values are downcastable and correct
304    macro_rules! generic_test_cast {
305        ($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr, $CAST_OPTIONS:expr) => {{
306            let schema = Schema::new(vec![Field::new("a", $A_TYPE, true)]);
307            let a_vec_len = $A_VEC.len();
308            let a = $A_ARRAY::from($A_VEC);
309            let batch =
310                RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
311
312            // verify that we can construct the expression
313            let expression =
314                cast_with_options(col("a", &schema)?, &schema, $TYPE, $CAST_OPTIONS)?;
315
316            // verify that its display is correct
317            assert_eq!(
318                format!("CAST(a@0 AS {:?})", $TYPE),
319                format!("{}", expression)
320            );
321
322            // verify that the expression's type is correct
323            assert_eq!(expression.data_type(&schema)?, $TYPE);
324
325            // compute
326            let result = expression
327                .evaluate(&batch)?
328                .into_array(batch.num_rows())
329                .expect("Failed to convert to array");
330
331            // verify that the array's data_type is correct
332            assert_eq!(*result.data_type(), $TYPE);
333
334            // verify that the len is correct
335            assert_eq!(result.len(), a_vec_len);
336
337            // verify that the data itself is downcastable
338            let result = result
339                .as_any()
340                .downcast_ref::<$TYPEARRAY>()
341                .expect("failed to downcast");
342
343            // verify that the result itself is correct
344            for (i, x) in $VEC.iter().enumerate() {
345                match x {
346                    Some(x) => assert_eq!(result.value(i), *x),
347                    None => assert!(!result.is_valid(i)),
348                }
349            }
350        }};
351    }
352
353    #[test]
354    fn test_cast_decimal_to_decimal() -> Result<()> {
355        let array = vec![
356            Some(1234),
357            Some(2222),
358            Some(3),
359            Some(4000),
360            Some(5000),
361            None,
362        ];
363
364        let decimal_array = array
365            .clone()
366            .into_iter()
367            .collect::<Decimal128Array>()
368            .with_precision_and_scale(10, 3)?;
369
370        generic_decimal_to_other_test_cast!(
371            decimal_array,
372            Decimal128(10, 3),
373            Decimal128Array,
374            Decimal128(20, 6),
375            [
376                Some(1_234_000),
377                Some(2_222_000),
378                Some(3_000),
379                Some(4_000_000),
380                Some(5_000_000),
381                None
382            ],
383            None
384        );
385
386        let decimal_array = array
387            .into_iter()
388            .collect::<Decimal128Array>()
389            .with_precision_and_scale(10, 3)?;
390
391        generic_decimal_to_other_test_cast!(
392            decimal_array,
393            Decimal128(10, 3),
394            Decimal128Array,
395            Decimal128(10, 2),
396            [Some(123), Some(222), Some(0), Some(400), Some(500), None],
397            None
398        );
399
400        Ok(())
401    }
402
403    #[test]
404    fn test_cast_decimal_to_decimal_overflow() -> Result<()> {
405        let array = vec![Some(123456789)];
406
407        let decimal_array = array
408            .clone()
409            .into_iter()
410            .collect::<Decimal128Array>()
411            .with_precision_and_scale(10, 3)?;
412
413        let schema = Schema::new(vec![Field::new("a", Decimal128(10, 3), false)]);
414        let batch = RecordBatch::try_new(
415            Arc::new(schema.clone()),
416            vec![Arc::new(decimal_array)],
417        )?;
418        let expression =
419            cast_with_options(col("a", &schema)?, &schema, Decimal128(6, 2), None)?;
420        let e = expression.evaluate(&batch).unwrap_err(); // panics on OK
421        assert_contains!(
422            e.to_string(),
423            "Arrow error: Invalid argument error: 12345679 is too large to store in a Decimal128 of precision 6. Max is 999999"
424        );
425
426        let expression_safe = cast_with_options(
427            col("a", &schema)?,
428            &schema,
429            Decimal128(6, 2),
430            Some(DEFAULT_SAFE_CAST_OPTIONS),
431        )?;
432        let result_safe = expression_safe
433            .evaluate(&batch)?
434            .into_array(batch.num_rows())
435            .expect("failed to convert to array");
436
437        assert!(result_safe.is_null(0));
438
439        Ok(())
440    }
441
442    #[test]
443    fn test_cast_decimal_to_numeric() -> Result<()> {
444        let array = vec![Some(1), Some(2), Some(3), Some(4), Some(5), None];
445        // decimal to i8
446        let decimal_array = array
447            .clone()
448            .into_iter()
449            .collect::<Decimal128Array>()
450            .with_precision_and_scale(10, 0)?;
451        generic_decimal_to_other_test_cast!(
452            decimal_array,
453            Decimal128(10, 0),
454            Int8Array,
455            Int8,
456            [
457                Some(1_i8),
458                Some(2_i8),
459                Some(3_i8),
460                Some(4_i8),
461                Some(5_i8),
462                None
463            ],
464            None
465        );
466
467        // decimal to i16
468        let decimal_array = array
469            .clone()
470            .into_iter()
471            .collect::<Decimal128Array>()
472            .with_precision_and_scale(10, 0)?;
473        generic_decimal_to_other_test_cast!(
474            decimal_array,
475            Decimal128(10, 0),
476            Int16Array,
477            Int16,
478            [
479                Some(1_i16),
480                Some(2_i16),
481                Some(3_i16),
482                Some(4_i16),
483                Some(5_i16),
484                None
485            ],
486            None
487        );
488
489        // decimal to i32
490        let decimal_array = array
491            .clone()
492            .into_iter()
493            .collect::<Decimal128Array>()
494            .with_precision_and_scale(10, 0)?;
495        generic_decimal_to_other_test_cast!(
496            decimal_array,
497            Decimal128(10, 0),
498            Int32Array,
499            Int32,
500            [
501                Some(1_i32),
502                Some(2_i32),
503                Some(3_i32),
504                Some(4_i32),
505                Some(5_i32),
506                None
507            ],
508            None
509        );
510
511        // decimal to i64
512        let decimal_array = array
513            .into_iter()
514            .collect::<Decimal128Array>()
515            .with_precision_and_scale(10, 0)?;
516        generic_decimal_to_other_test_cast!(
517            decimal_array,
518            Decimal128(10, 0),
519            Int64Array,
520            Int64,
521            [
522                Some(1_i64),
523                Some(2_i64),
524                Some(3_i64),
525                Some(4_i64),
526                Some(5_i64),
527                None
528            ],
529            None
530        );
531
532        // decimal to float32
533        let array = vec![
534            Some(1234),
535            Some(2222),
536            Some(3),
537            Some(4000),
538            Some(5000),
539            None,
540        ];
541        let decimal_array = array
542            .clone()
543            .into_iter()
544            .collect::<Decimal128Array>()
545            .with_precision_and_scale(10, 3)?;
546        generic_decimal_to_other_test_cast!(
547            decimal_array,
548            Decimal128(10, 3),
549            Float32Array,
550            Float32,
551            [
552                Some(1.234_f32),
553                Some(2.222_f32),
554                Some(0.003_f32),
555                Some(4.0_f32),
556                Some(5.0_f32),
557                None
558            ],
559            None
560        );
561
562        // decimal to float64
563        let decimal_array = array
564            .into_iter()
565            .collect::<Decimal128Array>()
566            .with_precision_and_scale(20, 6)?;
567        generic_decimal_to_other_test_cast!(
568            decimal_array,
569            Decimal128(20, 6),
570            Float64Array,
571            Float64,
572            [
573                Some(0.001234_f64),
574                Some(0.002222_f64),
575                Some(0.000003_f64),
576                Some(0.004_f64),
577                Some(0.005_f64),
578                None
579            ],
580            None
581        );
582        Ok(())
583    }
584
585    #[test]
586    fn test_cast_numeric_to_decimal() -> Result<()> {
587        // int8
588        generic_test_cast!(
589            Int8Array,
590            Int8,
591            vec![1, 2, 3, 4, 5],
592            Decimal128Array,
593            Decimal128(3, 0),
594            [Some(1), Some(2), Some(3), Some(4), Some(5)],
595            None
596        );
597
598        // int16
599        generic_test_cast!(
600            Int16Array,
601            Int16,
602            vec![1, 2, 3, 4, 5],
603            Decimal128Array,
604            Decimal128(5, 0),
605            [Some(1), Some(2), Some(3), Some(4), Some(5)],
606            None
607        );
608
609        // int32
610        generic_test_cast!(
611            Int32Array,
612            Int32,
613            vec![1, 2, 3, 4, 5],
614            Decimal128Array,
615            Decimal128(10, 0),
616            [Some(1), Some(2), Some(3), Some(4), Some(5)],
617            None
618        );
619
620        // int64
621        generic_test_cast!(
622            Int64Array,
623            Int64,
624            vec![1, 2, 3, 4, 5],
625            Decimal128Array,
626            Decimal128(20, 0),
627            [Some(1), Some(2), Some(3), Some(4), Some(5)],
628            None
629        );
630
631        // int64 to different scale
632        generic_test_cast!(
633            Int64Array,
634            Int64,
635            vec![1, 2, 3, 4, 5],
636            Decimal128Array,
637            Decimal128(20, 2),
638            [Some(100), Some(200), Some(300), Some(400), Some(500)],
639            None
640        );
641
642        // float32
643        generic_test_cast!(
644            Float32Array,
645            Float32,
646            vec![1.5, 2.5, 3.0, 1.123_456_8, 5.50],
647            Decimal128Array,
648            Decimal128(10, 2),
649            [Some(150), Some(250), Some(300), Some(112), Some(550)],
650            None
651        );
652
653        // float64
654        generic_test_cast!(
655            Float64Array,
656            Float64,
657            vec![1.5, 2.5, 3.0, 1.123_456_8, 5.50],
658            Decimal128Array,
659            Decimal128(20, 4),
660            [
661                Some(15000),
662                Some(25000),
663                Some(30000),
664                Some(11235),
665                Some(55000)
666            ],
667            None
668        );
669        Ok(())
670    }
671
672    #[test]
673    fn test_cast_i32_u32() -> Result<()> {
674        generic_test_cast!(
675            Int32Array,
676            Int32,
677            vec![1, 2, 3, 4, 5],
678            UInt32Array,
679            UInt32,
680            [
681                Some(1_u32),
682                Some(2_u32),
683                Some(3_u32),
684                Some(4_u32),
685                Some(5_u32)
686            ],
687            None
688        );
689        Ok(())
690    }
691
692    #[test]
693    fn test_cast_i32_utf8() -> Result<()> {
694        generic_test_cast!(
695            Int32Array,
696            Int32,
697            vec![1, 2, 3, 4, 5],
698            StringArray,
699            Utf8,
700            [Some("1"), Some("2"), Some("3"), Some("4"), Some("5")],
701            None
702        );
703        Ok(())
704    }
705
706    #[test]
707    fn test_cast_i64_t64() -> Result<()> {
708        let original = vec![1, 2, 3, 4, 5];
709        let expected: Vec<Option<i64>> = original
710            .iter()
711            .map(|i| Some(Time64NanosecondArray::from(vec![*i]).value(0)))
712            .collect();
713        generic_test_cast!(
714            Int64Array,
715            Int64,
716            original,
717            TimestampNanosecondArray,
718            Timestamp(TimeUnit::Nanosecond, None),
719            expected,
720            None
721        );
722        Ok(())
723    }
724
725    #[test]
726    fn invalid_cast() {
727        // Ensure a useful error happens at plan time if invalid casts are used
728        let schema = Schema::new(vec![Field::new("a", Int32, false)]);
729
730        let result = cast(
731            col("a", &schema).unwrap(),
732            &schema,
733            Interval(IntervalUnit::MonthDayNano),
734        );
735        result.expect_err("expected Invalid CAST");
736    }
737
738    #[test]
739    fn invalid_cast_with_options_error() -> Result<()> {
740        // Ensure a useful error happens at plan time if invalid casts are used
741        let schema = Schema::new(vec![Field::new("a", Utf8, false)]);
742        let a = StringArray::from(vec!["9.1"]);
743        let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
744        let expression = cast_with_options(col("a", &schema)?, &schema, Int32, None)?;
745        let result = expression.evaluate(&batch);
746
747        match result {
748            Ok(_) => panic!("expected error"),
749            Err(e) => {
750                assert!(e
751                    .to_string()
752                    .contains("Cannot cast string '9.1' to value of Int32 type"))
753            }
754        }
755        Ok(())
756    }
757
758    #[test]
759    #[ignore] // TODO: https://github.com/apache/datafusion/issues/5396
760    fn test_cast_decimal() -> Result<()> {
761        let schema = Schema::new(vec![Field::new("a", Int64, false)]);
762        let a = Int64Array::from(vec![100]);
763        let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
764        let expression =
765            cast_with_options(col("a", &schema)?, &schema, Decimal128(38, 38), None)?;
766        expression.evaluate(&batch)?;
767        Ok(())
768    }
769}