use std::sync::Arc;
use arrow_schema::TimeUnit;
use datafusion_expr::Expr;
use regex::Regex;
use sqlparser::{
ast::{self, Function, Ident, ObjectName, TimezoneInfo},
keywords::ALL_KEYWORDS,
};
use datafusion_common::Result;
use super::{utils::date_part_to_sql, Unparser};
pub trait Dialect: Send + Sync {
fn identifier_quote_style(&self, _identifier: &str) -> Option<char>;
fn supports_nulls_first_in_sort(&self) -> bool {
true
}
fn use_timestamp_for_date64(&self) -> bool {
false
}
fn interval_style(&self) -> IntervalStyle {
IntervalStyle::PostgresVerbose
}
fn float64_ast_dtype(&self) -> ast::DataType {
ast::DataType::Double
}
fn utf8_cast_dtype(&self) -> ast::DataType {
ast::DataType::Varchar(None)
}
fn large_utf8_cast_dtype(&self) -> ast::DataType {
ast::DataType::Text
}
fn date_field_extract_style(&self) -> DateFieldExtractStyle {
DateFieldExtractStyle::DatePart
}
fn int64_cast_dtype(&self) -> ast::DataType {
ast::DataType::BigInt(None)
}
fn int32_cast_dtype(&self) -> ast::DataType {
ast::DataType::Integer(None)
}
fn timestamp_cast_dtype(
&self,
_time_unit: &TimeUnit,
tz: &Option<Arc<str>>,
) -> ast::DataType {
let tz_info = match tz {
Some(_) => TimezoneInfo::WithTimeZone,
None => TimezoneInfo::None,
};
ast::DataType::Timestamp(None, tz_info)
}
fn date32_cast_dtype(&self) -> ast::DataType {
ast::DataType::Date
}
fn supports_column_alias_in_table_alias(&self) -> bool {
true
}
fn requires_derived_table_alias(&self) -> bool {
false
}
fn scalar_function_to_sql_overrides(
&self,
_unparser: &Unparser,
_func_name: &str,
_args: &[Expr],
) -> Result<Option<ast::Expr>> {
Ok(None)
}
}
#[derive(Clone, Copy)]
pub enum IntervalStyle {
PostgresVerbose,
SQLStandard,
MySQL,
}
#[derive(Clone, Copy, PartialEq)]
pub enum DateFieldExtractStyle {
DatePart,
Extract,
Strftime,
}
pub struct DefaultDialect {}
impl Dialect for DefaultDialect {
fn identifier_quote_style(&self, identifier: &str) -> Option<char> {
let identifier_regex = Regex::new(r"^[a-zA-Z_][a-zA-Z0-9_]*$").unwrap();
let id_upper = identifier.to_uppercase();
if (id_upper != "ID" && ALL_KEYWORDS.contains(&id_upper.as_str()))
|| !identifier_regex.is_match(identifier)
{
Some('"')
} else {
None
}
}
}
pub struct PostgreSqlDialect {}
impl Dialect for PostgreSqlDialect {
fn identifier_quote_style(&self, _: &str) -> Option<char> {
Some('"')
}
fn interval_style(&self) -> IntervalStyle {
IntervalStyle::PostgresVerbose
}
fn float64_ast_dtype(&self) -> ast::DataType {
ast::DataType::DoublePrecision
}
fn scalar_function_to_sql_overrides(
&self,
unparser: &Unparser,
func_name: &str,
args: &[Expr],
) -> Result<Option<ast::Expr>> {
if func_name == "round" {
return Ok(Some(
self.round_to_sql_enforce_numeric(unparser, func_name, args)?,
));
}
Ok(None)
}
}
impl PostgreSqlDialect {
fn round_to_sql_enforce_numeric(
&self,
unparser: &Unparser,
func_name: &str,
args: &[Expr],
) -> Result<ast::Expr> {
let mut args = unparser.function_args_to_sql(args)?;
if let Some(ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(expr))) =
args.first_mut()
{
if let ast::Expr::Cast { data_type, .. } = expr {
*data_type = ast::DataType::Numeric(ast::ExactNumberInfo::None);
} else {
*expr = ast::Expr::Cast {
kind: ast::CastKind::Cast,
expr: Box::new(expr.clone()),
data_type: ast::DataType::Numeric(ast::ExactNumberInfo::None),
format: None,
};
}
}
Ok(ast::Expr::Function(Function {
name: ObjectName(vec![Ident {
value: func_name.to_string(),
quote_style: None,
}]),
args: ast::FunctionArguments::List(ast::FunctionArgumentList {
duplicate_treatment: None,
args,
clauses: vec![],
}),
filter: None,
null_treatment: None,
over: None,
within_group: vec![],
parameters: ast::FunctionArguments::None,
}))
}
}
pub struct MySqlDialect {}
impl Dialect for MySqlDialect {
fn identifier_quote_style(&self, _: &str) -> Option<char> {
Some('`')
}
fn supports_nulls_first_in_sort(&self) -> bool {
false
}
fn interval_style(&self) -> IntervalStyle {
IntervalStyle::MySQL
}
fn utf8_cast_dtype(&self) -> ast::DataType {
ast::DataType::Char(None)
}
fn large_utf8_cast_dtype(&self) -> ast::DataType {
ast::DataType::Char(None)
}
fn date_field_extract_style(&self) -> DateFieldExtractStyle {
DateFieldExtractStyle::Extract
}
fn int64_cast_dtype(&self) -> ast::DataType {
ast::DataType::Custom(ObjectName(vec![Ident::new("SIGNED")]), vec![])
}
fn int32_cast_dtype(&self) -> ast::DataType {
ast::DataType::Custom(ObjectName(vec![Ident::new("SIGNED")]), vec![])
}
fn timestamp_cast_dtype(
&self,
_time_unit: &TimeUnit,
_tz: &Option<Arc<str>>,
) -> ast::DataType {
ast::DataType::Datetime(None)
}
fn requires_derived_table_alias(&self) -> bool {
true
}
fn scalar_function_to_sql_overrides(
&self,
unparser: &Unparser,
func_name: &str,
args: &[Expr],
) -> Result<Option<ast::Expr>> {
if func_name == "date_part" {
return date_part_to_sql(unparser, self.date_field_extract_style(), args);
}
Ok(None)
}
}
pub struct SqliteDialect {}
impl Dialect for SqliteDialect {
fn identifier_quote_style(&self, _: &str) -> Option<char> {
Some('`')
}
fn date_field_extract_style(&self) -> DateFieldExtractStyle {
DateFieldExtractStyle::Strftime
}
fn date32_cast_dtype(&self) -> ast::DataType {
ast::DataType::Text
}
fn supports_column_alias_in_table_alias(&self) -> bool {
false
}
fn scalar_function_to_sql_overrides(
&self,
unparser: &Unparser,
func_name: &str,
args: &[Expr],
) -> Result<Option<ast::Expr>> {
if func_name == "date_part" {
return date_part_to_sql(unparser, self.date_field_extract_style(), args);
}
Ok(None)
}
}
pub struct CustomDialect {
identifier_quote_style: Option<char>,
supports_nulls_first_in_sort: bool,
use_timestamp_for_date64: bool,
interval_style: IntervalStyle,
float64_ast_dtype: ast::DataType,
utf8_cast_dtype: ast::DataType,
large_utf8_cast_dtype: ast::DataType,
date_field_extract_style: DateFieldExtractStyle,
int64_cast_dtype: ast::DataType,
int32_cast_dtype: ast::DataType,
timestamp_cast_dtype: ast::DataType,
timestamp_tz_cast_dtype: ast::DataType,
date32_cast_dtype: ast::DataType,
supports_column_alias_in_table_alias: bool,
requires_derived_table_alias: bool,
}
impl Default for CustomDialect {
fn default() -> Self {
Self {
identifier_quote_style: None,
supports_nulls_first_in_sort: true,
use_timestamp_for_date64: false,
interval_style: IntervalStyle::SQLStandard,
float64_ast_dtype: ast::DataType::Double,
utf8_cast_dtype: ast::DataType::Varchar(None),
large_utf8_cast_dtype: ast::DataType::Text,
date_field_extract_style: DateFieldExtractStyle::DatePart,
int64_cast_dtype: ast::DataType::BigInt(None),
int32_cast_dtype: ast::DataType::Integer(None),
timestamp_cast_dtype: ast::DataType::Timestamp(None, TimezoneInfo::None),
timestamp_tz_cast_dtype: ast::DataType::Timestamp(
None,
TimezoneInfo::WithTimeZone,
),
date32_cast_dtype: ast::DataType::Date,
supports_column_alias_in_table_alias: true,
requires_derived_table_alias: false,
}
}
}
impl CustomDialect {
#[deprecated(note = "please use `CustomDialectBuilder` instead")]
pub fn new(identifier_quote_style: Option<char>) -> Self {
Self {
identifier_quote_style,
..Default::default()
}
}
}
impl Dialect for CustomDialect {
fn identifier_quote_style(&self, _: &str) -> Option<char> {
self.identifier_quote_style
}
fn supports_nulls_first_in_sort(&self) -> bool {
self.supports_nulls_first_in_sort
}
fn use_timestamp_for_date64(&self) -> bool {
self.use_timestamp_for_date64
}
fn interval_style(&self) -> IntervalStyle {
self.interval_style
}
fn float64_ast_dtype(&self) -> ast::DataType {
self.float64_ast_dtype.clone()
}
fn utf8_cast_dtype(&self) -> ast::DataType {
self.utf8_cast_dtype.clone()
}
fn large_utf8_cast_dtype(&self) -> ast::DataType {
self.large_utf8_cast_dtype.clone()
}
fn date_field_extract_style(&self) -> DateFieldExtractStyle {
self.date_field_extract_style
}
fn int64_cast_dtype(&self) -> ast::DataType {
self.int64_cast_dtype.clone()
}
fn int32_cast_dtype(&self) -> ast::DataType {
self.int32_cast_dtype.clone()
}
fn timestamp_cast_dtype(
&self,
_time_unit: &TimeUnit,
tz: &Option<Arc<str>>,
) -> ast::DataType {
if tz.is_some() {
self.timestamp_tz_cast_dtype.clone()
} else {
self.timestamp_cast_dtype.clone()
}
}
fn date32_cast_dtype(&self) -> ast::DataType {
self.date32_cast_dtype.clone()
}
fn supports_column_alias_in_table_alias(&self) -> bool {
self.supports_column_alias_in_table_alias
}
fn scalar_function_to_sql_overrides(
&self,
unparser: &Unparser,
func_name: &str,
args: &[Expr],
) -> Result<Option<ast::Expr>> {
if func_name == "date_part" {
return date_part_to_sql(unparser, self.date_field_extract_style(), args);
}
Ok(None)
}
fn requires_derived_table_alias(&self) -> bool {
self.requires_derived_table_alias
}
}
pub struct CustomDialectBuilder {
identifier_quote_style: Option<char>,
supports_nulls_first_in_sort: bool,
use_timestamp_for_date64: bool,
interval_style: IntervalStyle,
float64_ast_dtype: ast::DataType,
utf8_cast_dtype: ast::DataType,
large_utf8_cast_dtype: ast::DataType,
date_field_extract_style: DateFieldExtractStyle,
int64_cast_dtype: ast::DataType,
int32_cast_dtype: ast::DataType,
timestamp_cast_dtype: ast::DataType,
timestamp_tz_cast_dtype: ast::DataType,
date32_cast_dtype: ast::DataType,
supports_column_alias_in_table_alias: bool,
requires_derived_table_alias: bool,
}
impl Default for CustomDialectBuilder {
fn default() -> Self {
Self::new()
}
}
impl CustomDialectBuilder {
pub fn new() -> Self {
Self {
identifier_quote_style: None,
supports_nulls_first_in_sort: true,
use_timestamp_for_date64: false,
interval_style: IntervalStyle::PostgresVerbose,
float64_ast_dtype: ast::DataType::Double,
utf8_cast_dtype: ast::DataType::Varchar(None),
large_utf8_cast_dtype: ast::DataType::Text,
date_field_extract_style: DateFieldExtractStyle::DatePart,
int64_cast_dtype: ast::DataType::BigInt(None),
int32_cast_dtype: ast::DataType::Integer(None),
timestamp_cast_dtype: ast::DataType::Timestamp(None, TimezoneInfo::None),
timestamp_tz_cast_dtype: ast::DataType::Timestamp(
None,
TimezoneInfo::WithTimeZone,
),
date32_cast_dtype: ast::DataType::Date,
supports_column_alias_in_table_alias: true,
requires_derived_table_alias: false,
}
}
pub fn build(self) -> CustomDialect {
CustomDialect {
identifier_quote_style: self.identifier_quote_style,
supports_nulls_first_in_sort: self.supports_nulls_first_in_sort,
use_timestamp_for_date64: self.use_timestamp_for_date64,
interval_style: self.interval_style,
float64_ast_dtype: self.float64_ast_dtype,
utf8_cast_dtype: self.utf8_cast_dtype,
large_utf8_cast_dtype: self.large_utf8_cast_dtype,
date_field_extract_style: self.date_field_extract_style,
int64_cast_dtype: self.int64_cast_dtype,
int32_cast_dtype: self.int32_cast_dtype,
timestamp_cast_dtype: self.timestamp_cast_dtype,
timestamp_tz_cast_dtype: self.timestamp_tz_cast_dtype,
date32_cast_dtype: self.date32_cast_dtype,
supports_column_alias_in_table_alias: self
.supports_column_alias_in_table_alias,
requires_derived_table_alias: self.requires_derived_table_alias,
}
}
pub fn with_identifier_quote_style(mut self, identifier_quote_style: char) -> Self {
self.identifier_quote_style = Some(identifier_quote_style);
self
}
pub fn with_supports_nulls_first_in_sort(
mut self,
supports_nulls_first_in_sort: bool,
) -> Self {
self.supports_nulls_first_in_sort = supports_nulls_first_in_sort;
self
}
pub fn with_use_timestamp_for_date64(
mut self,
use_timestamp_for_date64: bool,
) -> Self {
self.use_timestamp_for_date64 = use_timestamp_for_date64;
self
}
pub fn with_interval_style(mut self, interval_style: IntervalStyle) -> Self {
self.interval_style = interval_style;
self
}
pub fn with_float64_ast_dtype(mut self, float64_ast_dtype: ast::DataType) -> Self {
self.float64_ast_dtype = float64_ast_dtype;
self
}
pub fn with_utf8_cast_dtype(mut self, utf8_cast_dtype: ast::DataType) -> Self {
self.utf8_cast_dtype = utf8_cast_dtype;
self
}
pub fn with_large_utf8_cast_dtype(
mut self,
large_utf8_cast_dtype: ast::DataType,
) -> Self {
self.large_utf8_cast_dtype = large_utf8_cast_dtype;
self
}
pub fn with_date_field_extract_style(
mut self,
date_field_extract_style: DateFieldExtractStyle,
) -> Self {
self.date_field_extract_style = date_field_extract_style;
self
}
pub fn with_int64_cast_dtype(mut self, int64_cast_dtype: ast::DataType) -> Self {
self.int64_cast_dtype = int64_cast_dtype;
self
}
pub fn with_int32_cast_dtype(mut self, int32_cast_dtype: ast::DataType) -> Self {
self.int32_cast_dtype = int32_cast_dtype;
self
}
pub fn with_timestamp_cast_dtype(
mut self,
timestamp_cast_dtype: ast::DataType,
timestamp_tz_cast_dtype: ast::DataType,
) -> Self {
self.timestamp_cast_dtype = timestamp_cast_dtype;
self.timestamp_tz_cast_dtype = timestamp_tz_cast_dtype;
self
}
pub fn with_date32_cast_dtype(mut self, date32_cast_dtype: ast::DataType) -> Self {
self.date32_cast_dtype = date32_cast_dtype;
self
}
pub fn with_supports_column_alias_in_table_alias(
mut self,
supports_column_alias_in_table_alias: bool,
) -> Self {
self.supports_column_alias_in_table_alias = supports_column_alias_in_table_alias;
self
}
pub fn with_requires_derived_table_alias(
mut self,
requires_derived_table_alias: bool,
) -> Self {
self.requires_derived_table_alias = requires_derived_table_alias;
self
}
}