llm_chain/chains/sequential.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
//! A module for implementing a sequential chain of LLM steps.
//!
//! This module provides the `Chain` struct, which represents a sequential chain of steps for Large Language Models (LLMs). Each step in the chain is executed in order, and the output of the previous step is available as input to the next step.
//!
//! The `Chain` struct allows you to:
//! - Create a new chain with a vector of `Step` instances
//! - Execute the chain with a given set of `Parameters` and an `Executor`
//!
//! The `Chain` struct is designed to work with any executor that implements the `Executor` trait, providing flexibility and extensibility.
//!
//! # Example
//!
//! ```ignore
//!
//! // Assuming an executor `executor` that implements the `Executor` trait.
//! let step1 = Step::new(prompt!("Write a summary for this text: {{text}}"));
//! let step2 = Step::new(prompt!("{{text}}\n\nWrite a tweet thread for the above summary");
//!
//! let chain = Chain::new(vec![step1, step2]);
//!
//! let parameters = parameters!("your input text here")
//!
//! // Execute the chain with the provided parameters and executor.
//! let result = chain.run(parameters, &executor).await;
//! ```
//!
//! This module also provides serialization and deserialization support for the `Chain` struct, allowing you to store and load chains using formats like JSON, YAML, or others.
use serde::{Deserialize, Serialize};
use crate::frame::FormatAndExecuteError;
use crate::output::Output;
use crate::{
frame::Frame, serialization::StorableEntity, step::Step, traits::Executor, Parameters,
};
#[derive(thiserror::Error, Debug)]
/// The `SequentialChainError` enum represents errors that can occur when executing a sequential chain.
pub enum SequentialChainError {
#[error("ExecutorError: {0}")]
FormatAndExecuteError(#[from] FormatAndExecuteError),
#[error("The vector of steps was empty")]
NoSteps,
}
/// A sequential chain is a chain where each step is executed in order, with the output of the previous step being available to the next step.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Chain {
steps: Vec<Step>,
}
impl Chain {
/// Creates a new `Chain` instance with the given sequence of steps.
///
/// # Arguments
///
/// * `steps` - A vector of `Step<E>` objects that define the sequence of steps for the chain.
pub fn new(steps: Vec<Step>) -> Chain {
Chain { steps }
}
/// Creates a new `Chain` instance with a single step.
///
/// # Arguments
///
/// * `step` - A `Step<E>` object that defines the single step for the chain.
pub fn of_one(step: Step) -> Chain {
Chain { steps: vec![step] }
}
/// Executes the chain with the given parameters and executor.
///
/// This method runs each step in the chain in sequence, passing the output of the previous step to the next step.
/// If the chain is empty, an error is returned.
///
/// # Arguments
///
/// * `parameters` - A `Parameters` object containing the input parameters for the chain.
/// * `executor` - A reference to an executor that implements the `Executor` trait.
///
/// # Returns
///
/// * `Ok(E::Output)` - If the chain executes successfully, the output of the last step is returned.
/// * `Err(SequentialChainError<E::Error>)` - If an error occurs during the execution of the chain, the error is returned.
pub async fn run<E>(
&self,
parameters: Parameters,
executor: &E,
) -> Result<Output, SequentialChainError>
where
E: Executor,
{
if self.steps.is_empty() {
return Err(SequentialChainError::NoSteps);
}
let mut current_params = parameters;
for step in &self.steps[..self.steps.len() - 1] {
let body = Frame::new(executor, step)
.format_and_execute(¤t_params)
.await?
.to_immediate()
.await
.map_err(|err| {
SequentialChainError::FormatAndExecuteError(FormatAndExecuteError::Execute(err))
})?
.as_content()
.extract_last_body()
.cloned()
.unwrap_or_default();
current_params = current_params.with_text(body);
}
let last_step = self.steps.last().unwrap();
Ok(Frame::new(executor, last_step)
.format_and_execute(¤t_params)
.await?)
}
}
impl StorableEntity for Chain {
fn get_metadata() -> Vec<(String, String)> {
let base = vec![(
"chain-type".to_string(),
"llm-chain::chains::sequential::Chain".to_string(),
)];
base
}
}