use std::cmp::Ordering;
use std::fmt::{Display, Formatter};
use std::ops::Add;
use bytemuck::{Pod, Zeroable};
use polars_error::*;
use polars_utils::min_max::MinMax;
use polars_utils::nulls::IsNull;
use polars_utils::slice::GetSaferUnchecked;
use polars_utils::total_ord::{TotalEq, TotalOrd};
use crate::buffer::Buffer;
use crate::datatypes::PrimitiveType;
use crate::types::NativeType;
pub const INLINE_VIEW_SIZE: u32 = 12;
#[derive(Debug, Copy, Clone, Default)]
#[repr(C)]
pub struct View {
pub length: u32,
pub prefix: u32,
pub buffer_idx: u32,
pub offset: u32,
}
impl View {
#[inline(always)]
pub fn as_u128(self) -> u128 {
unsafe { std::mem::transmute(self) }
}
#[inline]
pub fn new_from_bytes(bytes: &[u8], buffer_idx: u32, offset: u32) -> Self {
if bytes.len() <= 12 {
let mut ret = Self {
length: bytes.len() as u32,
..Default::default()
};
let ret_ptr = &mut ret as *mut _ as *mut u8;
unsafe {
core::ptr::copy_nonoverlapping(bytes.as_ptr(), ret_ptr.add(4), bytes.len());
}
ret
} else {
let prefix_buf: [u8; 4] = std::array::from_fn(|i| *bytes.get(i).unwrap_or(&0));
Self {
length: bytes.len() as u32,
prefix: u32::from_le_bytes(prefix_buf),
buffer_idx,
offset,
}
}
}
pub unsafe fn get_slice_unchecked<'a>(&'a self, buffers: &'a [Buffer<u8>]) -> &'a [u8] {
unsafe {
if self.length <= 12 {
let ptr = self as *const View as *const u8;
std::slice::from_raw_parts(ptr.add(4), self.length as usize)
} else {
let data = buffers.get_unchecked_release(self.buffer_idx as usize);
let offset = self.offset as usize;
data.get_unchecked_release(offset..offset + self.length as usize)
}
}
}
}
impl IsNull for View {
const HAS_NULLS: bool = false;
type Inner = Self;
fn is_null(&self) -> bool {
false
}
fn unwrap_inner(self) -> Self::Inner {
self
}
}
impl Display for View {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
unsafe impl Zeroable for View {}
unsafe impl Pod for View {}
impl Add<Self> for View {
type Output = View;
fn add(self, _rhs: Self) -> Self::Output {
unimplemented!()
}
}
impl num_traits::Zero for View {
fn zero() -> Self {
Default::default()
}
fn is_zero(&self) -> bool {
*self == Self::zero()
}
}
impl PartialEq for View {
fn eq(&self, other: &Self) -> bool {
self.as_u128() == other.as_u128()
}
}
impl TotalOrd for View {
fn tot_cmp(&self, _other: &Self) -> Ordering {
unimplemented!()
}
}
impl TotalEq for View {
fn tot_eq(&self, other: &Self) -> bool {
self.eq(other)
}
}
impl MinMax for View {
fn nan_min_lt(&self, _other: &Self) -> bool {
unimplemented!()
}
fn nan_max_lt(&self, _other: &Self) -> bool {
unimplemented!()
}
}
impl NativeType for View {
const PRIMITIVE: PrimitiveType = PrimitiveType::UInt128;
type Bytes = [u8; 16];
#[inline]
fn to_le_bytes(&self) -> Self::Bytes {
self.as_u128().to_le_bytes()
}
#[inline]
fn to_be_bytes(&self) -> Self::Bytes {
self.as_u128().to_be_bytes()
}
#[inline]
fn from_le_bytes(bytes: Self::Bytes) -> Self {
Self::from(u128::from_le_bytes(bytes))
}
#[inline]
fn from_be_bytes(bytes: Self::Bytes) -> Self {
Self::from(u128::from_be_bytes(bytes))
}
}
impl From<u128> for View {
#[inline]
fn from(value: u128) -> Self {
unsafe { std::mem::transmute(value) }
}
}
impl From<View> for u128 {
#[inline]
fn from(value: View) -> Self {
value.as_u128()
}
}
fn validate_view<F>(views: &[View], buffers: &[Buffer<u8>], validate_bytes: F) -> PolarsResult<()>
where
F: Fn(&[u8]) -> PolarsResult<()>,
{
for view in views {
let len = view.length;
if len <= INLINE_VIEW_SIZE {
if len < INLINE_VIEW_SIZE && view.as_u128() >> (32 + len * 8) != 0 {
polars_bail!(ComputeError: "view contained non-zero padding in prefix");
}
validate_bytes(&view.to_le_bytes()[4..4 + len as usize])?;
} else {
let data = buffers.get(view.buffer_idx as usize).ok_or_else(|| {
polars_err!(OutOfBounds: "view index out of bounds\n\nGot: {} buffers and index: {}", buffers.len(), view.buffer_idx)
})?;
let start = view.offset as usize;
let end = start + len as usize;
let b = data
.as_slice()
.get(start..end)
.ok_or_else(|| polars_err!(OutOfBounds: "buffer slice out of bounds"))?;
polars_ensure!(b.starts_with(&view.prefix.to_le_bytes()), ComputeError: "prefix does not match string data");
validate_bytes(b)?;
};
}
Ok(())
}
pub(super) fn validate_binary_view(views: &[View], buffers: &[Buffer<u8>]) -> PolarsResult<()> {
validate_view(views, buffers, |_| Ok(()))
}
fn validate_utf8(b: &[u8]) -> PolarsResult<()> {
match simdutf8::basic::from_utf8(b) {
Ok(_) => Ok(()),
Err(_) => Err(polars_err!(ComputeError: "invalid utf8")),
}
}
pub(super) fn validate_utf8_view(views: &[View], buffers: &[Buffer<u8>]) -> PolarsResult<()> {
validate_view(views, buffers, validate_utf8)
}
pub(super) unsafe fn validate_utf8_only(
views: &[View],
buffers_to_check: &[Buffer<u8>],
all_buffers: &[Buffer<u8>],
) -> PolarsResult<()> {
if all_buffers.is_empty() {
for view in views {
let len = view.length;
validate_utf8(
view.to_le_bytes()
.get_unchecked_release(4..4 + len as usize),
)?;
}
return Ok(());
}
if buffers_to_check.iter().all(|buf| buf.is_ascii()) {
for view in views {
let len = view.length;
if len <= 12 {
validate_utf8(
view.to_le_bytes()
.get_unchecked_release(4..4 + len as usize),
)?;
}
}
} else {
for view in views {
let len = view.length;
if len <= 12 {
validate_utf8(
view.to_le_bytes()
.get_unchecked_release(4..4 + len as usize),
)?;
} else {
let buffer_idx = view.buffer_idx;
let offset = view.offset;
let data = all_buffers.get_unchecked_release(buffer_idx as usize);
let start = offset as usize;
let end = start + len as usize;
let b = &data.as_slice().get_unchecked_release(start..end);
validate_utf8(b)?;
};
}
}
Ok(())
}