datafusion_functions/math/
log.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Math function: `log()`.
19
20use std::any::Any;
21use std::sync::Arc;
22
23use super::power::PowerFunc;
24
25use arrow::array::{ArrayRef, AsArray};
26use arrow::datatypes::{DataType, Float32Type, Float64Type};
27use datafusion_common::{
28    exec_err, internal_err, plan_datafusion_err, plan_err, Result, ScalarValue,
29};
30use datafusion_expr::expr::ScalarFunction;
31use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
32use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
33use datafusion_expr::{
34    lit, ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDF,
35    TypeSignature::*,
36};
37use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
38use datafusion_macros::user_doc;
39
40#[user_doc(
41    doc_section(label = "Math Functions"),
42    description = "Returns the base-x logarithm of a number. Can either provide a specified base, or if omitted then takes the base-10 of a number.",
43    syntax_example = r#"log(base, numeric_expression)
44log(numeric_expression)"#,
45    standard_argument(name = "base", prefix = "Base numeric"),
46    standard_argument(name = "numeric_expression", prefix = "Numeric")
47)]
48#[derive(Debug)]
49pub struct LogFunc {
50    signature: Signature,
51}
52
53impl Default for LogFunc {
54    fn default() -> Self {
55        Self::new()
56    }
57}
58
59impl LogFunc {
60    pub fn new() -> Self {
61        use DataType::*;
62        Self {
63            signature: Signature::one_of(
64                vec![
65                    Exact(vec![Float32]),
66                    Exact(vec![Float64]),
67                    Exact(vec![Float32, Float32]),
68                    Exact(vec![Float64, Float64]),
69                ],
70                Volatility::Immutable,
71            ),
72        }
73    }
74}
75
76impl ScalarUDFImpl for LogFunc {
77    fn as_any(&self) -> &dyn Any {
78        self
79    }
80    fn name(&self) -> &str {
81        "log"
82    }
83
84    fn signature(&self) -> &Signature {
85        &self.signature
86    }
87
88    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
89        match &arg_types[0] {
90            DataType::Float32 => Ok(DataType::Float32),
91            _ => Ok(DataType::Float64),
92        }
93    }
94
95    fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
96        let (base_sort_properties, num_sort_properties) = if input.len() == 1 {
97            // log(x) defaults to log(10, x)
98            (SortProperties::Singleton, input[0].sort_properties)
99        } else {
100            (input[0].sort_properties, input[1].sort_properties)
101        };
102        match (num_sort_properties, base_sort_properties) {
103            (first @ SortProperties::Ordered(num), SortProperties::Ordered(base))
104                if num.descending != base.descending
105                    && num.nulls_first == base.nulls_first =>
106            {
107                Ok(first)
108            }
109            (
110                first @ (SortProperties::Ordered(_) | SortProperties::Singleton),
111                SortProperties::Singleton,
112            ) => Ok(first),
113            (SortProperties::Singleton, second @ SortProperties::Ordered(_)) => {
114                Ok(-second)
115            }
116            _ => Ok(SortProperties::Unordered),
117        }
118    }
119
120    // Support overloaded log(base, x) and log(x) which defaults to log(10, x)
121    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
122        let args = ColumnarValue::values_to_arrays(&args.args)?;
123
124        let mut base = ColumnarValue::Scalar(ScalarValue::Float32(Some(10.0)));
125
126        let mut x = &args[0];
127        if args.len() == 2 {
128            x = &args[1];
129            base = ColumnarValue::Array(Arc::clone(&args[0]));
130        }
131        // note in f64::log params order is different than in sql. e.g in sql log(base, x) == f64::log(x, base)
132        let arr: ArrayRef = match args[0].data_type() {
133            DataType::Float64 => match base {
134                ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => {
135                    Arc::new(x.as_primitive::<Float64Type>().unary::<_, Float64Type>(
136                        |value: f64| f64::log(value, base as f64),
137                    ))
138                }
139                ColumnarValue::Array(base) => {
140                    let x = x.as_primitive::<Float64Type>();
141                    let base = base.as_primitive::<Float64Type>();
142                    let result = arrow::compute::binary::<_, _, _, Float64Type>(
143                        x,
144                        base,
145                        f64::log,
146                    )?;
147                    Arc::new(result) as _
148                }
149                _ => {
150                    return exec_err!("log function requires a scalar or array for base")
151                }
152            },
153
154            DataType::Float32 => match base {
155                ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => Arc::new(
156                    x.as_primitive::<Float32Type>()
157                        .unary::<_, Float32Type>(|value: f32| f32::log(value, base)),
158                ),
159                ColumnarValue::Array(base) => {
160                    let x = x.as_primitive::<Float32Type>();
161                    let base = base.as_primitive::<Float32Type>();
162                    let result = arrow::compute::binary::<_, _, _, Float32Type>(
163                        x,
164                        base,
165                        f32::log,
166                    )?;
167                    Arc::new(result) as _
168                }
169                _ => {
170                    return exec_err!("log function requires a scalar or array for base")
171                }
172            },
173            other => {
174                return exec_err!("Unsupported data type {other:?} for function log")
175            }
176        };
177
178        Ok(ColumnarValue::Array(arr))
179    }
180
181    fn documentation(&self) -> Option<&Documentation> {
182        self.doc()
183    }
184
185    /// Simplify the `log` function by the relevant rules:
186    /// 1. Log(a, 1) ===> 0
187    /// 2. Log(a, Power(a, b)) ===> b
188    /// 3. Log(a, a) ===> 1
189    fn simplify(
190        &self,
191        mut args: Vec<Expr>,
192        info: &dyn SimplifyInfo,
193    ) -> Result<ExprSimplifyResult> {
194        // Args are either
195        // log(number)
196        // log(base, number)
197        let num_args = args.len();
198        if num_args > 2 {
199            return plan_err!("Expected log to have 1 or 2 arguments, got {num_args}");
200        }
201        let number = args.pop().ok_or_else(|| {
202            plan_datafusion_err!("Expected log to have 1 or 2 arguments, got 0")
203        })?;
204        let number_datatype = info.get_data_type(&number)?;
205        // default to base 10
206        let base = if let Some(base) = args.pop() {
207            base
208        } else {
209            lit(ScalarValue::new_ten(&number_datatype)?)
210        };
211
212        match number {
213            Expr::Literal(value) if value == ScalarValue::new_one(&number_datatype)? => {
214                Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_zero(
215                    &info.get_data_type(&base)?,
216                )?)))
217            }
218            Expr::ScalarFunction(ScalarFunction { func, mut args })
219                if is_pow(&func) && args.len() == 2 && base == args[0] =>
220            {
221                let b = args.pop().unwrap(); // length checked above
222                Ok(ExprSimplifyResult::Simplified(b))
223            }
224            number => {
225                if number == base {
226                    Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_one(
227                        &number_datatype,
228                    )?)))
229                } else {
230                    let args = match num_args {
231                        1 => vec![number],
232                        2 => vec![base, number],
233                        _ => {
234                            return internal_err!(
235                                "Unexpected number of arguments in log::simplify"
236                            )
237                        }
238                    };
239                    Ok(ExprSimplifyResult::Original(args))
240                }
241            }
242        }
243    }
244}
245
246/// Returns true if the function is `PowerFunc`
247fn is_pow(func: &ScalarUDF) -> bool {
248    func.inner().as_any().downcast_ref::<PowerFunc>().is_some()
249}
250
251#[cfg(test)]
252mod tests {
253    use std::collections::HashMap;
254
255    use super::*;
256
257    use arrow::array::{Float32Array, Float64Array, Int64Array};
258    use arrow::compute::SortOptions;
259    use datafusion_common::cast::{as_float32_array, as_float64_array};
260    use datafusion_common::DFSchema;
261    use datafusion_expr::execution_props::ExecutionProps;
262    use datafusion_expr::simplify::SimplifyContext;
263
264    #[test]
265    #[should_panic]
266    fn test_log_invalid_base_type() {
267        let args = ScalarFunctionArgs {
268            args: vec![
269                ColumnarValue::Array(Arc::new(Float64Array::from(vec![
270                    10.0, 100.0, 1000.0, 10000.0,
271                ]))), // num
272                ColumnarValue::Array(Arc::new(Int64Array::from(vec![5, 10, 15, 20]))),
273            ],
274            number_rows: 4,
275            return_type: &DataType::Float64,
276        };
277        let _ = LogFunc::new().invoke_with_args(args);
278    }
279
280    #[test]
281    fn test_log_invalid_value() {
282        let args = ScalarFunctionArgs {
283            args: vec![
284                ColumnarValue::Array(Arc::new(Int64Array::from(vec![10]))), // num
285            ],
286            number_rows: 1,
287            return_type: &DataType::Float64,
288        };
289
290        let result = LogFunc::new().invoke_with_args(args);
291        result.expect_err("expected error");
292    }
293
294    #[test]
295    fn test_log_scalar_f32_unary() {
296        let args = ScalarFunctionArgs {
297            args: vec![
298                ColumnarValue::Scalar(ScalarValue::Float32(Some(10.0))), // num
299            ],
300            number_rows: 1,
301            return_type: &DataType::Float32,
302        };
303        let result = LogFunc::new()
304            .invoke_with_args(args)
305            .expect("failed to initialize function log");
306
307        match result {
308            ColumnarValue::Array(arr) => {
309                let floats = as_float32_array(&arr)
310                    .expect("failed to convert result to a Float32Array");
311
312                assert_eq!(floats.len(), 1);
313                assert!((floats.value(0) - 1.0).abs() < 1e-10);
314            }
315            ColumnarValue::Scalar(_) => {
316                panic!("Expected an array value")
317            }
318        }
319    }
320
321    #[test]
322    fn test_log_scalar_f64_unary() {
323        let args = ScalarFunctionArgs {
324            args: vec![
325                ColumnarValue::Scalar(ScalarValue::Float64(Some(10.0))), // num
326            ],
327            number_rows: 1,
328            return_type: &DataType::Float64,
329        };
330        let result = LogFunc::new()
331            .invoke_with_args(args)
332            .expect("failed to initialize function log");
333
334        match result {
335            ColumnarValue::Array(arr) => {
336                let floats = as_float64_array(&arr)
337                    .expect("failed to convert result to a Float64Array");
338
339                assert_eq!(floats.len(), 1);
340                assert!((floats.value(0) - 1.0).abs() < 1e-10);
341            }
342            ColumnarValue::Scalar(_) => {
343                panic!("Expected an array value")
344            }
345        }
346    }
347
348    #[test]
349    fn test_log_scalar_f32() {
350        let args = ScalarFunctionArgs {
351            args: vec![
352                ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))), // num
353                ColumnarValue::Scalar(ScalarValue::Float32(Some(32.0))), // num
354            ],
355            number_rows: 1,
356            return_type: &DataType::Float32,
357        };
358        let result = LogFunc::new()
359            .invoke_with_args(args)
360            .expect("failed to initialize function log");
361
362        match result {
363            ColumnarValue::Array(arr) => {
364                let floats = as_float32_array(&arr)
365                    .expect("failed to convert result to a Float32Array");
366
367                assert_eq!(floats.len(), 1);
368                assert!((floats.value(0) - 5.0).abs() < 1e-10);
369            }
370            ColumnarValue::Scalar(_) => {
371                panic!("Expected an array value")
372            }
373        }
374    }
375
376    #[test]
377    fn test_log_scalar_f64() {
378        let args = ScalarFunctionArgs {
379            args: vec![
380                ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), // num
381                ColumnarValue::Scalar(ScalarValue::Float64(Some(64.0))), // num
382            ],
383            number_rows: 1,
384            return_type: &DataType::Float64,
385        };
386        let result = LogFunc::new()
387            .invoke_with_args(args)
388            .expect("failed to initialize function log");
389
390        match result {
391            ColumnarValue::Array(arr) => {
392                let floats = as_float64_array(&arr)
393                    .expect("failed to convert result to a Float64Array");
394
395                assert_eq!(floats.len(), 1);
396                assert!((floats.value(0) - 6.0).abs() < 1e-10);
397            }
398            ColumnarValue::Scalar(_) => {
399                panic!("Expected an array value")
400            }
401        }
402    }
403
404    #[test]
405    fn test_log_f64_unary() {
406        let args = ScalarFunctionArgs {
407            args: vec![
408                ColumnarValue::Array(Arc::new(Float64Array::from(vec![
409                    10.0, 100.0, 1000.0, 10000.0,
410                ]))), // num
411            ],
412            number_rows: 4,
413            return_type: &DataType::Float64,
414        };
415        let result = LogFunc::new()
416            .invoke_with_args(args)
417            .expect("failed to initialize function log");
418
419        match result {
420            ColumnarValue::Array(arr) => {
421                let floats = as_float64_array(&arr)
422                    .expect("failed to convert result to a Float64Array");
423
424                assert_eq!(floats.len(), 4);
425                assert!((floats.value(0) - 1.0).abs() < 1e-10);
426                assert!((floats.value(1) - 2.0).abs() < 1e-10);
427                assert!((floats.value(2) - 3.0).abs() < 1e-10);
428                assert!((floats.value(3) - 4.0).abs() < 1e-10);
429            }
430            ColumnarValue::Scalar(_) => {
431                panic!("Expected an array value")
432            }
433        }
434    }
435
436    #[test]
437    fn test_log_f32_unary() {
438        let args = ScalarFunctionArgs {
439            args: vec![
440                ColumnarValue::Array(Arc::new(Float32Array::from(vec![
441                    10.0, 100.0, 1000.0, 10000.0,
442                ]))), // num
443            ],
444            number_rows: 4,
445            return_type: &DataType::Float32,
446        };
447        let result = LogFunc::new()
448            .invoke_with_args(args)
449            .expect("failed to initialize function log");
450
451        match result {
452            ColumnarValue::Array(arr) => {
453                let floats = as_float32_array(&arr)
454                    .expect("failed to convert result to a Float64Array");
455
456                assert_eq!(floats.len(), 4);
457                assert!((floats.value(0) - 1.0).abs() < 1e-10);
458                assert!((floats.value(1) - 2.0).abs() < 1e-10);
459                assert!((floats.value(2) - 3.0).abs() < 1e-10);
460                assert!((floats.value(3) - 4.0).abs() < 1e-10);
461            }
462            ColumnarValue::Scalar(_) => {
463                panic!("Expected an array value")
464            }
465        }
466    }
467
468    #[test]
469    fn test_log_f64() {
470        let args = ScalarFunctionArgs {
471            args: vec![
472                ColumnarValue::Array(Arc::new(Float64Array::from(vec![
473                    2.0, 2.0, 3.0, 5.0,
474                ]))), // base
475                ColumnarValue::Array(Arc::new(Float64Array::from(vec![
476                    8.0, 4.0, 81.0, 625.0,
477                ]))), // num
478            ],
479            number_rows: 4,
480            return_type: &DataType::Float64,
481        };
482        let result = LogFunc::new()
483            .invoke_with_args(args)
484            .expect("failed to initialize function log");
485
486        match result {
487            ColumnarValue::Array(arr) => {
488                let floats = as_float64_array(&arr)
489                    .expect("failed to convert result to a Float64Array");
490
491                assert_eq!(floats.len(), 4);
492                assert!((floats.value(0) - 3.0).abs() < 1e-10);
493                assert!((floats.value(1) - 2.0).abs() < 1e-10);
494                assert!((floats.value(2) - 4.0).abs() < 1e-10);
495                assert!((floats.value(3) - 4.0).abs() < 1e-10);
496            }
497            ColumnarValue::Scalar(_) => {
498                panic!("Expected an array value")
499            }
500        }
501    }
502
503    #[test]
504    fn test_log_f32() {
505        let args = ScalarFunctionArgs {
506            args: vec![
507                ColumnarValue::Array(Arc::new(Float32Array::from(vec![
508                    2.0, 2.0, 3.0, 5.0,
509                ]))), // base
510                ColumnarValue::Array(Arc::new(Float32Array::from(vec![
511                    8.0, 4.0, 81.0, 625.0,
512                ]))), // num
513            ],
514            number_rows: 4,
515            return_type: &DataType::Float32,
516        };
517        let result = LogFunc::new()
518            .invoke_with_args(args)
519            .expect("failed to initialize function log");
520
521        match result {
522            ColumnarValue::Array(arr) => {
523                let floats = as_float32_array(&arr)
524                    .expect("failed to convert result to a Float32Array");
525
526                assert_eq!(floats.len(), 4);
527                assert!((floats.value(0) - 3.0).abs() < f32::EPSILON);
528                assert!((floats.value(1) - 2.0).abs() < f32::EPSILON);
529                assert!((floats.value(2) - 4.0).abs() < f32::EPSILON);
530                assert!((floats.value(3) - 4.0).abs() < f32::EPSILON);
531            }
532            ColumnarValue::Scalar(_) => {
533                panic!("Expected an array value")
534            }
535        }
536    }
537    #[test]
538    // Test log() simplification errors
539    fn test_log_simplify_errors() {
540        let props = ExecutionProps::new();
541        let schema =
542            Arc::new(DFSchema::new_with_metadata(vec![], HashMap::new()).unwrap());
543        let context = SimplifyContext::new(&props).with_schema(schema);
544        // Expect 0 args to error
545        let _ = LogFunc::new().simplify(vec![], &context).unwrap_err();
546        // Expect 3 args to error
547        let _ = LogFunc::new()
548            .simplify(vec![lit(1), lit(2), lit(3)], &context)
549            .unwrap_err();
550    }
551
552    #[test]
553    // Test that non-simplifiable log() expressions are unchanged after simplification
554    fn test_log_simplify_original() {
555        let props = ExecutionProps::new();
556        let schema =
557            Arc::new(DFSchema::new_with_metadata(vec![], HashMap::new()).unwrap());
558        let context = SimplifyContext::new(&props).with_schema(schema);
559        // One argument with no simplifications
560        let result = LogFunc::new().simplify(vec![lit(2)], &context).unwrap();
561        let ExprSimplifyResult::Original(args) = result else {
562            panic!("Expected ExprSimplifyResult::Original")
563        };
564        assert_eq!(args.len(), 1);
565        assert_eq!(args[0], lit(2));
566        // Two arguments with no simplifications
567        let result = LogFunc::new()
568            .simplify(vec![lit(2), lit(3)], &context)
569            .unwrap();
570        let ExprSimplifyResult::Original(args) = result else {
571            panic!("Expected ExprSimplifyResult::Original")
572        };
573        assert_eq!(args.len(), 2);
574        assert_eq!(args[0], lit(2));
575        assert_eq!(args[1], lit(3));
576    }
577
578    #[test]
579    fn test_log_output_ordering() {
580        // [Unordered, Ascending, Descending, Literal]
581        let orders = vec![
582            ExprProperties::new_unknown(),
583            ExprProperties::new_unknown().with_order(SortProperties::Ordered(
584                SortOptions {
585                    descending: false,
586                    nulls_first: true,
587                },
588            )),
589            ExprProperties::new_unknown().with_order(SortProperties::Ordered(
590                SortOptions {
591                    descending: true,
592                    nulls_first: true,
593                },
594            )),
595            ExprProperties::new_unknown().with_order(SortProperties::Singleton),
596        ];
597
598        let log = LogFunc::new();
599
600        // Test log(num)
601        for order in orders.iter().cloned() {
602            let result = log.output_ordering(&[order.clone()]).unwrap();
603            assert_eq!(result, order.sort_properties);
604        }
605
606        // Test log(base, num), where `nulls_first` is the same
607        let mut results = Vec::with_capacity(orders.len() * orders.len());
608        for base_order in orders.iter() {
609            for num_order in orders.iter().cloned() {
610                let result = log
611                    .output_ordering(&[base_order.clone(), num_order])
612                    .unwrap();
613                results.push(result);
614            }
615        }
616        let expected = vec![
617            // base: Unordered
618            SortProperties::Unordered,
619            SortProperties::Unordered,
620            SortProperties::Unordered,
621            SortProperties::Unordered,
622            // base: Ascending, num: Unordered
623            SortProperties::Unordered,
624            // base: Ascending, num: Ascending
625            SortProperties::Unordered,
626            // base: Ascending, num: Descending
627            SortProperties::Ordered(SortOptions {
628                descending: true,
629                nulls_first: true,
630            }),
631            // base: Ascending, num: Literal
632            SortProperties::Ordered(SortOptions {
633                descending: true,
634                nulls_first: true,
635            }),
636            // base: Descending, num: Unordered
637            SortProperties::Unordered,
638            // base: Descending, num: Ascending
639            SortProperties::Ordered(SortOptions {
640                descending: false,
641                nulls_first: true,
642            }),
643            // base: Descending, num: Descending
644            SortProperties::Unordered,
645            // base: Descending, num: Literal
646            SortProperties::Ordered(SortOptions {
647                descending: false,
648                nulls_first: true,
649            }),
650            // base: Literal, num: Unordered
651            SortProperties::Unordered,
652            // base: Literal, num: Ascending
653            SortProperties::Ordered(SortOptions {
654                descending: false,
655                nulls_first: true,
656            }),
657            // base: Literal, num: Descending
658            SortProperties::Ordered(SortOptions {
659                descending: true,
660                nulls_first: true,
661            }),
662            // base: Literal, num: Literal
663            SortProperties::Singleton,
664        ];
665        assert_eq!(results, expected);
666
667        // Test with different `nulls_first`
668        let base_order = ExprProperties::new_unknown().with_order(
669            SortProperties::Ordered(SortOptions {
670                descending: true,
671                nulls_first: true,
672            }),
673        );
674        let num_order = ExprProperties::new_unknown().with_order(
675            SortProperties::Ordered(SortOptions {
676                descending: false,
677                nulls_first: false,
678            }),
679        );
680        assert_eq!(
681            log.output_ordering(&[base_order, num_order]).unwrap(),
682            SortProperties::Unordered
683        );
684    }
685}