use polars_error::{polars_bail, PolarsResult};
use super::{primitive_as_primitive, primitive_to_primitive, CastOptionsImpl};
use crate::array::{Array, DictionaryArray, DictionaryKey};
use crate::compute::cast::cast;
use crate::datatypes::ArrowDataType;
use crate::match_integer_type;
macro_rules! key_cast {
($keys:expr, $values:expr, $array:expr, $to_keys_type:expr, $to_type:ty, $to_datatype:expr) => {{
let cast_keys = primitive_to_primitive::<_, $to_type>($keys, $to_keys_type);
if cast_keys.null_count() > $keys.null_count() {
polars_bail!(ComputeError: "overflow")
}
unsafe {
DictionaryArray::try_new_unchecked($to_datatype, cast_keys, $values.clone())
}
.map(|x| x.boxed())
}};
}
pub fn dictionary_to_dictionary_values<K: DictionaryKey>(
from: &DictionaryArray<K>,
values_type: &ArrowDataType,
) -> PolarsResult<DictionaryArray<K>> {
let keys = from.keys();
let values = from.values();
let length = values.len();
let values = cast(values.as_ref(), values_type, CastOptionsImpl::default())?;
assert_eq!(values.len(), length); unsafe {
DictionaryArray::try_new_unchecked(from.data_type().clone(), keys.clone(), values.clone())
}
}
pub fn wrapping_dictionary_to_dictionary_values<K: DictionaryKey>(
from: &DictionaryArray<K>,
values_type: &ArrowDataType,
) -> PolarsResult<DictionaryArray<K>> {
let keys = from.keys();
let values = from.values();
let length = values.len();
let values = cast(
values.as_ref(),
values_type,
CastOptionsImpl {
wrapped: true,
partial: false,
},
)?;
assert_eq!(values.len(), length); unsafe {
DictionaryArray::try_new_unchecked(from.data_type().clone(), keys.clone(), values.clone())
}
}
pub fn dictionary_to_dictionary_keys<K1, K2>(
from: &DictionaryArray<K1>,
) -> PolarsResult<DictionaryArray<K2>>
where
K1: DictionaryKey + num_traits::NumCast,
K2: DictionaryKey + num_traits::NumCast,
{
let keys = from.keys();
let values = from.values();
let is_ordered = from.is_ordered();
let casted_keys = primitive_to_primitive::<K1, K2>(keys, &K2::PRIMITIVE.into());
if casted_keys.null_count() > keys.null_count() {
polars_bail!(ComputeError: "overflow")
} else {
let data_type = ArrowDataType::Dictionary(
K2::KEY_TYPE,
Box::new(values.data_type().clone()),
is_ordered,
);
unsafe { DictionaryArray::try_new_unchecked(data_type, casted_keys, values.clone()) }
}
}
pub fn wrapping_dictionary_to_dictionary_keys<K1, K2>(
from: &DictionaryArray<K1>,
) -> PolarsResult<DictionaryArray<K2>>
where
K1: DictionaryKey + num_traits::AsPrimitive<K2>,
K2: DictionaryKey,
{
let keys = from.keys();
let values = from.values();
let is_ordered = from.is_ordered();
let casted_keys = primitive_as_primitive::<K1, K2>(keys, &K2::PRIMITIVE.into());
if casted_keys.null_count() > keys.null_count() {
polars_bail!(ComputeError: "overflow")
} else {
let data_type = ArrowDataType::Dictionary(
K2::KEY_TYPE,
Box::new(values.data_type().clone()),
is_ordered,
);
DictionaryArray::try_new(data_type, casted_keys, values.clone())
}
}
pub(super) fn dictionary_cast_dyn<K: DictionaryKey + num_traits::NumCast>(
array: &dyn Array,
to_type: &ArrowDataType,
options: CastOptionsImpl,
) -> PolarsResult<Box<dyn Array>> {
let array = array.as_any().downcast_ref::<DictionaryArray<K>>().unwrap();
let keys = array.keys();
let values = array.values();
match to_type {
ArrowDataType::Dictionary(to_keys_type, to_values_type, _) => {
let values = cast(values.as_ref(), to_values_type, options)?;
let to_key_type = (*to_keys_type).into();
match_integer_type!(to_keys_type, |$T| {
key_cast!(keys, values, array, &to_key_type, $T, to_type.clone())
})
},
_ => unimplemented!(),
}
}