use crate::utils::make_scalar_function;
use arrow::array::{Capacities, MutableArrayData};
use arrow::compute;
use arrow_array::{
new_null_array, Array, ArrayRef, GenericListArray, Int64Array, ListArray,
OffsetSizeTrait,
};
use arrow_buffer::OffsetBuffer;
use arrow_schema::DataType::{LargeList, List};
use arrow_schema::{DataType, Field};
use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array};
use datafusion_common::{exec_err, Result};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use std::any::Any;
use std::sync::Arc;
make_udf_expr_and_func!(
ArrayRepeat,
array_repeat,
element count, "returns an array containing element `count` times.", array_repeat_udf );
#[derive(Debug)]
pub(super) struct ArrayRepeat {
signature: Signature,
aliases: Vec<String>,
}
impl ArrayRepeat {
pub fn new() -> Self {
Self {
signature: Signature::variadic_any(Volatility::Immutable),
aliases: vec![String::from("list_repeat")],
}
}
}
impl ScalarUDFImpl for ArrayRepeat {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"array_repeat"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
Ok(List(Arc::new(Field::new(
"item",
arg_types[0].clone(),
true,
))))
}
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
make_scalar_function(array_repeat_inner)(args)
}
fn aliases(&self) -> &[String] {
&self.aliases
}
}
pub fn array_repeat_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("array_repeat expects two arguments");
}
let element = &args[0];
let count_array = as_int64_array(&args[1])?;
match element.data_type() {
List(_) => {
let list_array = as_list_array(element)?;
general_list_repeat::<i32>(list_array, count_array)
}
LargeList(_) => {
let list_array = as_large_list_array(element)?;
general_list_repeat::<i64>(list_array, count_array)
}
_ => general_repeat::<i32>(element, count_array),
}
}
fn general_repeat<O: OffsetSizeTrait>(
array: &ArrayRef,
count_array: &Int64Array,
) -> Result<ArrayRef> {
let data_type = array.data_type();
let mut new_values = vec![];
let count_vec = count_array
.values()
.to_vec()
.iter()
.map(|x| *x as usize)
.collect::<Vec<_>>();
for (row_index, &count) in count_vec.iter().enumerate() {
let repeated_array = if array.is_null(row_index) {
new_null_array(data_type, count)
} else {
let original_data = array.to_data();
let capacity = Capacities::Array(count);
let mut mutable =
MutableArrayData::with_capacities(vec![&original_data], false, capacity);
for _ in 0..count {
mutable.extend(0, row_index, row_index + 1);
}
let data = mutable.freeze();
arrow_array::make_array(data)
};
new_values.push(repeated_array);
}
let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect();
let values = compute::concat(&new_values)?;
Ok(Arc::new(GenericListArray::<O>::try_new(
Arc::new(Field::new("item", data_type.to_owned(), true)),
OffsetBuffer::from_lengths(count_vec),
values,
None,
)?))
}
fn general_list_repeat<O: OffsetSizeTrait>(
list_array: &GenericListArray<O>,
count_array: &Int64Array,
) -> Result<ArrayRef> {
let data_type = list_array.data_type();
let value_type = list_array.value_type();
let mut new_values = vec![];
let count_vec = count_array
.values()
.to_vec()
.iter()
.map(|x| *x as usize)
.collect::<Vec<_>>();
for (list_array_row, &count) in list_array.iter().zip(count_vec.iter()) {
let list_arr = match list_array_row {
Some(list_array_row) => {
let original_data = list_array_row.to_data();
let capacity = Capacities::Array(original_data.len() * count);
let mut mutable = MutableArrayData::with_capacities(
vec![&original_data],
false,
capacity,
);
for _ in 0..count {
mutable.extend(0, 0, original_data.len());
}
let data = mutable.freeze();
let repeated_array = arrow_array::make_array(data);
let list_arr = GenericListArray::<O>::try_new(
Arc::new(Field::new("item", value_type.clone(), true)),
OffsetBuffer::<O>::from_lengths(vec![original_data.len(); count]),
repeated_array,
None,
)?;
Arc::new(list_arr) as ArrayRef
}
None => new_null_array(data_type, count),
};
new_values.push(list_arr);
}
let lengths = new_values.iter().map(|a| a.len()).collect::<Vec<_>>();
let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect();
let values = compute::concat(&new_values)?;
Ok(Arc::new(ListArray::try_new(
Arc::new(Field::new("item", data_type.to_owned(), true)),
OffsetBuffer::<i32>::from_lengths(lengths),
values,
None,
)?))
}