datafusion_functions/core/
union_extract.rs1use arrow::array::Array;
19use arrow::datatypes::{DataType, FieldRef, UnionFields};
20use datafusion_common::cast::as_union_array;
21use datafusion_common::utils::take_function_args;
22use datafusion_common::{
23 exec_datafusion_err, exec_err, internal_err, Result, ScalarValue,
24};
25use datafusion_doc::Documentation;
26use datafusion_expr::{ColumnarValue, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs};
27use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
28use datafusion_macros::user_doc;
29
30#[user_doc(
31 doc_section(label = "Union Functions"),
32 description = "Returns the value of the given field in the union when selected, or NULL otherwise.",
33 syntax_example = "union_extract(union, field_name)",
34 sql_example = r#"```sql
35❯ select union_column, union_extract(union_column, 'a'), union_extract(union_column, 'b') from table_with_union;
36+--------------+----------------------------------+----------------------------------+
37| union_column | union_extract(union_column, 'a') | union_extract(union_column, 'b') |
38+--------------+----------------------------------+----------------------------------+
39| {a=1} | 1 | |
40| {b=3.0} | | 3.0 |
41| {a=4} | 4 | |
42| {b=} | | |
43| {a=} | | |
44+--------------+----------------------------------+----------------------------------+
45```"#,
46 standard_argument(name = "union", prefix = "Union"),
47 argument(
48 name = "field_name",
49 description = "String expression to operate on. Must be a constant."
50 )
51)]
52#[derive(Debug)]
53pub struct UnionExtractFun {
54 signature: Signature,
55}
56
57impl Default for UnionExtractFun {
58 fn default() -> Self {
59 Self::new()
60 }
61}
62
63impl UnionExtractFun {
64 pub fn new() -> Self {
65 Self {
66 signature: Signature::any(2, Volatility::Immutable),
67 }
68 }
69}
70
71impl ScalarUDFImpl for UnionExtractFun {
72 fn as_any(&self) -> &dyn std::any::Any {
73 self
74 }
75
76 fn name(&self) -> &str {
77 "union_extract"
78 }
79
80 fn signature(&self) -> &Signature {
81 &self.signature
82 }
83
84 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
85 internal_err!("union_extract should return type from exprs")
87 }
88
89 fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result<ReturnInfo> {
90 if args.arg_types.len() != 2 {
91 return exec_err!(
92 "union_extract expects 2 arguments, got {} instead",
93 args.arg_types.len()
94 );
95 }
96
97 let DataType::Union(fields, _) = &args.arg_types[0] else {
98 return exec_err!(
99 "union_extract first argument must be a union, got {} instead",
100 args.arg_types[0]
101 );
102 };
103
104 let Some(ScalarValue::Utf8(Some(field_name))) = &args.scalar_arguments[1] else {
105 return exec_err!(
106 "union_extract second argument must be a non-null string literal, got {} instead",
107 args.arg_types[1]
108 );
109 };
110
111 let field = find_field(fields, field_name)?.1;
112
113 Ok(ReturnInfo::new_nullable(field.data_type().clone()))
114 }
115
116 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
117 let [array, target_name] = take_function_args("union_extract", args.args)?;
118
119 let target_name = match target_name {
120 ColumnarValue::Scalar(ScalarValue::Utf8(Some(target_name))) => Ok(target_name),
121 ColumnarValue::Scalar(ScalarValue::Utf8(None)) => exec_err!("union_extract second argument must be a non-null string literal, got a null instead"),
122 _ => exec_err!("union_extract second argument must be a non-null string literal, got {} instead", target_name.data_type()),
123 }?;
124
125 match array {
126 ColumnarValue::Array(array) => {
127 let union_array = as_union_array(&array).map_err(|_| {
128 exec_datafusion_err!(
129 "union_extract first argument must be a union, got {} instead",
130 array.data_type()
131 )
132 })?;
133
134 Ok(ColumnarValue::Array(
135 arrow::compute::kernels::union_extract::union_extract(
136 union_array,
137 &target_name,
138 )?,
139 ))
140 }
141 ColumnarValue::Scalar(ScalarValue::Union(value, fields, _)) => {
142 let (target_type_id, target) = find_field(&fields, &target_name)?;
143
144 let result = match value {
145 Some((type_id, value)) if target_type_id == type_id => *value,
146 _ => ScalarValue::try_new_null(target.data_type())?,
147 };
148
149 Ok(ColumnarValue::Scalar(result))
150 }
151 other => exec_err!(
152 "union_extract first argument must be a union, got {} instead",
153 other.data_type()
154 ),
155 }
156 }
157
158 fn documentation(&self) -> Option<&Documentation> {
159 self.doc()
160 }
161}
162
163fn find_field<'a>(fields: &'a UnionFields, name: &str) -> Result<(i8, &'a FieldRef)> {
164 fields
165 .iter()
166 .find(|field| field.1.name() == name)
167 .ok_or_else(|| exec_datafusion_err!("field {name} not found on union"))
168}
169
170#[cfg(test)]
171mod tests {
172
173 use arrow::datatypes::{DataType, Field, UnionFields, UnionMode};
174 use datafusion_common::{Result, ScalarValue};
175 use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
176
177 use super::UnionExtractFun;
178
179 #[test]
181 fn test_scalar_value() -> Result<()> {
182 let fun = UnionExtractFun::new();
183
184 let fields = UnionFields::new(
185 vec![1, 3],
186 vec![
187 Field::new("str", DataType::Utf8, false),
188 Field::new("int", DataType::Int32, false),
189 ],
190 );
191
192 let result = fun.invoke_with_args(ScalarFunctionArgs {
193 args: vec![
194 ColumnarValue::Scalar(ScalarValue::Union(
195 None,
196 fields.clone(),
197 UnionMode::Dense,
198 )),
199 ColumnarValue::Scalar(ScalarValue::new_utf8("str")),
200 ],
201 number_rows: 1,
202 return_type: &DataType::Utf8,
203 })?;
204
205 assert_scalar(result, ScalarValue::Utf8(None));
206
207 let result = fun.invoke_with_args(ScalarFunctionArgs {
208 args: vec![
209 ColumnarValue::Scalar(ScalarValue::Union(
210 Some((3, Box::new(ScalarValue::Int32(Some(42))))),
211 fields.clone(),
212 UnionMode::Dense,
213 )),
214 ColumnarValue::Scalar(ScalarValue::new_utf8("str")),
215 ],
216 number_rows: 1,
217 return_type: &DataType::Utf8,
218 })?;
219
220 assert_scalar(result, ScalarValue::Utf8(None));
221
222 let result = fun.invoke_with_args(ScalarFunctionArgs {
223 args: vec![
224 ColumnarValue::Scalar(ScalarValue::Union(
225 Some((1, Box::new(ScalarValue::new_utf8("42")))),
226 fields.clone(),
227 UnionMode::Dense,
228 )),
229 ColumnarValue::Scalar(ScalarValue::new_utf8("str")),
230 ],
231 number_rows: 1,
232 return_type: &DataType::Utf8,
233 })?;
234
235 assert_scalar(result, ScalarValue::new_utf8("42"));
236
237 Ok(())
238 }
239
240 fn assert_scalar(value: ColumnarValue, expected: ScalarValue) {
241 match value {
242 ColumnarValue::Array(array) => panic!("expected scalar got {array:?}"),
243 ColumnarValue::Scalar(scalar) => assert_eq!(scalar, expected),
244 }
245 }
246}