llama_cpp_2/context/session.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
//! utilities for working with session files
use crate::context::LlamaContext;
use crate::token::LlamaToken;
use std::ffi::{CString, NulError};
use std::path::{Path, PathBuf};
/// Failed to save a Session file
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
pub enum SaveSessionError {
/// llama.cpp failed to save the session file
#[error("Failed to save session file")]
FailedToSave,
/// null byte in string
#[error("null byte in string {0}")]
NullError(#[from] NulError),
/// failed to convert path to str
#[error("failed to convert path {0} to str")]
PathToStrError(PathBuf),
}
/// Failed to load a Session file
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
pub enum LoadSessionError {
/// llama.cpp failed to load the session file
#[error("Failed to load session file")]
FailedToLoad,
/// null byte in string
#[error("null byte in string {0}")]
NullError(#[from] NulError),
/// failed to convert path to str
#[error("failed to convert path {0} to str")]
PathToStrError(PathBuf),
/// Insufficient max length
#[error("max_length is not large enough to hold {n_out} (was {max_tokens})")]
InsufficientMaxLength {
/// The length of the session file
n_out: usize,
/// The maximum length
max_tokens: usize,
},
}
impl LlamaContext<'_> {
/// Save the current session to a file.
///
/// # Parameters
///
/// * `path_session` - The file to save to.
/// * `tokens` - The tokens to associate the session with. This should be a prefix of a sequence of tokens that the context has processed, so that the relevant KV caches are already filled.
///
/// # Errors
///
/// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to save the session file.
pub fn save_session_file(
&self,
path_session: impl AsRef<Path>,
tokens: &[LlamaToken],
) -> Result<(), SaveSessionError> {
let path = path_session.as_ref();
let path = path
.to_str()
.ok_or_else(|| SaveSessionError::PathToStrError(path.to_path_buf()))?;
let cstr = CString::new(path)?;
if unsafe {
llama_cpp_sys_2::llama_save_session_file(
self.context.as_ptr(),
cstr.as_ptr(),
tokens.as_ptr().cast::<llama_cpp_sys_2::llama_token>(),
tokens.len(),
)
} {
Ok(())
} else {
Err(SaveSessionError::FailedToSave)
}
}
/// Load a session file into the current context.
///
/// You still need to pass the returned tokens to the context for inference to work. What this function buys you is that the KV caches are already filled with the relevant data.
///
/// # Parameters
///
/// * `path_session` - The file to load from. It must be a session file from a compatible context, otherwise the function will error.
/// * `max_tokens` - The maximum token length of the loaded session. If the session was saved with a longer length, the function will error.
///
/// # Errors
///
/// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to load the session file. (e.g. the file does not exist, is not a session file, etc.)
pub fn load_session_file(
&mut self,
path_session: impl AsRef<Path>,
max_tokens: usize,
) -> Result<Vec<LlamaToken>, LoadSessionError> {
let path = path_session.as_ref();
let path = path
.to_str()
.ok_or(LoadSessionError::PathToStrError(path.to_path_buf()))?;
let cstr = CString::new(path)?;
let mut tokens: Vec<LlamaToken> = Vec::with_capacity(max_tokens);
let mut n_out = 0;
// SAFETY: cast is valid as LlamaToken is repr(transparent)
let tokens_out = tokens.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>();
let load_session_success = unsafe {
llama_cpp_sys_2::llama_load_session_file(
self.context.as_ptr(),
cstr.as_ptr(),
tokens_out,
max_tokens,
&mut n_out,
)
};
if load_session_success {
if n_out > max_tokens {
return Err(LoadSessionError::InsufficientMaxLength { n_out, max_tokens });
}
// SAFETY: we checked that n_out <= max_tokens and llama.cpp promises that n_out tokens will be written
unsafe {
tokens.set_len(n_out);
}
Ok(tokens)
} else {
Err(LoadSessionError::FailedToLoad)
}
}
/// Returns the maximum size in bytes of the state (rng, logits, embedding
/// and `kv_cache`) - will often be smaller after compacting tokens
#[must_use]
pub fn get_state_size(&self) -> usize {
unsafe { llama_cpp_sys_2::llama_get_state_size(self.context.as_ptr()) }
}
/// Copies the state to the specified destination address.
///
/// Returns the number of bytes copied
///
/// # Safety
///
/// Destination needs to have allocated enough memory.
pub unsafe fn copy_state_data(&self, dest: *mut u8) -> usize {
unsafe { llama_cpp_sys_2::llama_copy_state_data(self.context.as_ptr(), dest) }
}
/// Set the state reading from the specified address
/// Returns the number of bytes read
///
/// # Safety
///
/// help wanted: not entirely sure what the safety requirements are here.
pub unsafe fn set_state_data(&mut self, src: &[u8]) -> usize {
unsafe { llama_cpp_sys_2::llama_set_state_data(self.context.as_ptr(), src.as_ptr()) }
}
}