llm_chain/
tokens.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
//! # Tokens Module
//!
//! This module provides utilities for managing tokens in Language Learning Models (LLMs),
//! primarily focusing on measuring the sizes of prompts. This is useful for ensuring that
//! prompts stay within the context window size supported by a given model.

use crate::step::Step;
use crate::{traits, Parameters};
use serde::{Deserialize, Serialize};
use std::cmp::max;
use thiserror::Error;

/// Custom error type for handling prompt token-related errors.
#[derive(Clone, Debug, Error)]
pub enum PromptTokensError {
    /// Indicates that prompt tokens are not accessible for the given step.
    #[error("The prompt tokens are not accessible for this type of step.")]
    NotAvailable,
    /// Indicates that the prompt tokens could not be computed.
    #[error("The prompt tokens could not be computed.")]
    UnableToCompute,
    /// Indicates that the prompt tokens could not be computed because formatting the prompt failed.
    #[error("Formatting prompt failed: {0}")]
    PromptFormatFailed(#[from] crate::prompt::StringTemplateError),
    #[error("Tokenizer error: {0}")]
    TokenizerError(#[from] crate::tokens::TokenizerError),
}

/// An extension trait for the `Executor` trait that provides additional methods for working
/// with token counts.
pub trait ExecutorTokenCountExt: traits::Executor {
    /// Splits a `Parameters` object into multiple smaller `Parameters` objects that fit within
    /// the context window size supported by the given model.
    ///
    /// # Arguments
    /// * `step` - The step that will process the Parameters. Has impact on tokenizer & text splitter used
    /// * `doc` - The parameter object to split into multiple, smaller, parameters
    /// * `chunk_overlap` - The amount of tokens each split part should overlap with previous & next chunk
    ///
    /// # Errors
    ///
    /// Returns a `PromptTokensError` if there is an issue computing the tokens.
    fn split_to_fit(
        &self,
        step: &Step,
        doc: &Parameters,
        base_parameters: &Parameters,
        chunk_overlap: Option<usize>,
    ) -> Result<Vec<Parameters>, PromptTokensError> {
        let splitter = self
            .get_tokenizer(step.options())
            .map_err(|_e| PromptTokensError::UnableToCompute)?;

        let text = doc.get_text().ok_or(PromptTokensError::UnableToCompute)?;

        let prompt = step.format(&base_parameters.combine(&Parameters::new_with_text("")))?;
        let tokens_used = self.tokens_used(step.options(), &prompt)?;
        let chunk_overlap = chunk_overlap.unwrap_or(0);

        let split_params = splitter
            .split_text(
                &text,
                tokens_used.max_tokens as usize - tokens_used.tokens_used as usize,
                chunk_overlap,
            )
            .map_err(|_e| PromptTokensError::UnableToCompute)?
            .into_iter()
            .map(Parameters::new_with_text)
            .collect();
        Ok(split_params)
    }
}

/// Blanket implementation of ExecutorTokenCountExt for all Executors
impl<E: traits::Executor> ExecutorTokenCountExt for E {}

/// Struct representing token count information, including the maximum tokens allowed and the
/// total number of tokens used.
pub struct TokenCount {
    /// The maximum number of tokens allowed.
    max_tokens: i32,
    /// The total number of tokens used.
    tokens_used: i32,
}
impl TokenCount {
    /// Creates a new `TokenCount` instance with the given maximum tokens and tokens used.
    ///
    /// # Arguments
    ///
    /// * `max_tokens` - The maximum number of tokens allowed.
    /// * `tokens_used` - The total number of tokens used.
    pub fn new(max_tokens: i32, tokens_used: i32) -> Self {
        Self {
            max_tokens,
            tokens_used,
        }
    }

    /// Returns the number of tokens that could be added to the context window.
    pub fn tokens_remaining(&self) -> i32 {
        self.max_tokens - self.tokens_used
    }

    /// Returns true if there is still room in the context window.
    pub fn has_tokens_remaining(&self) -> bool {
        self.has_room_for(1)
    }

    /// Returns true if there is room for the given number of tokens.
    ///
    /// # Arguments
    ///
    /// * `tokens` - The number of tokens to check if there is room for.
    ///
    /// # Examples
    ///
    /// ```
    /// use llm_chain::tokens::TokenCount;
    /// let token_count = TokenCount::new(100, 50);
    /// assert!(token_count.has_room_for(49));
    /// ```
    pub fn has_room_for(&self, tokens: i32) -> bool {
        self.tokens_remaining() >= tokens
    }
}

#[derive(Error, Debug, Clone)]
pub enum TokenizerError {
    #[error("Error tokenizing input text")]
    TokenizationError,
    #[error("Error stringifying tokens to text")]
    ToStringError,
    #[error("Error creating tokenizer")]
    TokenizerCreationError,
    #[error("Token Collection type mismatch")]
    TokenCollectionTypeMismatch,
}

pub trait Tokenizer {
    /// Tokenizes a string.
    ///
    /// # Parameters
    ///
    /// * `doc`: The string to tokenize.
    ///
    /// # Returns
    ///
    /// A `Result` containing a vector of tokens, or an error if there was a problem.
    fn tokenize_str(&self, doc: &str) -> Result<TokenCollection, TokenizerError>;

    /// Converts a vector of tokens into a string.
    ///
    /// # Parameters
    ///
    /// * `tokens`: The slice of tokens to convert.
    ///
    /// # Returns
    ///
    /// A `Result` containing a string, or an error if there was a problem.
    fn to_string(&self, tokens: TokenCollection) -> Result<String, TokenizerError>;

    fn split_text(
        &self,
        doc: &str,
        max_tokens_per_chunk: usize,
        chunk_overlap: usize,
    ) -> Result<Vec<String>, TokenizerError> {
        let tokens = self.tokenize_str(doc)?;
        let step_size = max(
            max_tokens_per_chunk.checked_sub(chunk_overlap).unwrap_or(1),
            1,
        );

        debug_assert_ne!(step_size, 0);

        (0..tokens.len())
            .step_by(step_size)
            .map(|start_idx| {
                let end_idx = usize::min(start_idx + max_tokens_per_chunk, tokens.len());
                self.to_string(tokens.slice(start_idx, end_idx))
            })
            .collect()
    }
}
/// Represents a single token.
#[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(transparent)]
pub struct Token(TokenImpl);

#[derive(Serialize, Deserialize, Clone, Debug)]
enum TokenImpl {
    I32(i32),
    USize(usize),
}

impl From<i32> for Token {
    fn from(value: i32) -> Self {
        Token(TokenImpl::I32(value))
    }
}

impl From<usize> for Token {
    fn from(value: usize) -> Self {
        Token(TokenImpl::USize(value))
    }
}

impl Token {
    pub fn to_i32(&self) -> Option<i32> {
        match &self.0 {
            TokenImpl::I32(x) => Some(*x),
            _ => None,
        }
    }

    pub fn to_usize(&self) -> Option<usize> {
        match &self.0 {
            TokenImpl::USize(x) => Some(*x),
            _ => None,
        }
    }
}

/// A type-safe, enum-backed collection of tokens.
///
/// `TokenCollection` can hold a collection of `i32` or `usize` tokens,
/// ensuring type safety and efficient storage.
#[derive(Debug)]
pub struct TokenCollection(TokenCollectionImpl);

/// The internal enum representation of `TokenCollection`.
///
/// This enum holds the actual data for a `TokenCollection` instance,
/// allowing us to differentiate between the two types of collections
/// (`i32` and `usize`) in a type-safe manner.
#[derive(Debug)]
enum TokenCollectionImpl {
    /// A token collection of `i32`
    I32(Vec<i32>),
    /// A token collection of usize, this should be avoided as the size is non-determinate, but is present in some libraries.
    Usize(Vec<usize>),
}

impl TokenCollection {
    /// Converts the `TokenCollection` into a vector of `i32`,
    /// if it contains `i32` values. Returns `None` otherwise.
    pub fn as_i32(self) -> Result<Vec<i32>, TokenizerError> {
        match self.0 {
            TokenCollectionImpl::I32(v) => Ok(v),
            _ => Err(TokenizerError::TokenCollectionTypeMismatch),
        }
    }

    /// Converts the `TokenCollection` into a vector of `usize`,
    /// if it contains `usize` values. Returns `None` otherwise.
    pub fn as_usize(self) -> Result<Vec<usize>, TokenizerError> {
        match self.0 {
            TokenCollectionImpl::Usize(v) => Ok(v),
            _ => Err(TokenizerError::TokenCollectionTypeMismatch),
        }
    }

    /// Returns the number of tokens in the token collection
    pub fn len(&self) -> usize {
        match &self.0 {
            TokenCollectionImpl::I32(x) => x.len(),
            TokenCollectionImpl::Usize(x) => x.len(),
        }
    }

    /// Returns true if the TokenCollection is empty
    pub fn is_empty(&self) -> bool {
        self.len() == 0
    }

    /// Slices the token collection between start and end.
    pub fn slice(&self, start: usize, end: usize) -> Self {
        match &self.0 {
            TokenCollectionImpl::I32(v) => Vec::from(&v[start..end]).into(),
            TokenCollectionImpl::Usize(v) => Vec::from(&v[start..end]).into(),
        }
    }
}

/// Enables the conversion from a vector of `i32` into a `TokenCollection`.
impl From<Vec<i32>> for TokenCollection {
    fn from(v: Vec<i32>) -> Self {
        TokenCollection(TokenCollectionImpl::I32(v))
    }
}

/// Enables the conversion from a vector of `usize` into a `TokenCollection`.
impl From<Vec<usize>> for TokenCollection {
    fn from(v: Vec<usize>) -> Self {
        TokenCollection(TokenCollectionImpl::Usize(v))
    }
}