lance_datafusion/
sql.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! SQL Parser utility
5
6use datafusion::sql::sqlparser::{
7    ast::{Expr, SelectItem, SetExpr, Statement},
8    dialect::{Dialect, GenericDialect},
9    parser::Parser,
10    tokenizer::{Token, Tokenizer},
11};
12
13use lance_core::{Error, Result};
14use snafu::location;
15#[derive(Debug, Default)]
16struct LanceDialect(GenericDialect);
17
18impl LanceDialect {
19    fn new() -> Self {
20        Self(GenericDialect {})
21    }
22}
23
24impl Dialect for LanceDialect {
25    fn is_identifier_start(&self, ch: char) -> bool {
26        self.0.is_identifier_start(ch)
27    }
28
29    fn is_identifier_part(&self, ch: char) -> bool {
30        self.0.is_identifier_part(ch)
31    }
32
33    fn is_delimited_identifier_start(&self, ch: char) -> bool {
34        ch == '`'
35    }
36}
37
38/// Parse sql filter to Expression.
39pub(crate) fn parse_sql_filter(filter: &str) -> Result<Expr> {
40    let sql = format!("SELECT 1 FROM t WHERE {filter}");
41    let statement = parse_statement(&sql)?;
42
43    let selection = if let Statement::Query(query) = &statement {
44        if let SetExpr::Select(s) = query.body.as_ref() {
45            s.selection.as_ref()
46        } else {
47            None
48        }
49    } else {
50        None
51    };
52    let expr = selection
53        .ok_or_else(|| Error::io(format!("Filter is not valid: {filter}"), location!()))?;
54    Ok(expr.clone())
55}
56
57/// Parse a SQL expression to Expression. This is more lenient than parse_sql_filter
58/// as it can be used for projection expressions as well.
59pub(crate) fn parse_sql_expr(expr: &str) -> Result<Expr> {
60    let sql = format!("SELECT {expr} FROM t");
61    let statement = parse_statement(&sql)?;
62
63    let selection = if let Statement::Query(query) = &statement {
64        if let SetExpr::Select(s) = query.body.as_ref() {
65            if let SelectItem::UnnamedExpr(expr) = &s.projection[0] {
66                Some(expr)
67            } else {
68                None
69            }
70        } else {
71            None
72        }
73    } else {
74        None
75    };
76    let expr = selection
77        .ok_or_else(|| Error::io(format!("Expression is not valid: {expr}"), location!()))?;
78    Ok(expr.clone())
79}
80
81fn parse_statement(statement: &str) -> Result<Statement> {
82    let dialect = LanceDialect::new();
83
84    // Hack to allow == as equals
85    // This is used to parse PyArrow expressions from strings.
86    // See: https://github.com/sqlparser-rs/sqlparser-rs/pull/815#issuecomment-1450714278
87    let mut tokenizer = Tokenizer::new(&dialect, statement);
88    let mut tokens = Vec::new();
89    let mut token_iter = tokenizer.tokenize()?.into_iter();
90    let mut prev_token = token_iter.next().unwrap();
91    for next_token in token_iter {
92        if let (Token::Eq, Token::Eq) = (&prev_token, &next_token) {
93            continue; // skip second equals
94        }
95        let token = std::mem::replace(&mut prev_token, next_token);
96        tokens.push(token);
97    }
98    tokens.push(prev_token);
99
100    Ok(Parser::new(&dialect)
101        .with_tokens(tokens)
102        .parse_statement()?)
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108
109    use datafusion::sql::sqlparser::ast::{BinaryOperator, Ident, Value};
110
111    #[test]
112    fn test_double_equal() {
113        let expr = parse_sql_filter("a == b").unwrap();
114        assert_eq!(
115            Expr::BinaryOp {
116                left: Box::new(Expr::Identifier(Ident::new("a"))),
117                op: BinaryOperator::Eq,
118                right: Box::new(Expr::Identifier(Ident::new("b")))
119            },
120            expr
121        );
122    }
123
124    #[test]
125    fn test_like() {
126        let expr = parse_sql_filter("a LIKE 'abc%'").unwrap();
127        assert_eq!(
128            Expr::Like {
129                negated: false,
130                expr: Box::new(Expr::Identifier(Ident::new("a"))),
131                pattern: Box::new(Expr::Value(Value::SingleQuotedString("abc%".to_string()))),
132                escape_char: None,
133                any: false,
134            },
135            expr
136        );
137    }
138
139    #[test]
140    fn test_quoted_ident() {
141        // CUBE is a SQL keyword, so it must be quoted.
142        let expr = parse_sql_filter("`a:Test_Something` == `CUBE`").unwrap();
143        assert_eq!(
144            Expr::BinaryOp {
145                left: Box::new(Expr::Identifier(Ident::with_quote('`', "a:Test_Something"))),
146                op: BinaryOperator::Eq,
147                right: Box::new(Expr::Identifier(Ident::with_quote('`', "CUBE")))
148            },
149            expr
150        );
151
152        let expr = parse_sql_filter("`outer field`.`inner field` == 1").unwrap();
153        assert_eq!(
154            Expr::BinaryOp {
155                left: Box::new(Expr::CompoundIdentifier(vec![
156                    Ident::with_quote('`', "outer field"),
157                    Ident::with_quote('`', "inner field")
158                ])),
159                op: BinaryOperator::Eq,
160                right: Box::new(Expr::Value(Value::Number("1".to_string(), false))),
161            },
162            expr
163        );
164    }
165}