lance_datafusion/
sql.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The Lance Authors

//! SQL Parser utility

use datafusion::sql::sqlparser::{
    ast::{Expr, SelectItem, SetExpr, Statement},
    dialect::{Dialect, GenericDialect},
    parser::Parser,
    tokenizer::{Token, Tokenizer},
};

use lance_core::{Error, Result};
use snafu::{location, Location};
#[derive(Debug, Default)]
struct LanceDialect(GenericDialect);

impl LanceDialect {
    fn new() -> Self {
        Self(GenericDialect {})
    }
}

impl Dialect for LanceDialect {
    fn is_identifier_start(&self, ch: char) -> bool {
        self.0.is_identifier_start(ch)
    }

    fn is_identifier_part(&self, ch: char) -> bool {
        self.0.is_identifier_part(ch)
    }

    fn is_delimited_identifier_start(&self, ch: char) -> bool {
        ch == '`'
    }
}

/// Parse sql filter to Expression.
pub(crate) fn parse_sql_filter(filter: &str) -> Result<Expr> {
    let sql = format!("SELECT 1 FROM t WHERE {filter}");
    let statement = parse_statement(&sql)?;

    let selection = if let Statement::Query(query) = &statement {
        if let SetExpr::Select(s) = query.body.as_ref() {
            s.selection.as_ref()
        } else {
            None
        }
    } else {
        None
    };
    let expr = selection
        .ok_or_else(|| Error::io(format!("Filter is not valid: {filter}"), location!()))?;
    Ok(expr.clone())
}

/// Parse a SQL expression to Expression. This is more lenient than parse_sql_filter
/// as it can be used for projection expressions as well.
pub(crate) fn parse_sql_expr(expr: &str) -> Result<Expr> {
    let sql = format!("SELECT {expr} FROM t");
    let statement = parse_statement(&sql)?;

    let selection = if let Statement::Query(query) = &statement {
        if let SetExpr::Select(s) = query.body.as_ref() {
            if let SelectItem::UnnamedExpr(expr) = &s.projection[0] {
                Some(expr)
            } else {
                None
            }
        } else {
            None
        }
    } else {
        None
    };
    let expr = selection
        .ok_or_else(|| Error::io(format!("Expression is not valid: {expr}"), location!()))?;
    Ok(expr.clone())
}

fn parse_statement(statement: &str) -> Result<Statement> {
    let dialect = LanceDialect::new();

    // Hack to allow == as equals
    // This is used to parse PyArrow expressions from strings.
    // See: https://github.com/sqlparser-rs/sqlparser-rs/pull/815#issuecomment-1450714278
    let mut tokenizer = Tokenizer::new(&dialect, statement);
    let mut tokens = Vec::new();
    let mut token_iter = tokenizer.tokenize()?.into_iter();
    let mut prev_token = token_iter.next().unwrap();
    for next_token in token_iter {
        if let (Token::Eq, Token::Eq) = (&prev_token, &next_token) {
            continue; // skip second equals
        }
        let token = std::mem::replace(&mut prev_token, next_token);
        tokens.push(token);
    }
    tokens.push(prev_token);

    Ok(Parser::new(&dialect)
        .with_tokens(tokens)
        .parse_statement()?)
}

#[cfg(test)]
mod tests {
    use super::*;

    use datafusion::sql::sqlparser::ast::{BinaryOperator, Ident, Value};

    #[test]
    fn test_double_equal() {
        let expr = parse_sql_filter("a == b").unwrap();
        assert_eq!(
            Expr::BinaryOp {
                left: Box::new(Expr::Identifier(Ident::new("a"))),
                op: BinaryOperator::Eq,
                right: Box::new(Expr::Identifier(Ident::new("b")))
            },
            expr
        );
    }

    #[test]
    fn test_like() {
        let expr = parse_sql_filter("a LIKE 'abc%'").unwrap();
        assert_eq!(
            Expr::Like {
                negated: false,
                expr: Box::new(Expr::Identifier(Ident::new("a"))),
                pattern: Box::new(Expr::Value(Value::SingleQuotedString("abc%".to_string()))),
                escape_char: None
            },
            expr
        );
    }

    #[test]
    fn test_quoted_ident() {
        // CUBE is a SQL keyword, so it must be quoted.
        let expr = parse_sql_filter("`a:Test_Something` == `CUBE`").unwrap();
        assert_eq!(
            Expr::BinaryOp {
                left: Box::new(Expr::Identifier(Ident::with_quote('`', "a:Test_Something"))),
                op: BinaryOperator::Eq,
                right: Box::new(Expr::Identifier(Ident::with_quote('`', "CUBE")))
            },
            expr
        );

        let expr = parse_sql_filter("`outer field`.`inner field` == 1").unwrap();
        assert_eq!(
            Expr::BinaryOp {
                left: Box::new(Expr::CompoundIdentifier(vec![
                    Ident::with_quote('`', "outer field"),
                    Ident::with_quote('`', "inner field")
                ])),
                op: BinaryOperator::Eq,
                right: Box::new(Expr::Value(Value::Number("1".to_string(), false))),
            },
            expr
        );
    }
}