datafusion_common/
pyarrow.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Conversions between PyArrow and DataFusion types
19
20use arrow::array::{Array, ArrayData};
21use arrow::pyarrow::{FromPyArrow, ToPyArrow};
22use pyo3::exceptions::PyException;
23use pyo3::prelude::PyErr;
24use pyo3::types::{PyAnyMethods, PyList};
25use pyo3::{Bound, FromPyObject, IntoPyObject, PyAny, PyObject, PyResult, Python};
26
27use crate::{DataFusionError, ScalarValue};
28
29impl From<DataFusionError> for PyErr {
30    fn from(err: DataFusionError) -> PyErr {
31        PyException::new_err(err.to_string())
32    }
33}
34
35impl FromPyArrow for ScalarValue {
36    fn from_pyarrow_bound(value: &Bound<'_, PyAny>) -> PyResult<Self> {
37        let py = value.py();
38        let typ = value.getattr("type")?;
39        let val = value.call_method0("as_py")?;
40
41        // construct pyarrow array from the python value and pyarrow type
42        let factory = py.import("pyarrow")?.getattr("array")?;
43        let args = PyList::new(py, [val])?;
44        let array = factory.call1((args, typ))?;
45
46        // convert the pyarrow array to rust array using C data interface
47        let array = arrow::array::make_array(ArrayData::from_pyarrow_bound(&array)?);
48        let scalar = ScalarValue::try_from_array(&array, 0)?;
49
50        Ok(scalar)
51    }
52}
53
54impl ToPyArrow for ScalarValue {
55    fn to_pyarrow(&self, py: Python) -> PyResult<PyObject> {
56        let array = self.to_array()?;
57        // convert to pyarrow array using C data interface
58        let pyarray = array.to_data().to_pyarrow(py)?;
59        let pyscalar = pyarray.call_method1(py, "__getitem__", (0,))?;
60
61        Ok(pyscalar)
62    }
63}
64
65impl<'source> FromPyObject<'source> for ScalarValue {
66    fn extract_bound(value: &Bound<'source, PyAny>) -> PyResult<Self> {
67        Self::from_pyarrow_bound(value)
68    }
69}
70
71impl<'source> IntoPyObject<'source> for ScalarValue {
72    type Target = PyAny;
73
74    type Output = Bound<'source, Self::Target>;
75
76    type Error = PyErr;
77
78    fn into_pyobject(self, py: Python<'source>) -> Result<Self::Output, Self::Error> {
79        let array = self.to_array()?;
80        // convert to pyarrow array using C data interface
81        let pyarray = array.to_data().to_pyarrow(py)?;
82        let pyarray_bound = pyarray.bind(py);
83        pyarray_bound.call_method1("__getitem__", (0,))
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use pyo3::ffi::c_str;
90    use pyo3::prepare_freethreaded_python;
91    use pyo3::py_run;
92    use pyo3::types::PyDict;
93
94    use super::*;
95
96    fn init_python() {
97        prepare_freethreaded_python();
98        Python::with_gil(|py| {
99            if py.run(c_str!("import pyarrow"), None, None).is_err() {
100                let locals = PyDict::new(py);
101                py.run(
102                    c_str!(
103                        "import sys; executable = sys.executable; python_path = sys.path"
104                    ),
105                    None,
106                    Some(&locals),
107                )
108                .expect("Couldn't get python info");
109                let executable = locals.get_item("executable").unwrap();
110                let executable: String = executable.extract().unwrap();
111
112                let python_path = locals.get_item("python_path").unwrap();
113                let python_path: Vec<String> = python_path.extract().unwrap();
114
115                panic!("pyarrow not found\nExecutable: {executable}\nPython path: {python_path:?}\n\
116                         HINT: try `pip install pyarrow`\n\
117                         NOTE: On Mac OS, you must compile against a Framework Python \
118                         (default in python.org installers and brew, but not pyenv)\n\
119                         NOTE: On Mac OS, PYO3 might point to incorrect Python library \
120                         path when using virtual environments. Try \
121                         `export PYTHONPATH=$(python -c \"import sys; print(sys.path[-1])\")`\n")
122            }
123        })
124    }
125
126    #[test]
127    fn test_roundtrip() {
128        init_python();
129
130        let example_scalars = vec![
131            ScalarValue::Boolean(Some(true)),
132            ScalarValue::Int32(Some(23)),
133            ScalarValue::Float64(Some(12.34)),
134            ScalarValue::from("Hello!"),
135            ScalarValue::Date32(Some(1234)),
136        ];
137
138        Python::with_gil(|py| {
139            for scalar in example_scalars.iter() {
140                let result = ScalarValue::from_pyarrow_bound(
141                    scalar.to_pyarrow(py).unwrap().bind(py),
142                )
143                .unwrap();
144                assert_eq!(scalar, &result);
145            }
146        });
147    }
148
149    #[test]
150    fn test_py_scalar() -> PyResult<()> {
151        init_python();
152
153        Python::with_gil(|py| -> PyResult<()> {
154            let scalar_float = ScalarValue::Float64(Some(12.34));
155            let py_float = scalar_float
156                .into_pyobject(py)?
157                .call_method0("as_py")
158                .unwrap();
159            py_run!(py, py_float, "assert py_float == 12.34");
160
161            let scalar_string = ScalarValue::Utf8(Some("Hello!".to_string()));
162            let py_string = scalar_string
163                .into_pyobject(py)?
164                .call_method0("as_py")
165                .unwrap();
166            py_run!(py, py_string, "assert py_string == 'Hello!'");
167
168            Ok(())
169        })
170    }
171}