llm_chain/chains/map_reduce.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
//! The `map_reduce` module contains the `Chain` struct, which represents a map-reduce chain.
//!
//! A map-reduce chain is a combination of two steps - a `map` step and a `reduce` step.
//! The `map` step processes each input document and the `reduce` step combines the results
//! of the `map` step into a single output.
//!
//! The `Chain` struct is generic over the type of the `Step` and provides a convenient way
//! to execute map-reduce operations using a provided `Executor`.
use crate::traits::ExecutorError;
use crate::{
frame::Frame, output::Output, prompt::Data, serialization::StorableEntity, step::Step, tokens,
tokens::PromptTokensError, traits::Executor, Parameters,
};
use futures::future::join_all;
use futures::future::FutureExt;
use serde::Deserialize;
use serde::Serialize;
use thiserror::Error;
/// The `MapReduceChainError` enum represents errors that can occur when executing a map-reduce chain.
#[derive(Error, Debug)]
pub enum MapReduceChainError {
/// An error relating to the operation of the Executor.
#[error("FormatAndExecuteError: {0}")]
FormatAndExecuteError(#[from] crate::frame::FormatAndExecuteError),
/// An error relating to tokenizing the inputs.
#[error("TokenizeError: {0}")]
TokenizeError(#[from] crate::tokens::PromptTokensError),
#[error("The vector of input documents was empty")]
InputEmpty,
#[error("Error templating: {0}")]
StringTemplate(#[from] crate::prompt::StringTemplateError),
}
/// The `Chain` struct represents a map-reduce chain, consisting of a `map` step and a `reduce` step.
///
/// The struct is generic over the type of the `Step` and provides methods for constructing and
/// executing the chain using a given `Executor`.
#[derive(Serialize, Deserialize)]
pub struct Chain {
map: Step,
reduce: Step,
}
impl Chain {
/// Constructs a new `Chain` with the given `map` and `reduce` steps.
///
/// The `new` function takes two instances of `Step` and returns a new `Chain` instance.
pub fn new(map: Step, reduce: Step) -> Chain {
Chain { map, reduce }
}
/// Executes the map-reduce chain using the provided `Executor`.
///
/// The `run` function takes a vector of input documents, a base set of parameters, and a reference
/// to an `Executor`. It processes the input documents using the `map` step and the `reduce` step,
/// and returns the result as an `Option<E::Output>`.
///
/// The function is asynchronous and must be awaited.
pub async fn run<E: Executor>(
&self,
documents: Vec<Parameters>,
base_parameters: Parameters,
executor: &E,
) -> Result<Output, MapReduceChainError> {
if documents.is_empty() {
return Err(MapReduceChainError::InputEmpty);
}
let map_frame = Frame::new(executor, &self.map);
let reduce_frame = Frame::new(executor, &self.reduce);
let chunked_docs = self.chunk_documents(
documents.clone(),
base_parameters.clone(),
executor,
&self.map,
)?;
// Execute the `map` step for each document, combining the base parameters with each document's parameters.
let chunked_docs_with_base_parameters: Vec<_> = chunked_docs
.iter()
.map(|doc| base_parameters.combine(doc))
.collect();
let mapped_documents: Vec<_> = join_all(
chunked_docs_with_base_parameters
.iter()
.map(|doc| map_frame.format_and_execute(doc))
.collect::<Vec<_>>(),
)
.await;
let mapped_documents = mapped_documents
.into_iter()
.collect::<Result<Vec<Output>, _>>()?;
let mapped_documents: Vec<Result<Data<String>, ExecutorError>> = join_all(
mapped_documents
.into_iter()
.map(|x| x.to_immediate().map(|x| x.map(|y| y.as_content())))
.collect::<Vec<_>>(),
)
.await;
let mapped_documents: Vec<Data<String>> = mapped_documents
.into_iter()
.collect::<Result<Vec<Data<String>>, ExecutorError>>()
.map_err(|e| {
MapReduceChainError::FormatAndExecuteError(
crate::frame::FormatAndExecuteError::Execute(e),
)
})?;
let mut documents = self
.combine_documents_up_to(executor, mapped_documents, &base_parameters)
.await?;
if documents.is_empty() {
return Err(MapReduceChainError::InputEmpty);
}
loop {
let tasks: Vec<_> = documents
.iter()
.map(|doc| base_parameters.with_text(doc))
.collect();
let futures = tasks.iter().map(|p| reduce_frame.format_and_execute(p));
let new_docs = join_all(futures).await;
let new_docs = new_docs.into_iter().collect::<Result<Vec<_>, _>>()?;
let new_docs = join_all(
new_docs
.into_iter()
.map(|x| x.to_immediate().map(|x| x.map(|y| y.as_content()))),
)
.await;
let new_docs = new_docs
.into_iter()
.collect::<Result<Vec<Data<String>>, ExecutorError>>()
.map_err(|e| {
MapReduceChainError::FormatAndExecuteError(
crate::frame::FormatAndExecuteError::Execute(e),
)
})?;
let n_new_docs = new_docs.len();
if n_new_docs == 1 {
return Ok(Output::new_immediate(new_docs[0].clone()));
}
documents = self
.combine_documents_up_to(executor, new_docs, &base_parameters)
.await?;
}
}
async fn combine_documents_up_to<E: Executor>(
&self,
executor: &E,
mut v: Vec<Data<String>>,
parameters: &Parameters,
) -> Result<Vec<String>, MapReduceChainError> {
let mut new_outputs = Vec::new();
while let Some(current) = v.pop() {
let mut current_doc = current.extract_last_body().cloned().unwrap_or_default();
while let Some(next) = v.last() {
let Some(next_doc_content) = next.extract_last_body() else {
continue;
};
let mut new_doc = current_doc.clone();
new_doc.push('\n');
new_doc.push_str(next_doc_content);
let params = parameters.with_text(new_doc.clone());
let prompt = self.reduce.format(¶ms)?;
let count = executor.tokens_used(self.reduce.options(), &prompt)?;
if count.has_tokens_remaining() {
current_doc = new_doc;
v.pop();
} else {
break;
}
}
new_outputs.push(current_doc);
}
Ok(new_outputs)
}
fn chunk_documents<'a, E>(
&self,
v: Vec<Parameters>,
base_parameters: Parameters,
executor: &E,
step: &Step,
) -> Result<Vec<Parameters>, PromptTokensError>
where
E: Executor + 'a,
{
let data: Result<Vec<_>, _> = v
.iter()
.map(|x| {
<E as tokens::ExecutorTokenCountExt>::split_to_fit(
executor,
step,
x,
&base_parameters,
None,
)
})
.collect();
let data = data?.iter().flatten().cloned().collect();
Ok(data)
}
}
/// Implements the `StorableEntity` trait for the `Chain` struct.
///
/// This implementation provides a method for extracting metadata from a `Chain` instance, in order to identify it
impl StorableEntity for Chain {
/// Returns metadata about the Chain instance.
///
/// The metadata is returned as a vector of tuples, where each tuple contains a key-value pair
/// representing a metadata field and its value.
///
/// This method also extracts metadata from the Step instances associated with the Chain.
fn get_metadata() -> Vec<(String, String)> {
let base = vec![(
"chain-type".to_string(),
"llm-chain::chains::map_reduce::Chain".to_string(),
)];
base
}
}