llama_cpp_2/model/params/
kv_overrides.rsuse crate::model::params::LlamaModelParams;
use std::ffi::{CStr, CString};
use std::fmt::Debug;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ParamOverrideValue {
Bool(bool),
Float(f64),
Int(i64),
Str([std::os::raw::c_char; 128]),
}
impl ParamOverrideValue {
pub(crate) fn tag(&self) -> llama_cpp_sys_2::llama_model_kv_override_type {
match self {
ParamOverrideValue::Bool(_) => llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_BOOL,
ParamOverrideValue::Float(_) => llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_FLOAT,
ParamOverrideValue::Int(_) => llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_INT,
ParamOverrideValue::Str(_) => llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_STR,
}
}
pub(crate) fn value(&self) -> llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 {
match self {
ParamOverrideValue::Bool(value) => {
llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 { val_bool: *value }
}
ParamOverrideValue::Float(value) => {
llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 { val_f64: *value }
}
ParamOverrideValue::Int(value) => {
llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 { val_i64: *value }
}
ParamOverrideValue::Str(c_string) => {
llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 { val_str: *c_string }
}
}
}
}
impl From<&llama_cpp_sys_2::llama_model_kv_override> for ParamOverrideValue {
fn from(
llama_cpp_sys_2::llama_model_kv_override {
key: _,
tag,
__bindgen_anon_1,
}: &llama_cpp_sys_2::llama_model_kv_override,
) -> Self {
match *tag {
llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_INT => {
ParamOverrideValue::Int(unsafe { __bindgen_anon_1.val_i64 })
}
llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_FLOAT => {
ParamOverrideValue::Float(unsafe { __bindgen_anon_1.val_f64 })
}
llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_BOOL => {
ParamOverrideValue::Bool(unsafe { __bindgen_anon_1.val_bool })
}
llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_STR => {
ParamOverrideValue::Str(unsafe { __bindgen_anon_1.val_str })
}
_ => unreachable!("Unknown tag of {tag}"),
}
}
}
#[derive(Debug)]
pub struct KvOverrides<'a> {
model_params: &'a LlamaModelParams,
}
impl KvOverrides<'_> {
pub(super) fn new(model_params: &LlamaModelParams) -> KvOverrides {
KvOverrides { model_params }
}
}
impl<'a> IntoIterator for KvOverrides<'a> {
type Item = (CString, ParamOverrideValue);
type IntoIter = KvOverrideValueIterator<'a>;
fn into_iter(self) -> Self::IntoIter {
KvOverrideValueIterator {
model_params: self.model_params,
current: 0,
}
}
}
#[derive(Debug)]
pub struct KvOverrideValueIterator<'a> {
model_params: &'a LlamaModelParams,
current: usize,
}
impl Iterator for KvOverrideValueIterator<'_> {
type Item = (CString, ParamOverrideValue);
fn next(&mut self) -> Option<Self::Item> {
let overrides = self.model_params.params.kv_overrides;
if overrides.is_null() {
return None;
}
let current = unsafe { *overrides.add(self.current) };
if current.key[0] == 0 {
return None;
}
let value = ParamOverrideValue::from(¤t);
let key = unsafe { CStr::from_ptr(current.key.as_ptr()).to_owned() };
self.current += 1;
Some((key, value))
}
}