datafusion_physical_expr/expressions/
not.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Not expression
19
20use std::any::Any;
21use std::fmt;
22use std::hash::Hash;
23use std::sync::Arc;
24
25use crate::PhysicalExpr;
26
27use arrow::datatypes::{DataType, Schema};
28use arrow::record_batch::RecordBatch;
29use datafusion_common::{cast::as_boolean_array, internal_err, Result, ScalarValue};
30use datafusion_expr::interval_arithmetic::Interval;
31use datafusion_expr::statistics::Distribution::{self, Bernoulli};
32use datafusion_expr::ColumnarValue;
33
34/// Not expression
35#[derive(Debug, Eq)]
36pub struct NotExpr {
37    /// Input expression
38    arg: Arc<dyn PhysicalExpr>,
39}
40
41// Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808
42impl PartialEq for NotExpr {
43    fn eq(&self, other: &Self) -> bool {
44        self.arg.eq(&other.arg)
45    }
46}
47
48impl Hash for NotExpr {
49    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
50        self.arg.hash(state);
51    }
52}
53
54impl NotExpr {
55    /// Create new not expression
56    pub fn new(arg: Arc<dyn PhysicalExpr>) -> Self {
57        Self { arg }
58    }
59
60    /// Get the input expression
61    pub fn arg(&self) -> &Arc<dyn PhysicalExpr> {
62        &self.arg
63    }
64}
65
66impl fmt::Display for NotExpr {
67    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
68        write!(f, "NOT {}", self.arg)
69    }
70}
71
72impl PhysicalExpr for NotExpr {
73    /// Return a reference to Any that can be used for downcasting
74    fn as_any(&self) -> &dyn Any {
75        self
76    }
77
78    fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
79        Ok(DataType::Boolean)
80    }
81
82    fn nullable(&self, input_schema: &Schema) -> Result<bool> {
83        self.arg.nullable(input_schema)
84    }
85
86    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
87        match self.arg.evaluate(batch)? {
88            ColumnarValue::Array(array) => {
89                let array = as_boolean_array(&array)?;
90                Ok(ColumnarValue::Array(Arc::new(
91                    arrow::compute::kernels::boolean::not(array)?,
92                )))
93            }
94            ColumnarValue::Scalar(scalar) => {
95                if scalar.is_null() {
96                    return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None)));
97                }
98                let bool_value: bool = scalar.try_into()?;
99                Ok(ColumnarValue::Scalar(ScalarValue::from(!bool_value)))
100            }
101        }
102    }
103
104    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
105        vec![&self.arg]
106    }
107
108    fn with_new_children(
109        self: Arc<Self>,
110        children: Vec<Arc<dyn PhysicalExpr>>,
111    ) -> Result<Arc<dyn PhysicalExpr>> {
112        Ok(Arc::new(NotExpr::new(Arc::clone(&children[0]))))
113    }
114
115    fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
116        children[0].not()
117    }
118
119    fn propagate_constraints(
120        &self,
121        interval: &Interval,
122        children: &[&Interval],
123    ) -> Result<Option<Vec<Interval>>> {
124        let complemented_interval = interval.not()?;
125
126        Ok(children[0]
127            .intersect(complemented_interval)?
128            .map(|result| vec![result]))
129    }
130
131    fn evaluate_statistics(&self, children: &[&Distribution]) -> Result<Distribution> {
132        match children[0] {
133            Bernoulli(b) => {
134                let p_value = b.p_value();
135                if p_value.is_null() {
136                    Ok(children[0].clone())
137                } else {
138                    let one = ScalarValue::new_one(&p_value.data_type())?;
139                    Distribution::new_bernoulli(one.sub_checked(p_value)?)
140                }
141            }
142            _ => internal_err!("NotExpr can only operate on Boolean datatypes"),
143        }
144    }
145
146    fn propagate_statistics(
147        &self,
148        parent: &Distribution,
149        children: &[&Distribution],
150    ) -> Result<Option<Vec<Distribution>>> {
151        match (parent, children[0]) {
152            (Bernoulli(parent), Bernoulli(child)) => {
153                let parent_range = parent.range();
154                let result = if parent_range == Interval::CERTAINLY_TRUE {
155                    if child.range() == Interval::CERTAINLY_TRUE {
156                        None
157                    } else {
158                        Some(vec![Distribution::new_bernoulli(ScalarValue::new_zero(
159                            &child.data_type(),
160                        )?)?])
161                    }
162                } else if parent_range == Interval::CERTAINLY_FALSE {
163                    if child.range() == Interval::CERTAINLY_FALSE {
164                        None
165                    } else {
166                        Some(vec![Distribution::new_bernoulli(ScalarValue::new_one(
167                            &child.data_type(),
168                        )?)?])
169                    }
170                } else {
171                    Some(vec![])
172                };
173                Ok(result)
174            }
175            _ => internal_err!("NotExpr can only operate on Boolean datatypes"),
176        }
177    }
178}
179
180/// Creates a unary expression NOT
181pub fn not(arg: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>> {
182    Ok(Arc::new(NotExpr::new(arg)))
183}
184
185#[cfg(test)]
186mod tests {
187    use std::sync::LazyLock;
188
189    use super::*;
190    use crate::expressions::{col, Column};
191
192    use arrow::{array::BooleanArray, datatypes::*};
193
194    #[test]
195    fn neg_op() -> Result<()> {
196        let schema = schema();
197
198        let expr = not(col("a", &schema)?)?;
199        assert_eq!(expr.data_type(&schema)?, DataType::Boolean);
200        assert!(expr.nullable(&schema)?);
201
202        let input = BooleanArray::from(vec![Some(true), None, Some(false)]);
203        let expected = &BooleanArray::from(vec![Some(false), None, Some(true)]);
204
205        let batch = RecordBatch::try_new(schema, vec![Arc::new(input)])?;
206
207        let result = expr
208            .evaluate(&batch)?
209            .into_array(batch.num_rows())
210            .expect("Failed to convert to array");
211        let result =
212            as_boolean_array(&result).expect("failed to downcast to BooleanArray");
213        assert_eq!(result, expected);
214
215        Ok(())
216    }
217
218    #[test]
219    fn test_evaluate_bounds() -> Result<()> {
220        // Note that `None` for boolean intervals is converted to `Some(false)`
221        // / `Some(true)` by `Interval::make`, so it is not explicitly tested
222        // here
223
224        // if the bounds are all booleans (false, true) so is the negation
225        assert_evaluate_bounds(
226            Interval::make(Some(false), Some(true))?,
227            Interval::make(Some(false), Some(true))?,
228        )?;
229        // (true, false) is not tested because it is not a valid interval (lower
230        // bound is greater than upper bound)
231        assert_evaluate_bounds(
232            Interval::make(Some(true), Some(true))?,
233            Interval::make(Some(false), Some(false))?,
234        )?;
235        assert_evaluate_bounds(
236            Interval::make(Some(false), Some(false))?,
237            Interval::make(Some(true), Some(true))?,
238        )?;
239        Ok(())
240    }
241
242    fn assert_evaluate_bounds(
243        interval: Interval,
244        expected_interval: Interval,
245    ) -> Result<()> {
246        let not_expr = not(col("a", &schema())?)?;
247        assert_eq!(not_expr.evaluate_bounds(&[&interval])?, expected_interval);
248        Ok(())
249    }
250
251    #[test]
252    fn test_evaluate_statistics() -> Result<()> {
253        let _schema = &Schema::new(vec![Field::new("a", DataType::Boolean, false)]);
254        let a = Arc::new(Column::new("a", 0)) as _;
255        let expr = not(a)?;
256
257        // Uniform with non-boolean bounds
258        assert!(expr
259            .evaluate_statistics(&[&Distribution::new_uniform(
260                Interval::make_unbounded(&DataType::Float64)?
261            )?])
262            .is_err());
263
264        // Exponential
265        assert!(expr
266            .evaluate_statistics(&[&Distribution::new_exponential(
267                ScalarValue::from(1.0),
268                ScalarValue::from(1.0),
269                true
270            )?])
271            .is_err());
272
273        // Gaussian
274        assert!(expr
275            .evaluate_statistics(&[&Distribution::new_gaussian(
276                ScalarValue::from(1.0),
277                ScalarValue::from(1.0),
278            )?])
279            .is_err());
280
281        // Bernoulli
282        assert_eq!(
283            expr.evaluate_statistics(&[&Distribution::new_bernoulli(
284                ScalarValue::from(0.0),
285            )?])?,
286            Distribution::new_bernoulli(ScalarValue::from(1.))?
287        );
288
289        assert_eq!(
290            expr.evaluate_statistics(&[&Distribution::new_bernoulli(
291                ScalarValue::from(1.0),
292            )?])?,
293            Distribution::new_bernoulli(ScalarValue::from(0.))?
294        );
295
296        assert_eq!(
297            expr.evaluate_statistics(&[&Distribution::new_bernoulli(
298                ScalarValue::from(0.25),
299            )?])?,
300            Distribution::new_bernoulli(ScalarValue::from(0.75))?
301        );
302
303        assert!(expr
304            .evaluate_statistics(&[&Distribution::new_generic(
305                ScalarValue::Null,
306                ScalarValue::Null,
307                ScalarValue::Null,
308                Interval::make_unbounded(&DataType::UInt8)?
309            )?])
310            .is_err());
311
312        // Unknown with non-boolean interval as range
313        assert!(expr
314            .evaluate_statistics(&[&Distribution::new_generic(
315                ScalarValue::Null,
316                ScalarValue::Null,
317                ScalarValue::Null,
318                Interval::make_unbounded(&DataType::Float64)?
319            )?])
320            .is_err());
321
322        Ok(())
323    }
324
325    fn schema() -> SchemaRef {
326        static SCHEMA: LazyLock<SchemaRef> = LazyLock::new(|| {
327            Arc::new(Schema::new(vec![Field::new("a", DataType::Boolean, true)]))
328        });
329        Arc::clone(&SCHEMA)
330    }
331}