datafusion_common/utils/
memory.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! This module provides a function to estimate the memory size of a HashTable prior to allocation
19
20use crate::{DataFusionError, Result};
21use std::mem::size_of;
22
23/// Estimates the memory size required for a hash table prior to allocation.
24///
25/// # Parameters
26/// - `num_elements`: The number of elements expected in the hash table.
27/// - `fixed_size`: A fixed overhead size associated with the collection
28///    (e.g., HashSet or HashTable).
29/// - `T`: The type of elements stored in the hash table.
30///
31/// # Details
32/// This function calculates the estimated memory size by considering:
33/// - An overestimation of buckets to keep approximately 1/8 of them empty.
34/// - The total memory size is computed as:
35///   - The size of each entry (`T`) multiplied by the estimated number of
36///     buckets.
37///   - One byte overhead for each bucket.
38///   - The fixed size overhead of the collection.
39/// - If the estimation overflows, we return a [`DataFusionError`]
40///
41/// # Examples
42/// ---
43///
44/// ## From within a struct
45///
46/// ```rust
47/// # use datafusion_common::utils::memory::estimate_memory_size;
48/// # use datafusion_common::Result;
49///
50/// struct MyStruct<T> {
51///     values: Vec<T>,
52///     other_data: usize,
53/// }
54///
55/// impl<T> MyStruct<T> {
56///     fn size(&self) -> Result<usize> {
57///         let num_elements = self.values.len();
58///         let fixed_size = std::mem::size_of_val(self) +
59///           std::mem::size_of_val(&self.values);
60///
61///         estimate_memory_size::<T>(num_elements, fixed_size)
62///     }
63/// }
64/// ```
65/// ---
66/// ## With a simple collection
67///
68/// ```rust
69/// # use datafusion_common::utils::memory::estimate_memory_size;
70/// # use std::collections::HashMap;
71///
72/// let num_rows = 100;
73/// let fixed_size = std::mem::size_of::<HashMap<u64, u64>>();
74/// let estimated_hashtable_size =
75///   estimate_memory_size::<(u64, u64)>(num_rows,fixed_size)
76///     .expect("Size estimation failed");
77/// ```
78pub fn estimate_memory_size<T>(num_elements: usize, fixed_size: usize) -> Result<usize> {
79    // For the majority of cases hashbrown overestimates the bucket quantity
80    // to keep ~1/8 of them empty. We take this factor into account by
81    // multiplying the number of elements with a fixed ratio of 8/7 (~1.14).
82    // This formula leads to over-allocation for small tables (< 8 elements)
83    // but should be fine overall.
84    num_elements
85        .checked_mul(8)
86        .and_then(|overestimate| {
87            let estimated_buckets = (overestimate / 7).next_power_of_two();
88            // + size of entry * number of buckets
89            // + 1 byte for each bucket
90            // + fixed size of collection (HashSet/HashTable)
91            size_of::<T>()
92                .checked_mul(estimated_buckets)?
93                .checked_add(estimated_buckets)?
94                .checked_add(fixed_size)
95        })
96        .ok_or_else(|| {
97            DataFusionError::Execution(
98                "usize overflow while estimating the number of buckets".to_string(),
99            )
100        })
101}
102
103#[cfg(test)]
104mod tests {
105    use std::{collections::HashSet, mem::size_of};
106
107    use super::estimate_memory_size;
108
109    #[test]
110    fn test_estimate_memory() {
111        // size (bytes): 48
112        let fixed_size = size_of::<HashSet<u32>>();
113
114        // estimated buckets: 16 = (8 * 8 / 7).next_power_of_two()
115        let num_elements = 8;
116        // size (bytes): 128 = 16 * 4 + 16 + 48
117        let estimated = estimate_memory_size::<u32>(num_elements, fixed_size).unwrap();
118        assert_eq!(estimated, 128);
119
120        // estimated buckets: 64 = (40 * 8 / 7).next_power_of_two()
121        let num_elements = 40;
122        // size (bytes): 368 = 64 * 4 + 64 + 48
123        let estimated = estimate_memory_size::<u32>(num_elements, fixed_size).unwrap();
124        assert_eq!(estimated, 368);
125    }
126
127    #[test]
128    fn test_estimate_memory_overflow() {
129        let num_elements = usize::MAX;
130        let fixed_size = size_of::<HashSet<u32>>();
131        let estimated = estimate_memory_size::<u32>(num_elements, fixed_size);
132
133        assert!(estimated.is_err());
134    }
135}