pyo3_polars/
alloc.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
use std::alloc::{GlobalAlloc, Layout, System};
use std::ffi::c_char;

use once_cell::race::OnceRef;
use pyo3::ffi::{PyCapsule_Import, Py_IsInitialized};
use pyo3::Python;

unsafe extern "C" fn fallback_alloc(size: usize, align: usize) -> *mut u8 {
    System.alloc(Layout::from_size_align_unchecked(size, align))
}

unsafe extern "C" fn fallback_dealloc(ptr: *mut u8, size: usize, align: usize) {
    System.dealloc(ptr, Layout::from_size_align_unchecked(size, align))
}

unsafe extern "C" fn fallback_alloc_zeroed(size: usize, align: usize) -> *mut u8 {
    System.alloc_zeroed(Layout::from_size_align_unchecked(size, align))
}

unsafe extern "C" fn fallback_realloc(
    ptr: *mut u8,
    size: usize,
    align: usize,
    new_size: usize,
) -> *mut u8 {
    System.realloc(
        ptr,
        Layout::from_size_align_unchecked(size, align),
        new_size,
    )
}

#[repr(C)]
struct AllocatorCapsule {
    alloc: unsafe extern "C" fn(usize, usize) -> *mut u8,
    dealloc: unsafe extern "C" fn(*mut u8, usize, usize),
    alloc_zeroed: unsafe extern "C" fn(usize, usize) -> *mut u8,
    realloc: unsafe extern "C" fn(*mut u8, usize, usize, usize) -> *mut u8,
}

static FALLBACK_ALLOCATOR_CAPSULE: AllocatorCapsule = AllocatorCapsule {
    alloc: fallback_alloc,
    alloc_zeroed: fallback_alloc_zeroed,
    dealloc: fallback_dealloc,
    realloc: fallback_realloc,
};

static ALLOCATOR_CAPSULE_NAME: &[u8] = b"polars.polars._allocator\0";

/// A memory allocator that relays allocations to the allocator used by Polars.
///
/// You can use it as the global memory allocator:
///
/// ```rust
/// use pyo3_polars::PolarsAllocator;
///
/// #[global_allocator]
/// static ALLOC: PolarsAllocator = PolarsAllocator::new();
/// ```
///
/// If the allocator capsule (`polars.polars._allocator`) is not available,
/// this allocator fallbacks to [`std::alloc::System`].
pub struct PolarsAllocator(OnceRef<'static, AllocatorCapsule>);

impl PolarsAllocator {
    fn get_allocator(&self) -> &'static AllocatorCapsule {
        // Do not allocate in this function,
        // otherwise it will cause infinite recursion.
        self.0.get_or_init(|| {
            let r = (unsafe { Py_IsInitialized() } != 0)
                .then(|| {
                    Python::with_gil(|_| unsafe {
                        (PyCapsule_Import(ALLOCATOR_CAPSULE_NAME.as_ptr() as *const c_char, 0)
                            as *const AllocatorCapsule)
                            .as_ref()
                    })
                })
                .flatten();
            #[cfg(debug_assertions)]
            if r.is_none() {
                // Do not use eprintln; it may alloc.
                let msg = b"failed to get allocator capsule\n";
                // Message length type is platform-dependent.
                let msg_len = msg.len().try_into().unwrap();
                unsafe { libc::write(2, msg.as_ptr() as *const libc::c_void, msg_len) };
            }
            r.unwrap_or(&FALLBACK_ALLOCATOR_CAPSULE)
        })
    }

    /// Create a `PolarsAllocator`.
    pub const fn new() -> Self {
        PolarsAllocator(OnceRef::new())
    }
}

impl Default for PolarsAllocator {
    fn default() -> Self {
        Self::new()
    }
}

unsafe impl GlobalAlloc for PolarsAllocator {
    #[inline]
    unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
        (self.get_allocator().alloc)(layout.size(), layout.align())
    }

    #[inline]
    unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
        (self.get_allocator().dealloc)(ptr, layout.size(), layout.align());
    }

    #[inline]
    unsafe fn alloc_zeroed(&self, layout: Layout) -> *mut u8 {
        (self.get_allocator().alloc_zeroed)(layout.size(), layout.align())
    }

    #[inline]
    unsafe fn realloc(&self, ptr: *mut u8, layout: Layout, new_size: usize) -> *mut u8 {
        (self.get_allocator().realloc)(ptr, layout.size(), layout.align(), new_size)
    }
}