llama_cpp_2/
llama_backend.rsuse crate::LLamaCppError;
use llama_cpp_sys_2::ggml_log_level;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering::SeqCst;
#[derive(Eq, PartialEq, Debug)]
pub struct LlamaBackend {}
static LLAMA_BACKEND_INITIALIZED: AtomicBool = AtomicBool::new(false);
impl LlamaBackend {
fn mark_init() -> crate::Result<()> {
match LLAMA_BACKEND_INITIALIZED.compare_exchange(false, true, SeqCst, SeqCst) {
Ok(_) => Ok(()),
Err(_) => Err(LLamaCppError::BackendAlreadyInitialized),
}
}
#[tracing::instrument(skip_all)]
pub fn init() -> crate::Result<LlamaBackend> {
Self::mark_init()?;
unsafe { llama_cpp_sys_2::llama_backend_init() }
Ok(LlamaBackend {})
}
#[tracing::instrument(skip_all)]
pub fn init_numa(strategy: NumaStrategy) -> crate::Result<LlamaBackend> {
Self::mark_init()?;
unsafe {
llama_cpp_sys_2::llama_numa_init(llama_cpp_sys_2::ggml_numa_strategy::from(strategy));
}
Ok(LlamaBackend {})
}
pub fn void_logs(&mut self) {
unsafe extern "C" fn void_log(
_level: ggml_log_level,
_text: *const ::std::os::raw::c_char,
_user_data: *mut ::std::os::raw::c_void,
) {
}
unsafe {
llama_cpp_sys_2::llama_log_set(Some(void_log), std::ptr::null_mut());
}
}
}
#[derive(Debug, Eq, PartialEq, Copy, Clone)]
pub enum NumaStrategy {
DISABLED,
DISTRIBUTE,
ISOLATE,
NUMACTL,
MIRROR,
COUNT,
}
#[derive(Debug, Eq, PartialEq, Copy, Clone)]
pub struct InvalidNumaStrategy(
pub llama_cpp_sys_2::ggml_numa_strategy,
);
impl TryFrom<llama_cpp_sys_2::ggml_numa_strategy> for NumaStrategy {
type Error = InvalidNumaStrategy;
fn try_from(value: llama_cpp_sys_2::ggml_numa_strategy) -> Result<Self, Self::Error> {
match value {
llama_cpp_sys_2::GGML_NUMA_STRATEGY_DISABLED => Ok(Self::DISABLED),
llama_cpp_sys_2::GGML_NUMA_STRATEGY_DISTRIBUTE => Ok(Self::DISTRIBUTE),
llama_cpp_sys_2::GGML_NUMA_STRATEGY_ISOLATE => Ok(Self::ISOLATE),
llama_cpp_sys_2::GGML_NUMA_STRATEGY_NUMACTL => Ok(Self::NUMACTL),
llama_cpp_sys_2::GGML_NUMA_STRATEGY_MIRROR => Ok(Self::MIRROR),
llama_cpp_sys_2::GGML_NUMA_STRATEGY_COUNT => Ok(Self::COUNT),
value => Err(InvalidNumaStrategy(value)),
}
}
}
impl From<NumaStrategy> for llama_cpp_sys_2::ggml_numa_strategy {
fn from(value: NumaStrategy) -> Self {
match value {
NumaStrategy::DISABLED => llama_cpp_sys_2::GGML_NUMA_STRATEGY_DISABLED,
NumaStrategy::DISTRIBUTE => llama_cpp_sys_2::GGML_NUMA_STRATEGY_DISTRIBUTE,
NumaStrategy::ISOLATE => llama_cpp_sys_2::GGML_NUMA_STRATEGY_ISOLATE,
NumaStrategy::NUMACTL => llama_cpp_sys_2::GGML_NUMA_STRATEGY_NUMACTL,
NumaStrategy::MIRROR => llama_cpp_sys_2::GGML_NUMA_STRATEGY_MIRROR,
NumaStrategy::COUNT => llama_cpp_sys_2::GGML_NUMA_STRATEGY_COUNT,
}
}
}
impl Drop for LlamaBackend {
fn drop(&mut self) {
match LLAMA_BACKEND_INITIALIZED.compare_exchange(true, false, SeqCst, SeqCst) {
Ok(_) => {}
Err(_) => {
unreachable!("This should not be reachable as the only ways to obtain a llama backend involve marking the backend as initialized.")
}
}
unsafe { llama_cpp_sys_2::llama_backend_free() }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn numa_from_and_to() {
let numas = [
NumaStrategy::DISABLED,
NumaStrategy::DISTRIBUTE,
NumaStrategy::ISOLATE,
NumaStrategy::NUMACTL,
NumaStrategy::MIRROR,
NumaStrategy::COUNT,
];
for numa in &numas {
let from = llama_cpp_sys_2::ggml_numa_strategy::from(*numa);
let to = NumaStrategy::try_from(from).expect("Failed to convert from and to");
assert_eq!(*numa, to);
}
}
#[test]
fn check_invalid_numa() {
let invalid = 800;
let invalid = NumaStrategy::try_from(invalid);
assert_eq!(invalid, Err(InvalidNumaStrategy(invalid.unwrap_err().0)));
}
}