datafusion_functions_nested/
map_extract.rsuse arrow::array::{ArrayRef, Capacities, MutableArrayData};
use arrow_array::{make_array, ListArray};
use arrow::datatypes::DataType;
use arrow_array::{Array, MapArray};
use arrow_buffer::OffsetBuffer;
use arrow_schema::Field;
use datafusion_common::{cast::as_map_array, exec_err, Result};
use datafusion_expr::scalar_doc_sections::DOC_SECTION_MAP;
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
};
use std::any::Any;
use std::sync::{Arc, OnceLock};
use std::vec;
use crate::utils::{get_map_entry_field, make_scalar_function};
make_udf_expr_and_func!(
MapExtract,
map_extract,
map key,
"Return a list containing the value for a given key or an empty list if the key is not contained in the map.",
map_extract_udf
);
#[derive(Debug)]
pub(super) struct MapExtract {
signature: Signature,
aliases: Vec<String>,
}
impl MapExtract {
pub fn new() -> Self {
Self {
signature: Signature::user_defined(Volatility::Immutable),
aliases: vec![String::from("element_at")],
}
}
}
impl ScalarUDFImpl for MapExtract {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"map_extract"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if arg_types.len() != 2 {
return exec_err!("map_extract expects two arguments");
}
let map_type = &arg_types[0];
let map_fields = get_map_entry_field(map_type)?;
Ok(DataType::List(Arc::new(Field::new(
"item",
map_fields.last().unwrap().data_type().clone(),
true,
))))
}
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
make_scalar_function(map_extract_inner)(args)
}
fn aliases(&self) -> &[String] {
&self.aliases
}
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if arg_types.len() != 2 {
return exec_err!("map_extract expects two arguments");
}
let field = get_map_entry_field(&arg_types[0])?;
Ok(vec![
arg_types[0].clone(),
field.first().unwrap().data_type().clone(),
])
}
fn documentation(&self) -> Option<&Documentation> {
Some(get_map_extract_doc())
}
}
static DOCUMENTATION: OnceLock<Documentation> = OnceLock::new();
fn get_map_extract_doc() -> &'static Documentation {
DOCUMENTATION.get_or_init(|| {
Documentation::builder()
.with_doc_section(DOC_SECTION_MAP)
.with_description(
"Returns a list containing the value for the given key or an empty list if the key is not present in the map.",
)
.with_syntax_example("map_extract(map, key)")
.with_sql_example(
r#"```sql
SELECT map_extract(MAP {'a': 1, 'b': NULL, 'c': 3}, 'a');
----
[1]
SELECT map_extract(MAP {1: 'one', 2: 'two'}, 2);
----
['two']
SELECT map_extract(MAP {'x': 10, 'y': NULL, 'z': 30}, 'y');
----
[]
```"#,
)
.with_argument(
"map",
"Map expression. Can be a constant, column, or function, and any combination of map operators.",
)
.with_argument(
"key",
"Key to extract from the map. Can be a constant, column, or function, any combination of arithmetic or string operators, or a named expression of the previously listed.",
)
.build()
.unwrap()
})
}
fn general_map_extract_inner(
map_array: &MapArray,
query_keys_array: &dyn Array,
) -> Result<ArrayRef> {
let keys = map_array.keys();
let mut offsets = vec![0_i32];
let values = map_array.values();
let original_data = values.to_data();
let capacity = Capacities::Array(original_data.len());
let mut mutable =
MutableArrayData::with_capacities(vec![&original_data], true, capacity);
for (row_index, offset_window) in map_array.value_offsets().windows(2).enumerate() {
let start = offset_window[0] as usize;
let end = offset_window[1] as usize;
let len = end - start;
let query_key = query_keys_array.slice(row_index, 1);
let value_index =
(0..len).find(|&i| keys.slice(start + i, 1).as_ref() == query_key.as_ref());
match value_index {
Some(index) => {
mutable.extend(0, start + index, start + index + 1);
}
None => {
mutable.extend_nulls(1);
}
}
offsets.push(offsets[row_index] + 1);
}
let data = mutable.freeze();
Ok(Arc::new(ListArray::new(
Arc::new(Field::new("item", map_array.value_type().clone(), true)),
OffsetBuffer::<i32>::new(offsets.into()),
Arc::new(make_array(data)),
None,
)))
}
fn map_extract_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("map_extract expects two arguments");
}
let map_array = match args[0].data_type() {
DataType::Map(_, _) => as_map_array(&args[0])?,
_ => return exec_err!("The first argument in map_extract must be a map"),
};
let key_type = map_array.key_type();
if key_type != args[1].data_type() {
return exec_err!(
"The key type {} does not match the map key type {}",
args[1].data_type(),
key_type
);
}
general_map_extract_inner(map_array, &args[1])
}