1use datafusion::sql::sqlparser::{
7 ast::{Expr, SelectItem, SetExpr, Statement},
8 dialect::{Dialect, GenericDialect},
9 parser::Parser,
10 tokenizer::{Token, Tokenizer},
11};
12
13use lance_core::{Error, Result};
14use snafu::location;
15#[derive(Debug, Default)]
16struct LanceDialect(GenericDialect);
17
18impl LanceDialect {
19 fn new() -> Self {
20 Self(GenericDialect {})
21 }
22}
23
24impl Dialect for LanceDialect {
25 fn is_identifier_start(&self, ch: char) -> bool {
26 self.0.is_identifier_start(ch)
27 }
28
29 fn is_identifier_part(&self, ch: char) -> bool {
30 self.0.is_identifier_part(ch)
31 }
32
33 fn is_delimited_identifier_start(&self, ch: char) -> bool {
34 ch == '`'
35 }
36}
37
38pub(crate) fn parse_sql_filter(filter: &str) -> Result<Expr> {
40 let sql = format!("SELECT 1 FROM t WHERE {filter}");
41 let statement = parse_statement(&sql)?;
42
43 let selection = if let Statement::Query(query) = &statement {
44 if let SetExpr::Select(s) = query.body.as_ref() {
45 s.selection.as_ref()
46 } else {
47 None
48 }
49 } else {
50 None
51 };
52 let expr = selection
53 .ok_or_else(|| Error::io(format!("Filter is not valid: {filter}"), location!()))?;
54 Ok(expr.clone())
55}
56
57pub(crate) fn parse_sql_expr(expr: &str) -> Result<Expr> {
60 let sql = format!("SELECT {expr} FROM t");
61 let statement = parse_statement(&sql)?;
62
63 let selection = if let Statement::Query(query) = &statement {
64 if let SetExpr::Select(s) = query.body.as_ref() {
65 if let SelectItem::UnnamedExpr(expr) = &s.projection[0] {
66 Some(expr)
67 } else {
68 None
69 }
70 } else {
71 None
72 }
73 } else {
74 None
75 };
76 let expr = selection
77 .ok_or_else(|| Error::io(format!("Expression is not valid: {expr}"), location!()))?;
78 Ok(expr.clone())
79}
80
81fn parse_statement(statement: &str) -> Result<Statement> {
82 let dialect = LanceDialect::new();
83
84 let mut tokenizer = Tokenizer::new(&dialect, statement);
88 let mut tokens = Vec::new();
89 let mut token_iter = tokenizer.tokenize()?.into_iter();
90 let mut prev_token = token_iter.next().unwrap();
91 for next_token in token_iter {
92 if let (Token::Eq, Token::Eq) = (&prev_token, &next_token) {
93 continue; }
95 let token = std::mem::replace(&mut prev_token, next_token);
96 tokens.push(token);
97 }
98 tokens.push(prev_token);
99
100 Ok(Parser::new(&dialect)
101 .with_tokens(tokens)
102 .parse_statement()?)
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108
109 use datafusion::sql::sqlparser::ast::{BinaryOperator, Ident, Value};
110
111 #[test]
112 fn test_double_equal() {
113 let expr = parse_sql_filter("a == b").unwrap();
114 assert_eq!(
115 Expr::BinaryOp {
116 left: Box::new(Expr::Identifier(Ident::new("a"))),
117 op: BinaryOperator::Eq,
118 right: Box::new(Expr::Identifier(Ident::new("b")))
119 },
120 expr
121 );
122 }
123
124 #[test]
125 fn test_like() {
126 let expr = parse_sql_filter("a LIKE 'abc%'").unwrap();
127 assert_eq!(
128 Expr::Like {
129 negated: false,
130 expr: Box::new(Expr::Identifier(Ident::new("a"))),
131 pattern: Box::new(Expr::Value(Value::SingleQuotedString("abc%".to_string()))),
132 escape_char: None,
133 any: false,
134 },
135 expr
136 );
137 }
138
139 #[test]
140 fn test_quoted_ident() {
141 let expr = parse_sql_filter("`a:Test_Something` == `CUBE`").unwrap();
143 assert_eq!(
144 Expr::BinaryOp {
145 left: Box::new(Expr::Identifier(Ident::with_quote('`', "a:Test_Something"))),
146 op: BinaryOperator::Eq,
147 right: Box::new(Expr::Identifier(Ident::with_quote('`', "CUBE")))
148 },
149 expr
150 );
151
152 let expr = parse_sql_filter("`outer field`.`inner field` == 1").unwrap();
153 assert_eq!(
154 Expr::BinaryOp {
155 left: Box::new(Expr::CompoundIdentifier(vec![
156 Ident::with_quote('`', "outer field"),
157 Ident::with_quote('`', "inner field")
158 ])),
159 op: BinaryOperator::Eq,
160 right: Box::new(Expr::Value(Value::Number("1".to_string(), false))),
161 },
162 expr
163 );
164 }
165}