use std::collections::HashMap;
use std::sync::Arc;
use std::vec;
use arrow_schema::*;
use datafusion_common::field_not_found;
use datafusion_common::internal_err;
use datafusion_expr::WindowUDF;
use sqlparser::ast::ExactNumberInfo;
use sqlparser::ast::TimezoneInfo;
use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption};
use sqlparser::ast::{DataType as SQLDataType, Ident, ObjectName, TableAlias};
use datafusion_common::config::ConfigOptions;
use datafusion_common::{
not_impl_err, plan_err, unqualified_field_not_found, DFSchema, DataFusionError,
Result,
};
use datafusion_common::{OwnedTableReference, TableReference};
use datafusion_expr::logical_plan::{LogicalPlan, LogicalPlanBuilder};
use datafusion_expr::utils::find_column_exprs;
use datafusion_expr::TableSource;
use datafusion_expr::{col, AggregateUDF, Expr, ScalarUDF};
use crate::utils::make_decimal_type;
pub trait ContextProvider {
fn get_table_provider(&self, name: TableReference) -> Result<Arc<dyn TableSource>>;
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>>;
fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>>;
fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>>;
fn get_variable_type(&self, variable_names: &[String]) -> Option<DataType>;
fn options(&self) -> &ConfigOptions;
}
#[derive(Debug)]
pub struct ParserOptions {
pub parse_float_as_decimal: bool,
pub enable_ident_normalization: bool,
}
impl Default for ParserOptions {
fn default() -> Self {
Self {
parse_float_as_decimal: false,
enable_ident_normalization: true,
}
}
}
#[derive(Debug)]
pub struct IdentNormalizer {
normalize: bool,
}
impl Default for IdentNormalizer {
fn default() -> Self {
Self { normalize: true }
}
}
impl IdentNormalizer {
pub fn new(normalize: bool) -> Self {
Self { normalize }
}
pub fn normalize(&self, ident: Ident) -> String {
if self.normalize {
crate::utils::normalize_ident(ident)
} else {
ident.value
}
}
}
#[derive(Debug, Clone)]
pub struct PlannerContext {
prepare_param_data_types: Vec<DataType>,
ctes: HashMap<String, Arc<LogicalPlan>>,
outer_query_schema: Option<DFSchema>,
}
impl Default for PlannerContext {
fn default() -> Self {
Self::new()
}
}
impl PlannerContext {
pub fn new() -> Self {
Self {
prepare_param_data_types: vec![],
ctes: HashMap::new(),
outer_query_schema: None,
}
}
pub fn with_prepare_param_data_types(
mut self,
prepare_param_data_types: Vec<DataType>,
) -> Self {
self.prepare_param_data_types = prepare_param_data_types;
self
}
pub fn outer_query_schema(&self) -> Option<&DFSchema> {
self.outer_query_schema.as_ref()
}
pub fn set_outer_query_schema(
&mut self,
mut schema: Option<DFSchema>,
) -> Option<DFSchema> {
std::mem::swap(&mut self.outer_query_schema, &mut schema);
schema
}
pub fn prepare_param_data_types(&self) -> &[DataType] {
&self.prepare_param_data_types
}
pub fn contains_cte(&self, cte_name: &str) -> bool {
self.ctes.contains_key(cte_name)
}
pub fn insert_cte(&mut self, cte_name: impl Into<String>, plan: LogicalPlan) {
let cte_name = cte_name.into();
self.ctes.insert(cte_name, Arc::new(plan));
}
pub fn get_cte(&self, cte_name: &str) -> Option<&LogicalPlan> {
self.ctes.get(cte_name).map(|cte| cte.as_ref())
}
}
pub struct SqlToRel<'a, S: ContextProvider> {
pub(crate) schema_provider: &'a S,
pub(crate) options: ParserOptions,
pub(crate) normalizer: IdentNormalizer,
}
impl<'a, S: ContextProvider> SqlToRel<'a, S> {
pub fn new(schema_provider: &'a S) -> Self {
Self::new_with_options(schema_provider, ParserOptions::default())
}
pub fn new_with_options(schema_provider: &'a S, options: ParserOptions) -> Self {
let normalize = options.enable_ident_normalization;
SqlToRel {
schema_provider,
options,
normalizer: IdentNormalizer::new(normalize),
}
}
pub fn build_schema(&self, columns: Vec<SQLColumnDef>) -> Result<Schema> {
let mut fields = Vec::with_capacity(columns.len());
for column in columns {
let data_type = self.convert_simple_data_type(&column.data_type)?;
let not_nullable = column
.options
.iter()
.any(|x| x.option == ColumnOption::NotNull);
fields.push(Field::new(
self.normalizer.normalize(column.name),
data_type,
!not_nullable,
));
}
Ok(Schema::new(fields))
}
pub(crate) fn apply_table_alias(
&self,
plan: LogicalPlan,
alias: TableAlias,
) -> Result<LogicalPlan> {
let plan = self.apply_expr_alias(plan, alias.columns)?;
LogicalPlanBuilder::from(plan)
.alias(self.normalizer.normalize(alias.name))?
.build()
}
pub(crate) fn apply_expr_alias(
&self,
plan: LogicalPlan,
idents: Vec<Ident>,
) -> Result<LogicalPlan> {
if idents.is_empty() {
Ok(plan)
} else if idents.len() != plan.schema().fields().len() {
plan_err!(
"Source table contains {} columns but only {} names given as column alias",
plan.schema().fields().len(),
idents.len()
)
} else {
let fields = plan.schema().fields().clone();
LogicalPlanBuilder::from(plan)
.project(fields.iter().zip(idents.into_iter()).map(|(field, ident)| {
col(field.name()).alias(self.normalizer.normalize(ident))
}))?
.build()
}
}
pub(crate) fn validate_schema_satisfies_exprs(
&self,
schema: &DFSchema,
exprs: &[Expr],
) -> Result<()> {
find_column_exprs(exprs)
.iter()
.try_for_each(|col| match col {
Expr::Column(col) => match &col.relation {
Some(r) => {
schema.field_with_qualified_name(r, &col.name)?;
Ok(())
}
None => {
if !schema.fields_with_unqualified_name(&col.name).is_empty() {
Ok(())
} else {
Err(unqualified_field_not_found(col.name.as_str(), schema))
}
}
}
.map_err(|_: DataFusionError| {
field_not_found(col.relation.clone(), col.name.as_str(), schema)
}),
_ => internal_err!("Not a column"),
})
}
pub(crate) fn convert_data_type(&self, sql_type: &SQLDataType) -> Result<DataType> {
match sql_type {
SQLDataType::Array(Some(inner_sql_type)) => {
let data_type = self.convert_simple_data_type(inner_sql_type)?;
Ok(DataType::List(Arc::new(Field::new(
"field", data_type, true,
))))
}
SQLDataType::Array(None) => {
not_impl_err!("Arrays with unspecified type is not supported")
}
other => self.convert_simple_data_type(other),
}
}
fn convert_simple_data_type(&self, sql_type: &SQLDataType) -> Result<DataType> {
match sql_type {
SQLDataType::Boolean | SQLDataType::Bool => Ok(DataType::Boolean),
SQLDataType::TinyInt(_) => Ok(DataType::Int8),
SQLDataType::SmallInt(_) | SQLDataType::Int2(_) => Ok(DataType::Int16),
SQLDataType::Int(_) | SQLDataType::Integer(_) | SQLDataType::Int4(_) => Ok(DataType::Int32),
SQLDataType::BigInt(_) | SQLDataType::Int8(_) => Ok(DataType::Int64),
SQLDataType::UnsignedTinyInt(_) => Ok(DataType::UInt8),
SQLDataType::UnsignedSmallInt(_) | SQLDataType::UnsignedInt2(_) => Ok(DataType::UInt16),
SQLDataType::UnsignedInt(_) | SQLDataType::UnsignedInteger(_) | SQLDataType::UnsignedInt4(_) => {
Ok(DataType::UInt32)
}
SQLDataType::UnsignedBigInt(_) | SQLDataType::UnsignedInt8(_) => Ok(DataType::UInt64),
SQLDataType::Float(_) => Ok(DataType::Float32),
SQLDataType::Real | SQLDataType::Float4 => Ok(DataType::Float32),
SQLDataType::Double | SQLDataType::DoublePrecision | SQLDataType::Float8 => Ok(DataType::Float64),
SQLDataType::Char(_)
| SQLDataType::Varchar(_)
| SQLDataType::Text
| SQLDataType::String => Ok(DataType::Utf8),
SQLDataType::Timestamp(None, tz_info) => {
let tz = if matches!(tz_info, TimezoneInfo::Tz)
|| matches!(tz_info, TimezoneInfo::WithTimeZone)
{
self.schema_provider.options().execution.time_zone.clone()
} else {
None
};
Ok(DataType::Timestamp(TimeUnit::Nanosecond, tz.map(Into::into)))
}
SQLDataType::Date => Ok(DataType::Date32),
SQLDataType::Time(None, tz_info) => {
if matches!(tz_info, TimezoneInfo::None)
|| matches!(tz_info, TimezoneInfo::WithoutTimeZone)
{
Ok(DataType::Time64(TimeUnit::Nanosecond))
} else {
not_impl_err!(
"Unsupported SQL type {sql_type:?}"
)
}
}
SQLDataType::Numeric(exact_number_info)
| SQLDataType::Decimal(exact_number_info) => {
let (precision, scale) = match *exact_number_info {
ExactNumberInfo::None => (None, None),
ExactNumberInfo::Precision(precision) => (Some(precision), None),
ExactNumberInfo::PrecisionAndScale(precision, scale) => {
(Some(precision), Some(scale))
}
};
make_decimal_type(precision, scale)
}
SQLDataType::Bytea => Ok(DataType::Binary),
SQLDataType::Interval => Ok(DataType::Interval(IntervalUnit::MonthDayNano)),
SQLDataType::Nvarchar(_)
| SQLDataType::JSON
| SQLDataType::Uuid
| SQLDataType::Binary(_)
| SQLDataType::Varbinary(_)
| SQLDataType::Blob(_)
| SQLDataType::Datetime(_)
| SQLDataType::Regclass
| SQLDataType::Custom(_, _)
| SQLDataType::Array(_)
| SQLDataType::Enum(_)
| SQLDataType::Set(_)
| SQLDataType::MediumInt(_)
| SQLDataType::UnsignedMediumInt(_)
| SQLDataType::Character(_)
| SQLDataType::CharacterVarying(_)
| SQLDataType::CharVarying(_)
| SQLDataType::CharacterLargeObject(_)
| SQLDataType::CharLargeObject(_)
| SQLDataType::Timestamp(Some(_), _)
| SQLDataType::Time(Some(_), _)
| SQLDataType::Dec(_)
| SQLDataType::BigNumeric(_)
| SQLDataType::BigDecimal(_)
| SQLDataType::Clob(_) => not_impl_err!(
"Unsupported SQL type {sql_type:?}"
),
}
}
pub(crate) fn object_name_to_table_reference(
&self,
object_name: ObjectName,
) -> Result<OwnedTableReference> {
object_name_to_table_reference(
object_name,
self.options.enable_ident_normalization,
)
}
}
pub fn object_name_to_table_reference(
object_name: ObjectName,
enable_normalization: bool,
) -> Result<OwnedTableReference> {
let ObjectName(idents) = object_name;
idents_to_table_reference(idents, enable_normalization)
}
pub(crate) fn idents_to_table_reference(
idents: Vec<Ident>,
enable_normalization: bool,
) -> Result<OwnedTableReference> {
struct IdentTaker(Vec<Ident>);
impl IdentTaker {
fn take(&mut self, enable_normalization: bool) -> String {
let ident = self.0.pop().expect("no more identifiers");
IdentNormalizer::new(enable_normalization).normalize(ident)
}
}
let mut taker = IdentTaker(idents);
match taker.0.len() {
1 => {
let table = taker.take(enable_normalization);
Ok(OwnedTableReference::bare(table))
}
2 => {
let table = taker.take(enable_normalization);
let schema = taker.take(enable_normalization);
Ok(OwnedTableReference::partial(schema, table))
}
3 => {
let table = taker.take(enable_normalization);
let schema = taker.take(enable_normalization);
let catalog = taker.take(enable_normalization);
Ok(OwnedTableReference::full(catalog, schema, table))
}
_ => plan_err!("Unsupported compound identifier '{:?}'", taker.0),
}
}
pub fn object_name_to_qualifier(
sql_table_name: &ObjectName,
enable_normalization: bool,
) -> String {
let columns = vec!["table_name", "table_schema", "table_catalog"].into_iter();
let normalizer = IdentNormalizer::new(enable_normalization);
sql_table_name
.0
.iter()
.rev()
.zip(columns)
.map(|(ident, column_name)| {
format!(
r#"{} = '{}'"#,
column_name,
normalizer.normalize(ident.clone())
)
})
.collect::<Vec<_>>()
.join(" AND ")
}