use std::collections::HashMap;
use std::sync::Arc;
use std::vec;
use arrow_schema::*;
use datafusion_common::{
field_not_found, internal_err, plan_datafusion_err, DFSchemaRef, SchemaError,
};
use sqlparser::ast::{ArrayElemTypeDef, ExactNumberInfo};
use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption};
use sqlparser::ast::{DataType as SQLDataType, Ident, ObjectName, TableAlias};
use sqlparser::ast::{TimezoneInfo, Value};
use datafusion_common::TableReference;
use datafusion_common::{
not_impl_err, plan_err, unqualified_field_not_found, DFSchema, DataFusionError,
Result,
};
use datafusion_expr::logical_plan::{LogicalPlan, LogicalPlanBuilder};
use datafusion_expr::utils::find_column_exprs;
use datafusion_expr::{col, Expr};
use crate::utils::{make_decimal_type, value_to_string};
pub use datafusion_expr::planner::ContextProvider;
#[derive(Debug)]
pub struct ParserOptions {
pub parse_float_as_decimal: bool,
pub enable_ident_normalization: bool,
pub support_varchar_with_length: bool,
pub enable_options_value_normalization: bool,
}
impl Default for ParserOptions {
fn default() -> Self {
Self {
parse_float_as_decimal: false,
enable_ident_normalization: true,
support_varchar_with_length: true,
enable_options_value_normalization: true,
}
}
}
#[derive(Debug)]
pub struct IdentNormalizer {
normalize: bool,
}
impl Default for IdentNormalizer {
fn default() -> Self {
Self { normalize: true }
}
}
impl IdentNormalizer {
pub fn new(normalize: bool) -> Self {
Self { normalize }
}
pub fn normalize(&self, ident: Ident) -> String {
if self.normalize {
crate::utils::normalize_ident(ident)
} else {
ident.value
}
}
}
#[derive(Debug)]
pub struct ValueNormalizer {
normalize: bool,
}
impl Default for ValueNormalizer {
fn default() -> Self {
Self { normalize: true }
}
}
impl ValueNormalizer {
pub fn new(normalize: bool) -> Self {
Self { normalize }
}
pub fn normalize(&self, value: Value) -> Option<String> {
match (value_to_string(&value), self.normalize) {
(Some(s), true) => Some(s.to_ascii_lowercase()),
(Some(s), false) => Some(s),
(None, _) => None,
}
}
}
#[derive(Debug, Clone)]
pub struct PlannerContext {
prepare_param_data_types: Arc<Vec<DataType>>,
ctes: HashMap<String, Arc<LogicalPlan>>,
outer_query_schema: Option<DFSchemaRef>,
}
impl Default for PlannerContext {
fn default() -> Self {
Self::new()
}
}
impl PlannerContext {
pub fn new() -> Self {
Self {
prepare_param_data_types: Arc::new(vec![]),
ctes: HashMap::new(),
outer_query_schema: None,
}
}
pub fn with_prepare_param_data_types(
mut self,
prepare_param_data_types: Vec<DataType>,
) -> Self {
self.prepare_param_data_types = prepare_param_data_types.into();
self
}
pub fn outer_query_schema(&self) -> Option<&DFSchema> {
self.outer_query_schema.as_ref().map(|s| s.as_ref())
}
pub fn set_outer_query_schema(
&mut self,
mut schema: Option<DFSchemaRef>,
) -> Option<DFSchemaRef> {
std::mem::swap(&mut self.outer_query_schema, &mut schema);
schema
}
pub fn prepare_param_data_types(&self) -> &[DataType] {
&self.prepare_param_data_types
}
pub fn contains_cte(&self, cte_name: &str) -> bool {
self.ctes.contains_key(cte_name)
}
pub fn insert_cte(&mut self, cte_name: impl Into<String>, plan: LogicalPlan) {
let cte_name = cte_name.into();
self.ctes.insert(cte_name, Arc::new(plan));
}
pub fn get_cte(&self, cte_name: &str) -> Option<&LogicalPlan> {
self.ctes.get(cte_name).map(|cte| cte.as_ref())
}
pub(super) fn remove_cte(&mut self, cte_name: &str) {
self.ctes.remove(cte_name);
}
}
pub struct SqlToRel<'a, S: ContextProvider> {
pub(crate) context_provider: &'a S,
pub(crate) options: ParserOptions,
pub(crate) ident_normalizer: IdentNormalizer,
pub(crate) value_normalizer: ValueNormalizer,
}
impl<'a, S: ContextProvider> SqlToRel<'a, S> {
pub fn new(context_provider: &'a S) -> Self {
Self::new_with_options(context_provider, ParserOptions::default())
}
pub fn new_with_options(context_provider: &'a S, options: ParserOptions) -> Self {
let ident_normalize = options.enable_ident_normalization;
let options_value_normalize = options.enable_options_value_normalization;
SqlToRel {
context_provider,
options,
ident_normalizer: IdentNormalizer::new(ident_normalize),
value_normalizer: ValueNormalizer::new(options_value_normalize),
}
}
pub fn build_schema(&self, columns: Vec<SQLColumnDef>) -> Result<Schema> {
let mut fields = Vec::with_capacity(columns.len());
for column in columns {
let data_type = self.convert_data_type(&column.data_type)?;
let not_nullable = column
.options
.iter()
.any(|x| x.option == ColumnOption::NotNull);
fields.push(Field::new(
self.ident_normalizer.normalize(column.name),
data_type,
!not_nullable,
));
}
Ok(Schema::new(fields))
}
pub(super) fn build_column_defaults(
&self,
columns: &Vec<SQLColumnDef>,
planner_context: &mut PlannerContext,
) -> Result<Vec<(String, Expr)>> {
let mut column_defaults = vec![];
let empty_schema = DFSchema::empty();
let error_desc = |e: DataFusionError| match e {
DataFusionError::SchemaError(SchemaError::FieldNotFound { .. }, _) => {
plan_datafusion_err!(
"Column reference is not allowed in the DEFAULT expression : {}",
e
)
}
_ => e,
};
for column in columns {
if let Some(default_sql_expr) =
column.options.iter().find_map(|o| match &o.option {
ColumnOption::Default(expr) => Some(expr),
_ => None,
})
{
let default_expr = self
.sql_to_expr(default_sql_expr.clone(), &empty_schema, planner_context)
.map_err(error_desc)?;
column_defaults.push((
self.ident_normalizer.normalize(column.name.clone()),
default_expr,
));
}
}
Ok(column_defaults)
}
pub(crate) fn apply_table_alias(
&self,
plan: LogicalPlan,
alias: TableAlias,
) -> Result<LogicalPlan> {
let plan = self.apply_expr_alias(plan, alias.columns)?;
LogicalPlanBuilder::from(plan)
.alias(TableReference::bare(
self.ident_normalizer.normalize(alias.name),
))?
.build()
}
pub(crate) fn apply_expr_alias(
&self,
plan: LogicalPlan,
idents: Vec<Ident>,
) -> Result<LogicalPlan> {
if idents.is_empty() {
Ok(plan)
} else if idents.len() != plan.schema().fields().len() {
plan_err!(
"Source table contains {} columns but only {} names given as column alias",
plan.schema().fields().len(),
idents.len()
)
} else {
let fields = plan.schema().fields().clone();
LogicalPlanBuilder::from(plan)
.project(fields.iter().zip(idents.into_iter()).map(|(field, ident)| {
col(field.name()).alias(self.ident_normalizer.normalize(ident))
}))?
.build()
}
}
pub(crate) fn validate_schema_satisfies_exprs(
&self,
schema: &DFSchema,
exprs: &[Expr],
) -> Result<()> {
find_column_exprs(exprs)
.iter()
.try_for_each(|col| match col {
Expr::Column(col) => match &col.relation {
Some(r) => {
schema.field_with_qualified_name(r, &col.name)?;
Ok(())
}
None => {
if !schema.fields_with_unqualified_name(&col.name).is_empty() {
Ok(())
} else {
Err(unqualified_field_not_found(col.name.as_str(), schema))
}
}
}
.map_err(|_: DataFusionError| {
field_not_found(col.relation.clone(), col.name.as_str(), schema)
}),
_ => internal_err!("Not a column"),
})
}
pub(crate) fn convert_data_type(&self, sql_type: &SQLDataType) -> Result<DataType> {
match sql_type {
SQLDataType::Array(ArrayElemTypeDef::AngleBracket(inner_sql_type))
| SQLDataType::Array(ArrayElemTypeDef::SquareBracket(inner_sql_type, _)) => {
let inner_data_type = self.convert_data_type(inner_sql_type)?;
Ok(DataType::new_list(inner_data_type, true))
}
SQLDataType::Array(ArrayElemTypeDef::None) => {
not_impl_err!("Arrays with unspecified type is not supported")
}
other => self.convert_simple_data_type(other),
}
}
fn convert_simple_data_type(&self, sql_type: &SQLDataType) -> Result<DataType> {
match sql_type {
SQLDataType::Boolean | SQLDataType::Bool => Ok(DataType::Boolean),
SQLDataType::TinyInt(_) => Ok(DataType::Int8),
SQLDataType::SmallInt(_) | SQLDataType::Int2(_) => Ok(DataType::Int16),
SQLDataType::Int(_) | SQLDataType::Integer(_) | SQLDataType::Int4(_) => Ok(DataType::Int32),
SQLDataType::BigInt(_) | SQLDataType::Int8(_) => Ok(DataType::Int64),
SQLDataType::UnsignedTinyInt(_) => Ok(DataType::UInt8),
SQLDataType::UnsignedSmallInt(_) | SQLDataType::UnsignedInt2(_) => Ok(DataType::UInt16),
SQLDataType::UnsignedInt(_) | SQLDataType::UnsignedInteger(_) | SQLDataType::UnsignedInt4(_) => {
Ok(DataType::UInt32)
}
SQLDataType::Varchar(length) => {
match (length, self.options.support_varchar_with_length) {
(Some(_), false) => plan_err!("does not support Varchar with length, please set `support_varchar_with_length` to be true"),
_ => Ok(DataType::Utf8),
}
}
SQLDataType::UnsignedBigInt(_) | SQLDataType::UnsignedInt8(_) => Ok(DataType::UInt64),
SQLDataType::Float(_) => Ok(DataType::Float32),
SQLDataType::Real | SQLDataType::Float4 => Ok(DataType::Float32),
SQLDataType::Double | SQLDataType::DoublePrecision | SQLDataType::Float8 => Ok(DataType::Float64),
SQLDataType::Char(_)
| SQLDataType::Text
| SQLDataType::String(_) => Ok(DataType::Utf8),
SQLDataType::Timestamp(None, tz_info) => {
let tz = if matches!(tz_info, TimezoneInfo::Tz)
|| matches!(tz_info, TimezoneInfo::WithTimeZone)
{
self.context_provider.options().execution.time_zone.clone()
} else {
None
};
Ok(DataType::Timestamp(TimeUnit::Nanosecond, tz.map(Into::into)))
}
SQLDataType::Date => Ok(DataType::Date32),
SQLDataType::Time(None, tz_info) => {
if matches!(tz_info, TimezoneInfo::None)
|| matches!(tz_info, TimezoneInfo::WithoutTimeZone)
{
Ok(DataType::Time64(TimeUnit::Nanosecond))
} else {
not_impl_err!(
"Unsupported SQL type {sql_type:?}"
)
}
}
SQLDataType::Numeric(exact_number_info)
| SQLDataType::Decimal(exact_number_info) => {
let (precision, scale) = match *exact_number_info {
ExactNumberInfo::None => (None, None),
ExactNumberInfo::Precision(precision) => (Some(precision), None),
ExactNumberInfo::PrecisionAndScale(precision, scale) => {
(Some(precision), Some(scale))
}
};
make_decimal_type(precision, scale)
}
SQLDataType::Bytea => Ok(DataType::Binary),
SQLDataType::Interval => Ok(DataType::Interval(IntervalUnit::MonthDayNano)),
SQLDataType::Struct(fields) => {
let fields = fields
.iter()
.enumerate()
.map(|(idx, field)| {
let data_type = self.convert_data_type(&field.field_type)?;
let field_name = match &field.field_name{
Some(ident) => ident.clone(),
None => Ident::new(format!("c{idx}"))
};
Ok(Arc::new(Field::new(
self.ident_normalizer.normalize(field_name),
data_type,
true,
)))
})
.collect::<Result<Vec<_>>>()?;
Ok(DataType::Struct(Fields::from(fields)))
}
SQLDataType::Nvarchar(_)
| SQLDataType::JSON
| SQLDataType::Uuid
| SQLDataType::Binary(_)
| SQLDataType::Varbinary(_)
| SQLDataType::Blob(_)
| SQLDataType::Datetime(_)
| SQLDataType::Regclass
| SQLDataType::Custom(_, _)
| SQLDataType::Array(_)
| SQLDataType::Enum(_)
| SQLDataType::Set(_)
| SQLDataType::MediumInt(_)
| SQLDataType::UnsignedMediumInt(_)
| SQLDataType::Character(_)
| SQLDataType::CharacterVarying(_)
| SQLDataType::CharVarying(_)
| SQLDataType::CharacterLargeObject(_)
| SQLDataType::CharLargeObject(_)
| SQLDataType::Timestamp(Some(_), _)
| SQLDataType::Time(Some(_), _)
| SQLDataType::Dec(_)
| SQLDataType::BigNumeric(_)
| SQLDataType::BigDecimal(_)
| SQLDataType::Clob(_)
| SQLDataType::Bytes(_)
| SQLDataType::Int64
| SQLDataType::Float64
| SQLDataType::JSONB
| SQLDataType::Unspecified
| SQLDataType::Int16
| SQLDataType::Int32
| SQLDataType::Int128
| SQLDataType::Int256
| SQLDataType::UInt8
| SQLDataType::UInt16
| SQLDataType::UInt32
| SQLDataType::UInt64
| SQLDataType::UInt128
| SQLDataType::UInt256
| SQLDataType::Float32
| SQLDataType::Date32
| SQLDataType::Datetime64(_, _)
| SQLDataType::FixedString(_)
| SQLDataType::Map(_, _)
| SQLDataType::Tuple(_)
| SQLDataType::Nested(_)
| SQLDataType::Union(_)
| SQLDataType::Nullable(_)
| SQLDataType::LowCardinality(_)
=> not_impl_err!(
"Unsupported SQL type {sql_type:?}"
),
}
}
pub(crate) fn object_name_to_table_reference(
&self,
object_name: ObjectName,
) -> Result<TableReference> {
object_name_to_table_reference(
object_name,
self.options.enable_ident_normalization,
)
}
}
pub fn object_name_to_table_reference(
object_name: ObjectName,
enable_normalization: bool,
) -> Result<TableReference> {
let ObjectName(idents) = object_name;
idents_to_table_reference(idents, enable_normalization)
}
pub(crate) fn idents_to_table_reference(
idents: Vec<Ident>,
enable_normalization: bool,
) -> Result<TableReference> {
struct IdentTaker(Vec<Ident>);
impl IdentTaker {
fn take(&mut self, enable_normalization: bool) -> String {
let ident = self.0.pop().expect("no more identifiers");
IdentNormalizer::new(enable_normalization).normalize(ident)
}
}
let mut taker = IdentTaker(idents);
match taker.0.len() {
1 => {
let table = taker.take(enable_normalization);
Ok(TableReference::bare(table))
}
2 => {
let table = taker.take(enable_normalization);
let schema = taker.take(enable_normalization);
Ok(TableReference::partial(schema, table))
}
3 => {
let table = taker.take(enable_normalization);
let schema = taker.take(enable_normalization);
let catalog = taker.take(enable_normalization);
Ok(TableReference::full(catalog, schema, table))
}
_ => plan_err!("Unsupported compound identifier '{:?}'", taker.0),
}
}
pub fn object_name_to_qualifier(
sql_table_name: &ObjectName,
enable_normalization: bool,
) -> String {
let columns = vec!["table_name", "table_schema", "table_catalog"].into_iter();
let normalizer = IdentNormalizer::new(enable_normalization);
sql_table_name
.0
.iter()
.rev()
.zip(columns)
.map(|(ident, column_name)| {
format!(
r#"{} = '{}'"#,
column_name,
normalizer.normalize(ident.clone())
)
})
.collect::<Vec<_>>()
.join(" AND ")
}