lance_index/vector/
hnsw.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The Lance Authors

//! HNSW graph implementation.
//!
//! Hierarchical Navigable Small World (HNSW).
//!

use arrow_schema::{DataType, Field};
use deepsize::DeepSizeOf;
use itertools::Itertools;
use serde::{Deserialize, Serialize};

use self::builder::HnswBuildParams;
use super::graph::{OrderedFloat, OrderedNode};
use super::storage::{DistCalculator, VectorStore};

pub mod builder;
pub mod index;

pub use builder::HNSW;
pub use index::HNSWIndex;

const HNSW_TYPE: &str = "HNSW";
const VECTOR_ID_COL: &str = "__vector_id";
const POINTER_COL: &str = "__pointer";

lazy_static::lazy_static! {
    /// POINTER field.
    ///
    pub static ref POINTER_FIELD: Field = Field::new(POINTER_COL, DataType::UInt32, true);

    /// Id of the vector in the [VectorStorage].
    pub static ref VECTOR_ID_FIELD: Field = Field::new(VECTOR_ID_COL, DataType::UInt32, true);
}

#[derive(Debug, Clone, Serialize, Deserialize, DeepSizeOf)]
pub struct HnswMetadata {
    pub entry_point: u32,
    pub params: HnswBuildParams,
    pub level_offsets: Vec<usize>,
}

impl Default for HnswMetadata {
    fn default() -> Self {
        let params = HnswBuildParams::default();
        let level_offsets = vec![0; params.max_level as usize];
        Self {
            entry_point: 0,
            params,
            level_offsets,
        }
    }
}

/// Algorithm 4 in the HNSW paper.
///
/// # NOTE
/// The results are not ordered.
fn select_neighbors_heuristic(
    storage: &impl VectorStore,
    candidates: &[OrderedNode],
    k: usize,
) -> Vec<OrderedNode> {
    if candidates.len() <= k {
        return candidates.iter().cloned().collect_vec();
    }
    let mut candidates = candidates.to_vec();
    candidates.sort_unstable_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());

    let mut results: Vec<OrderedNode> = Vec::with_capacity(k);
    for u in candidates.iter() {
        if results.len() >= k {
            break;
        }
        let dist_cal = storage.dist_calculator_from_id(u.id);

        if results.is_empty()
            || results
                .iter()
                .all(|v| u.dist < OrderedFloat(dist_cal.distance(v.id)))
        {
            results.push(u.clone());
        }
    }
    results
}