datafusion_functions/core/
getfield.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{
19    make_array, make_comparator, Array, BooleanArray, Capacities, MutableArrayData,
20    Scalar,
21};
22use arrow::compute::SortOptions;
23use arrow::datatypes::DataType;
24use arrow_buffer::NullBuffer;
25use datafusion_common::cast::{as_map_array, as_struct_array};
26use datafusion_common::{
27    exec_err, internal_err, plan_datafusion_err, utils::take_function_args, Result,
28    ScalarValue,
29};
30use datafusion_expr::{
31    ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs,
32};
33use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
34use datafusion_macros::user_doc;
35use std::any::Any;
36use std::sync::Arc;
37
38#[user_doc(
39    doc_section(label = "Other Functions"),
40    description = r#"Returns a field within a map or a struct with the given key.
41    Note: most users invoke `get_field` indirectly via field access
42    syntax such as `my_struct_col['field_name']` which results in a call to
43    `get_field(my_struct_col, 'field_name')`."#,
44    syntax_example = "get_field(expression1, expression2)",
45    sql_example = r#"```sql
46> create table t (idx varchar, v varchar) as values ('data','fusion'), ('apache', 'arrow');
47> select struct(idx, v) from t as c;
48+-------------------------+
49| struct(c.idx,c.v)       |
50+-------------------------+
51| {c0: data, c1: fusion}  |
52| {c0: apache, c1: arrow} |
53+-------------------------+
54> select get_field((select struct(idx, v) from t), 'c0');
55+-----------------------+
56| struct(t.idx,t.v)[c0] |
57+-----------------------+
58| data                  |
59| apache                |
60+-----------------------+
61> select get_field((select struct(idx, v) from t), 'c1');
62+-----------------------+
63| struct(t.idx,t.v)[c1] |
64+-----------------------+
65| fusion                |
66| arrow                 |
67+-----------------------+
68```"#,
69    argument(
70        name = "expression1",
71        description = "The map or struct to retrieve a field for."
72    ),
73    argument(
74        name = "expression2",
75        description = "The field name in the map or struct to retrieve data for. Must evaluate to a string."
76    )
77)]
78#[derive(Debug)]
79pub struct GetFieldFunc {
80    signature: Signature,
81}
82
83impl Default for GetFieldFunc {
84    fn default() -> Self {
85        Self::new()
86    }
87}
88
89impl GetFieldFunc {
90    pub fn new() -> Self {
91        Self {
92            signature: Signature::any(2, Volatility::Immutable),
93        }
94    }
95}
96
97// get_field(struct_array, field_name)
98impl ScalarUDFImpl for GetFieldFunc {
99    fn as_any(&self) -> &dyn Any {
100        self
101    }
102
103    fn name(&self) -> &str {
104        "get_field"
105    }
106
107    fn display_name(&self, args: &[Expr]) -> Result<String> {
108        let [base, field_name] = take_function_args(self.name(), args)?;
109
110        let name = match field_name {
111            Expr::Literal(name) => name,
112            other => &ScalarValue::Utf8(Some(other.schema_name().to_string())),
113        };
114
115        Ok(format!("{base}[{name}]"))
116    }
117
118    fn schema_name(&self, args: &[Expr]) -> Result<String> {
119        let [base, field_name] = take_function_args(self.name(), args)?;
120        let name = match field_name {
121            Expr::Literal(name) => name,
122            other => &ScalarValue::Utf8(Some(other.schema_name().to_string())),
123        };
124
125        Ok(format!("{}[{}]", base.schema_name(), name))
126    }
127
128    fn signature(&self) -> &Signature {
129        &self.signature
130    }
131
132    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
133        internal_err!("return_type_from_args should be called instead")
134    }
135
136    fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result<ReturnInfo> {
137        // Length check handled in the signature
138        debug_assert_eq!(args.scalar_arguments.len(), 2);
139
140        match (&args.arg_types[0], args.scalar_arguments[1].as_ref()) {
141            (DataType::Map(fields, _), _) => {
142                match fields.data_type() {
143                    DataType::Struct(fields) if fields.len() == 2 => {
144                        // Arrow's MapArray is essentially a ListArray of structs with two columns. They are
145                        // often named "key", and "value", but we don't require any specific naming here;
146                        // instead, we assume that the second column is the "value" column both here and in
147                        // execution.
148                        let value_field = fields.get(1).expect("fields should have exactly two members");
149                        Ok(ReturnInfo::new_nullable(value_field.data_type().clone()))
150                    },
151                    _ => exec_err!("Map fields must contain a Struct with exactly 2 fields"),
152                }
153            }
154            (DataType::Struct(fields),sv) => {
155                sv.and_then(|sv| sv.try_as_str().flatten().filter(|s| !s.is_empty()))
156                .map_or_else(
157                    || exec_err!("Field name must be a non-empty string"),
158                    |field_name| {
159                    fields.iter().find(|f| f.name() == field_name)
160                    .ok_or(plan_datafusion_err!("Field {field_name} not found in struct"))
161                    .map(|f| ReturnInfo::new_nullable(f.data_type().to_owned()))
162                })
163            },
164            (DataType::Null, _) => Ok(ReturnInfo::new_nullable(DataType::Null)),
165            (other, _) => exec_err!("The expression to get an indexed field is only valid for `Struct`, `Map` or `Null` types, got {other}"),
166        }
167    }
168
169    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
170        let [base, field_name] = take_function_args(self.name(), args.args)?;
171
172        if base.data_type().is_null() {
173            return Ok(ColumnarValue::Scalar(ScalarValue::Null));
174        }
175
176        let arrays =
177            ColumnarValue::values_to_arrays(&[base.clone(), field_name.clone()])?;
178        let array = Arc::clone(&arrays[0]);
179        let name = match field_name {
180            ColumnarValue::Scalar(name) => name,
181            _ => {
182                return exec_err!(
183                    "get_field function requires the argument field_name to be a string"
184                );
185            }
186        };
187
188        fn process_map_array(
189            array: Arc<dyn Array>,
190            key_array: Arc<dyn Array>,
191        ) -> Result<ColumnarValue> {
192            let map_array = as_map_array(array.as_ref())?;
193            let keys = if key_array.data_type().is_nested() {
194                let comparator = make_comparator(
195                    map_array.keys().as_ref(),
196                    key_array.as_ref(),
197                    SortOptions::default(),
198                )?;
199                let len = map_array.keys().len().min(key_array.len());
200                let values = (0..len).map(|i| comparator(i, i).is_eq()).collect();
201                let nulls =
202                    NullBuffer::union(map_array.keys().nulls(), key_array.nulls());
203                BooleanArray::new(values, nulls)
204            } else {
205                let be_compared = Scalar::new(key_array);
206                arrow::compute::kernels::cmp::eq(&be_compared, map_array.keys())?
207            };
208
209            let original_data = map_array.entries().column(1).to_data();
210            let capacity = Capacities::Array(original_data.len());
211            let mut mutable =
212                MutableArrayData::with_capacities(vec![&original_data], true, capacity);
213
214            for entry in 0..map_array.len() {
215                let start = map_array.value_offsets()[entry] as usize;
216                let end = map_array.value_offsets()[entry + 1] as usize;
217
218                let maybe_matched = keys
219                    .slice(start, end - start)
220                    .iter()
221                    .enumerate()
222                    .find(|(_, t)| t.unwrap());
223
224                if maybe_matched.is_none() {
225                    mutable.extend_nulls(1);
226                    continue;
227                }
228                let (match_offset, _) = maybe_matched.unwrap();
229                mutable.extend(0, start + match_offset, start + match_offset + 1);
230            }
231
232            let data = mutable.freeze();
233            let data = make_array(data);
234            Ok(ColumnarValue::Array(data))
235        }
236
237        match (array.data_type(), name) {
238            (DataType::Map(_, _), ScalarValue::List(arr)) => {
239                let key_array: Arc<dyn Array> = arr;
240                process_map_array(array, key_array)
241            }
242            (DataType::Map(_, _), ScalarValue::Struct(arr)) => {
243                process_map_array(array, arr as Arc<dyn Array>)
244            }
245            (DataType::Map(_, _), other) => {
246                let data_type = other.data_type();
247                if data_type.is_nested() {
248                    exec_err!("unsupported type {:?} for map access", data_type)
249                } else {
250                    process_map_array(array, other.to_array()?)
251                }
252            }
253            (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => {
254                let as_struct_array = as_struct_array(&array)?;
255                match as_struct_array.column_by_name(&k) {
256                    None => exec_err!("get indexed field {k} not found in struct"),
257                    Some(col) => Ok(ColumnarValue::Array(Arc::clone(col))),
258                }
259            }
260            (DataType::Struct(_), name) => exec_err!(
261                "get_field is only possible on struct with utf8 indexes. \
262                             Received with {name:?} index"
263            ),
264            (DataType::Null, _) => Ok(ColumnarValue::Scalar(ScalarValue::Null)),
265            (dt, name) => exec_err!(
266                "get_field is only possible on maps with utf8 indexes or struct \
267                                         with utf8 indexes. Received {dt:?} with {name:?} index"
268            ),
269        }
270    }
271
272    fn documentation(&self) -> Option<&Documentation> {
273        self.doc()
274    }
275}