1use std::any::Any;
19use std::fmt;
20use std::hash::Hash;
21use std::sync::Arc;
22
23use crate::physical_expr::PhysicalExpr;
24
25use arrow::compute::{can_cast_types, CastOptions};
26use arrow::datatypes::{DataType, DataType::*, Schema};
27use arrow::record_batch::RecordBatch;
28use datafusion_common::format::DEFAULT_FORMAT_OPTIONS;
29use datafusion_common::{not_impl_err, Result};
30use datafusion_expr_common::columnar_value::ColumnarValue;
31use datafusion_expr_common::interval_arithmetic::Interval;
32use datafusion_expr_common::sort_properties::ExprProperties;
33
34const DEFAULT_CAST_OPTIONS: CastOptions<'static> = CastOptions {
35 safe: false,
36 format_options: DEFAULT_FORMAT_OPTIONS,
37};
38
39const DEFAULT_SAFE_CAST_OPTIONS: CastOptions<'static> = CastOptions {
40 safe: true,
41 format_options: DEFAULT_FORMAT_OPTIONS,
42};
43
44#[derive(Debug, Clone, Eq)]
46pub struct CastExpr {
47 pub expr: Arc<dyn PhysicalExpr>,
49 cast_type: DataType,
51 cast_options: CastOptions<'static>,
53}
54
55impl PartialEq for CastExpr {
57 fn eq(&self, other: &Self) -> bool {
58 self.expr.eq(&other.expr)
59 && self.cast_type.eq(&other.cast_type)
60 && self.cast_options.eq(&other.cast_options)
61 }
62}
63
64impl Hash for CastExpr {
65 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
66 self.expr.hash(state);
67 self.cast_type.hash(state);
68 self.cast_options.hash(state);
69 }
70}
71
72impl CastExpr {
73 pub fn new(
75 expr: Arc<dyn PhysicalExpr>,
76 cast_type: DataType,
77 cast_options: Option<CastOptions<'static>>,
78 ) -> Self {
79 Self {
80 expr,
81 cast_type,
82 cast_options: cast_options.unwrap_or(DEFAULT_CAST_OPTIONS),
83 }
84 }
85
86 pub fn expr(&self) -> &Arc<dyn PhysicalExpr> {
88 &self.expr
89 }
90
91 pub fn cast_type(&self) -> &DataType {
93 &self.cast_type
94 }
95
96 pub fn cast_options(&self) -> &CastOptions<'static> {
98 &self.cast_options
99 }
100 pub fn is_bigger_cast(&self, src: DataType) -> bool {
101 if src == self.cast_type {
102 return true;
103 }
104 matches!(
105 (src, &self.cast_type),
106 (Int8, Int16 | Int32 | Int64)
107 | (Int16, Int32 | Int64)
108 | (Int32, Int64)
109 | (UInt8, UInt16 | UInt32 | UInt64)
110 | (UInt16, UInt32 | UInt64)
111 | (UInt32, UInt64)
112 | (
113 Int8 | Int16 | Int32 | UInt8 | UInt16 | UInt32,
114 Float32 | Float64
115 )
116 | (Int64 | UInt64, Float64)
117 | (Utf8, LargeUtf8)
118 )
119 }
120}
121
122impl fmt::Display for CastExpr {
123 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
124 write!(f, "CAST({} AS {:?})", self.expr, self.cast_type)
125 }
126}
127
128impl PhysicalExpr for CastExpr {
129 fn as_any(&self) -> &dyn Any {
131 self
132 }
133
134 fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
135 Ok(self.cast_type.clone())
136 }
137
138 fn nullable(&self, input_schema: &Schema) -> Result<bool> {
139 self.expr.nullable(input_schema)
140 }
141
142 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
143 let value = self.expr.evaluate(batch)?;
144 value.cast_to(&self.cast_type, Some(&self.cast_options))
145 }
146
147 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
148 vec![&self.expr]
149 }
150
151 fn with_new_children(
152 self: Arc<Self>,
153 children: Vec<Arc<dyn PhysicalExpr>>,
154 ) -> Result<Arc<dyn PhysicalExpr>> {
155 Ok(Arc::new(CastExpr::new(
156 Arc::clone(&children[0]),
157 self.cast_type.clone(),
158 Some(self.cast_options.clone()),
159 )))
160 }
161
162 fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
163 children[0].cast_to(&self.cast_type, &self.cast_options)
165 }
166
167 fn propagate_constraints(
168 &self,
169 interval: &Interval,
170 children: &[&Interval],
171 ) -> Result<Option<Vec<Interval>>> {
172 let child_interval = children[0];
173 let cast_type = child_interval.data_type();
175 Ok(Some(vec![
176 interval.cast_to(&cast_type, &DEFAULT_SAFE_CAST_OPTIONS)?
177 ]))
178 }
179
180 fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
183 let source_datatype = children[0].range.data_type();
184 let target_type = &self.cast_type;
185
186 let unbounded = Interval::make_unbounded(target_type)?;
187 if (source_datatype.is_numeric() || source_datatype == Boolean)
188 && target_type.is_numeric()
189 || source_datatype.is_temporal() && target_type.is_temporal()
190 || source_datatype.eq(target_type)
191 {
192 Ok(children[0].clone().with_range(unbounded))
193 } else {
194 Ok(ExprProperties::new_unknown().with_range(unbounded))
195 }
196 }
197}
198
199pub fn cast_with_options(
204 expr: Arc<dyn PhysicalExpr>,
205 input_schema: &Schema,
206 cast_type: DataType,
207 cast_options: Option<CastOptions<'static>>,
208) -> Result<Arc<dyn PhysicalExpr>> {
209 let expr_type = expr.data_type(input_schema)?;
210 if expr_type == cast_type {
211 Ok(Arc::clone(&expr))
212 } else if can_cast_types(&expr_type, &cast_type) {
213 Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options)))
214 } else {
215 not_impl_err!("Unsupported CAST from {expr_type:?} to {cast_type:?}")
216 }
217}
218
219pub fn cast(
224 expr: Arc<dyn PhysicalExpr>,
225 input_schema: &Schema,
226 cast_type: DataType,
227) -> Result<Arc<dyn PhysicalExpr>> {
228 cast_with_options(expr, input_schema, cast_type, None)
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234
235 use crate::expressions::column::col;
236
237 use arrow::{
238 array::{
239 Array, Decimal128Array, Float32Array, Float64Array, Int16Array, Int32Array,
240 Int64Array, Int8Array, StringArray, Time64NanosecondArray,
241 TimestampNanosecondArray, UInt32Array,
242 },
243 datatypes::*,
244 };
245 use datafusion_common::assert_contains;
246
247 macro_rules! generic_decimal_to_other_test_cast {
254 ($DECIMAL_ARRAY:ident, $A_TYPE:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr,$CAST_OPTIONS:expr) => {{
255 let schema = Schema::new(vec![Field::new("a", $A_TYPE, true)]);
256 let batch = RecordBatch::try_new(
257 Arc::new(schema.clone()),
258 vec![Arc::new($DECIMAL_ARRAY)],
259 )?;
260 let expression =
262 cast_with_options(col("a", &schema)?, &schema, $TYPE, $CAST_OPTIONS)?;
263
264 assert_eq!(
266 format!("CAST(a@0 AS {:?})", $TYPE),
267 format!("{}", expression)
268 );
269
270 assert_eq!(expression.data_type(&schema)?, $TYPE);
272
273 let result = expression
275 .evaluate(&batch)?
276 .into_array(batch.num_rows())
277 .expect("Failed to convert to array");
278
279 assert_eq!(*result.data_type(), $TYPE);
281
282 let result = result
284 .as_any()
285 .downcast_ref::<$TYPEARRAY>()
286 .expect("failed to downcast");
287
288 for (i, x) in $VEC.iter().enumerate() {
290 match x {
291 Some(x) => assert_eq!(result.value(i), *x),
292 None => assert!(!result.is_valid(i)),
293 }
294 }
295 }};
296 }
297
298 macro_rules! generic_test_cast {
305 ($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr, $CAST_OPTIONS:expr) => {{
306 let schema = Schema::new(vec![Field::new("a", $A_TYPE, true)]);
307 let a_vec_len = $A_VEC.len();
308 let a = $A_ARRAY::from($A_VEC);
309 let batch =
310 RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
311
312 let expression =
314 cast_with_options(col("a", &schema)?, &schema, $TYPE, $CAST_OPTIONS)?;
315
316 assert_eq!(
318 format!("CAST(a@0 AS {:?})", $TYPE),
319 format!("{}", expression)
320 );
321
322 assert_eq!(expression.data_type(&schema)?, $TYPE);
324
325 let result = expression
327 .evaluate(&batch)?
328 .into_array(batch.num_rows())
329 .expect("Failed to convert to array");
330
331 assert_eq!(*result.data_type(), $TYPE);
333
334 assert_eq!(result.len(), a_vec_len);
336
337 let result = result
339 .as_any()
340 .downcast_ref::<$TYPEARRAY>()
341 .expect("failed to downcast");
342
343 for (i, x) in $VEC.iter().enumerate() {
345 match x {
346 Some(x) => assert_eq!(result.value(i), *x),
347 None => assert!(!result.is_valid(i)),
348 }
349 }
350 }};
351 }
352
353 #[test]
354 fn test_cast_decimal_to_decimal() -> Result<()> {
355 let array = vec![
356 Some(1234),
357 Some(2222),
358 Some(3),
359 Some(4000),
360 Some(5000),
361 None,
362 ];
363
364 let decimal_array = array
365 .clone()
366 .into_iter()
367 .collect::<Decimal128Array>()
368 .with_precision_and_scale(10, 3)?;
369
370 generic_decimal_to_other_test_cast!(
371 decimal_array,
372 Decimal128(10, 3),
373 Decimal128Array,
374 Decimal128(20, 6),
375 [
376 Some(1_234_000),
377 Some(2_222_000),
378 Some(3_000),
379 Some(4_000_000),
380 Some(5_000_000),
381 None
382 ],
383 None
384 );
385
386 let decimal_array = array
387 .into_iter()
388 .collect::<Decimal128Array>()
389 .with_precision_and_scale(10, 3)?;
390
391 generic_decimal_to_other_test_cast!(
392 decimal_array,
393 Decimal128(10, 3),
394 Decimal128Array,
395 Decimal128(10, 2),
396 [Some(123), Some(222), Some(0), Some(400), Some(500), None],
397 None
398 );
399
400 Ok(())
401 }
402
403 #[test]
404 fn test_cast_decimal_to_decimal_overflow() -> Result<()> {
405 let array = vec![Some(123456789)];
406
407 let decimal_array = array
408 .clone()
409 .into_iter()
410 .collect::<Decimal128Array>()
411 .with_precision_and_scale(10, 3)?;
412
413 let schema = Schema::new(vec![Field::new("a", Decimal128(10, 3), false)]);
414 let batch = RecordBatch::try_new(
415 Arc::new(schema.clone()),
416 vec![Arc::new(decimal_array)],
417 )?;
418 let expression =
419 cast_with_options(col("a", &schema)?, &schema, Decimal128(6, 2), None)?;
420 let e = expression.evaluate(&batch).unwrap_err(); assert_contains!(
422 e.to_string(),
423 "Arrow error: Invalid argument error: 12345679 is too large to store in a Decimal128 of precision 6. Max is 999999"
424 );
425
426 let expression_safe = cast_with_options(
427 col("a", &schema)?,
428 &schema,
429 Decimal128(6, 2),
430 Some(DEFAULT_SAFE_CAST_OPTIONS),
431 )?;
432 let result_safe = expression_safe
433 .evaluate(&batch)?
434 .into_array(batch.num_rows())
435 .expect("failed to convert to array");
436
437 assert!(result_safe.is_null(0));
438
439 Ok(())
440 }
441
442 #[test]
443 fn test_cast_decimal_to_numeric() -> Result<()> {
444 let array = vec![Some(1), Some(2), Some(3), Some(4), Some(5), None];
445 let decimal_array = array
447 .clone()
448 .into_iter()
449 .collect::<Decimal128Array>()
450 .with_precision_and_scale(10, 0)?;
451 generic_decimal_to_other_test_cast!(
452 decimal_array,
453 Decimal128(10, 0),
454 Int8Array,
455 Int8,
456 [
457 Some(1_i8),
458 Some(2_i8),
459 Some(3_i8),
460 Some(4_i8),
461 Some(5_i8),
462 None
463 ],
464 None
465 );
466
467 let decimal_array = array
469 .clone()
470 .into_iter()
471 .collect::<Decimal128Array>()
472 .with_precision_and_scale(10, 0)?;
473 generic_decimal_to_other_test_cast!(
474 decimal_array,
475 Decimal128(10, 0),
476 Int16Array,
477 Int16,
478 [
479 Some(1_i16),
480 Some(2_i16),
481 Some(3_i16),
482 Some(4_i16),
483 Some(5_i16),
484 None
485 ],
486 None
487 );
488
489 let decimal_array = array
491 .clone()
492 .into_iter()
493 .collect::<Decimal128Array>()
494 .with_precision_and_scale(10, 0)?;
495 generic_decimal_to_other_test_cast!(
496 decimal_array,
497 Decimal128(10, 0),
498 Int32Array,
499 Int32,
500 [
501 Some(1_i32),
502 Some(2_i32),
503 Some(3_i32),
504 Some(4_i32),
505 Some(5_i32),
506 None
507 ],
508 None
509 );
510
511 let decimal_array = array
513 .into_iter()
514 .collect::<Decimal128Array>()
515 .with_precision_and_scale(10, 0)?;
516 generic_decimal_to_other_test_cast!(
517 decimal_array,
518 Decimal128(10, 0),
519 Int64Array,
520 Int64,
521 [
522 Some(1_i64),
523 Some(2_i64),
524 Some(3_i64),
525 Some(4_i64),
526 Some(5_i64),
527 None
528 ],
529 None
530 );
531
532 let array = vec![
534 Some(1234),
535 Some(2222),
536 Some(3),
537 Some(4000),
538 Some(5000),
539 None,
540 ];
541 let decimal_array = array
542 .clone()
543 .into_iter()
544 .collect::<Decimal128Array>()
545 .with_precision_and_scale(10, 3)?;
546 generic_decimal_to_other_test_cast!(
547 decimal_array,
548 Decimal128(10, 3),
549 Float32Array,
550 Float32,
551 [
552 Some(1.234_f32),
553 Some(2.222_f32),
554 Some(0.003_f32),
555 Some(4.0_f32),
556 Some(5.0_f32),
557 None
558 ],
559 None
560 );
561
562 let decimal_array = array
564 .into_iter()
565 .collect::<Decimal128Array>()
566 .with_precision_and_scale(20, 6)?;
567 generic_decimal_to_other_test_cast!(
568 decimal_array,
569 Decimal128(20, 6),
570 Float64Array,
571 Float64,
572 [
573 Some(0.001234_f64),
574 Some(0.002222_f64),
575 Some(0.000003_f64),
576 Some(0.004_f64),
577 Some(0.005_f64),
578 None
579 ],
580 None
581 );
582 Ok(())
583 }
584
585 #[test]
586 fn test_cast_numeric_to_decimal() -> Result<()> {
587 generic_test_cast!(
589 Int8Array,
590 Int8,
591 vec![1, 2, 3, 4, 5],
592 Decimal128Array,
593 Decimal128(3, 0),
594 [Some(1), Some(2), Some(3), Some(4), Some(5)],
595 None
596 );
597
598 generic_test_cast!(
600 Int16Array,
601 Int16,
602 vec![1, 2, 3, 4, 5],
603 Decimal128Array,
604 Decimal128(5, 0),
605 [Some(1), Some(2), Some(3), Some(4), Some(5)],
606 None
607 );
608
609 generic_test_cast!(
611 Int32Array,
612 Int32,
613 vec![1, 2, 3, 4, 5],
614 Decimal128Array,
615 Decimal128(10, 0),
616 [Some(1), Some(2), Some(3), Some(4), Some(5)],
617 None
618 );
619
620 generic_test_cast!(
622 Int64Array,
623 Int64,
624 vec![1, 2, 3, 4, 5],
625 Decimal128Array,
626 Decimal128(20, 0),
627 [Some(1), Some(2), Some(3), Some(4), Some(5)],
628 None
629 );
630
631 generic_test_cast!(
633 Int64Array,
634 Int64,
635 vec![1, 2, 3, 4, 5],
636 Decimal128Array,
637 Decimal128(20, 2),
638 [Some(100), Some(200), Some(300), Some(400), Some(500)],
639 None
640 );
641
642 generic_test_cast!(
644 Float32Array,
645 Float32,
646 vec![1.5, 2.5, 3.0, 1.123_456_8, 5.50],
647 Decimal128Array,
648 Decimal128(10, 2),
649 [Some(150), Some(250), Some(300), Some(112), Some(550)],
650 None
651 );
652
653 generic_test_cast!(
655 Float64Array,
656 Float64,
657 vec![1.5, 2.5, 3.0, 1.123_456_8, 5.50],
658 Decimal128Array,
659 Decimal128(20, 4),
660 [
661 Some(15000),
662 Some(25000),
663 Some(30000),
664 Some(11235),
665 Some(55000)
666 ],
667 None
668 );
669 Ok(())
670 }
671
672 #[test]
673 fn test_cast_i32_u32() -> Result<()> {
674 generic_test_cast!(
675 Int32Array,
676 Int32,
677 vec![1, 2, 3, 4, 5],
678 UInt32Array,
679 UInt32,
680 [
681 Some(1_u32),
682 Some(2_u32),
683 Some(3_u32),
684 Some(4_u32),
685 Some(5_u32)
686 ],
687 None
688 );
689 Ok(())
690 }
691
692 #[test]
693 fn test_cast_i32_utf8() -> Result<()> {
694 generic_test_cast!(
695 Int32Array,
696 Int32,
697 vec![1, 2, 3, 4, 5],
698 StringArray,
699 Utf8,
700 [Some("1"), Some("2"), Some("3"), Some("4"), Some("5")],
701 None
702 );
703 Ok(())
704 }
705
706 #[test]
707 fn test_cast_i64_t64() -> Result<()> {
708 let original = vec![1, 2, 3, 4, 5];
709 let expected: Vec<Option<i64>> = original
710 .iter()
711 .map(|i| Some(Time64NanosecondArray::from(vec![*i]).value(0)))
712 .collect();
713 generic_test_cast!(
714 Int64Array,
715 Int64,
716 original,
717 TimestampNanosecondArray,
718 Timestamp(TimeUnit::Nanosecond, None),
719 expected,
720 None
721 );
722 Ok(())
723 }
724
725 #[test]
726 fn invalid_cast() {
727 let schema = Schema::new(vec![Field::new("a", Int32, false)]);
729
730 let result = cast(
731 col("a", &schema).unwrap(),
732 &schema,
733 Interval(IntervalUnit::MonthDayNano),
734 );
735 result.expect_err("expected Invalid CAST");
736 }
737
738 #[test]
739 fn invalid_cast_with_options_error() -> Result<()> {
740 let schema = Schema::new(vec![Field::new("a", Utf8, false)]);
742 let a = StringArray::from(vec!["9.1"]);
743 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
744 let expression = cast_with_options(col("a", &schema)?, &schema, Int32, None)?;
745 let result = expression.evaluate(&batch);
746
747 match result {
748 Ok(_) => panic!("expected error"),
749 Err(e) => {
750 assert!(e
751 .to_string()
752 .contains("Cannot cast string '9.1' to value of Int32 type"))
753 }
754 }
755 Ok(())
756 }
757
758 #[test]
759 #[ignore] fn test_cast_decimal() -> Result<()> {
761 let schema = Schema::new(vec![Field::new("a", Int64, false)]);
762 let a = Int64Array::from(vec![100]);
763 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
764 let expression =
765 cast_with_options(col("a", &schema)?, &schema, Decimal128(38, 38), None)?;
766 expression.evaluate(&batch)?;
767 Ok(())
768 }
769}