use crate::{
options::Options,
parameters,
prompt::{PromptTemplate, StringTemplateError},
tools::{Tool, ToolError},
traits::{Executor, ExecutorError},
Parameters,
};
use std::time::{Duration, Instant};
use thiserror::Error;
const PROMPT: &str = "Question: Who lived longer, Muhammad Ali or Alan Turing?
Are follow up questions needed here: Yes.
Follow up: How old was Muhammad Ali when he died?
Intermediate answer: Muhammad Ali was 74 years old when he died.
Follow up: How old was Alan Turing when he died?
Intermediate answer: Alan Turing was 41 years old when he died.
So the final answer is: Muhammad Ali
Question: When was the founder of craigslist born?
Are follow up questions needed here: Yes.
Follow up: Who was the founder of craigslist?
Intermediate answer: Craigslist was founded by Craig Newmark.
Follow up: When was Craig Newmark born?
Intermediate answer: Craig Newmark was born on December 6, 1952.
So the final answer is: December 6, 1952
Question: Who was the maternal grandfather of George Washington?
Are follow up questions needed here: Yes.
Follow up: Who was the mother of George Washington?
Intermediate answer: The mother of George Washington was Mary Ball Washington.
Follow up: Who was the father of Mary Ball Washington?
Intermediate answer: The father of Mary Ball Washington was Joseph Ball.
So the final answer is: Joseph Ball
Question: Are both the directors of Jaws and Casino Royale from the same country?
Are follow up questions needed here: Yes.
Follow up: Who is the director of Jaws?
Intermediate answer: The director of Jaws is Steven Spielberg.
Follow up: Where is Steven Spielberg from?
Intermediate answer: The United States.
Follow up: Who is the director of Casino Royale?
Intermediate answer: The director of Casino Royale is Martin Campbell.
Follow up: Where is Martin Campbell from?
Intermediate answer: New Zealand.
So the final answer is: No
Question: {{input}}
Are followup questions needed here:{{agent_scratchpad}}";
#[derive(Debug, PartialEq, Eq)]
pub struct AgentAction {
pub tool: String,
pub tool_input: serde_yaml::Value,
pub log: String,
}
#[derive(Debug, PartialEq)]
pub struct AgentFinish {
pub return_values: Parameters,
pub log: String,
}
#[derive(Debug)]
pub struct AgentIntermediateStep {
pub action: AgentAction,
pub observation: serde_yaml::Value,
}
pub enum AgentIntermediateStepOutput {
Step(AgentIntermediateStep),
Finish(AgentFinish),
}
#[derive(Debug, PartialEq)]
pub enum AgentDecision {
Action(AgentAction),
Finish(AgentFinish),
}
pub trait AgentOutputParser {
type Error;
fn parse(&self, text: String) -> Result<AgentDecision, Self::Error>;
}
#[derive(Debug, Error)]
pub enum SelfAskWithSearchAgentError<T>
where
T: std::fmt::Debug + std::error::Error + ToolError,
{
#[error("Search tool input yaml was not of type string: {0:?}")]
ToolInputNotString(serde_yaml::Value),
#[error(transparent)]
SearchToolError(T),
#[error(transparent)]
ExecutorError(ExecutorError),
#[error(transparent)]
ParserError(#[from] ParserError),
#[error(transparent)]
YamlError(#[from] serde_yaml::Error),
#[error(transparent)]
StringTemplateError(#[from] StringTemplateError),
#[error("Model response was empty or contained no choices")]
NoChoicesReturned,
#[error("Max number of iterations or timeout exceeded. Elapsed: {time_elapsed_seconds}s, {iterations_elapsed} iterations")]
RuntimeExceeded {
time_elapsed_seconds: f64,
iterations_elapsed: u32,
},
}
pub struct SelfAskWithSearchAgentOutputParser {
followup_prefix: String,
intermediate_answer_prefix: String,
acceptable_finish_prefixes: Vec<String>,
}
impl SelfAskWithSearchAgentOutputParser {
pub fn new(
followup_prefix: &str,
intermediate_answer_prefix: &str,
acceptable_finish_prefixes: &[&str],
) -> Self {
Self {
followup_prefix: followup_prefix.into(),
intermediate_answer_prefix: intermediate_answer_prefix.into(),
acceptable_finish_prefixes: acceptable_finish_prefixes
.iter()
.map(|s| s.to_string())
.collect(),
}
}
}
impl Default for SelfAskWithSearchAgentOutputParser {
fn default() -> Self {
Self::new(
"Follow up:",
"Intermediate Answer:",
&[
"Final answer:",
"So the final answer is:",
"So the final answer could be:",
],
)
}
}
#[derive(Debug, Error)]
#[error("No finish line or follow up question was returned by the model: {0}")]
pub struct ParserError(String);
impl AgentOutputParser for SelfAskWithSearchAgentOutputParser {
type Error = ParserError;
fn parse(&self, text: String) -> Result<AgentDecision, Self::Error> {
if let Some(followup_idx) = text.find(&self.followup_prefix) {
let (followup_question, log) = if let Some(intermediate_answer_idx) =
text.find(&self.intermediate_answer_prefix)
{
let followup_question = text
.chars()
.skip(followup_idx + self.followup_prefix.len())
.take(intermediate_answer_idx - (followup_idx + self.followup_prefix.len()))
.collect::<String>()
.trim()
.to_owned();
let log = text.chars().take(intermediate_answer_idx).collect();
(followup_question, log)
} else {
let followup_question = text
.chars()
.skip(followup_idx + self.followup_prefix.len())
.take_while(|&c| c != '\n')
.collect::<String>()
.trim()
.to_owned();
let log = text
.char_indices()
.map_while(|(idx, c)| {
if c != '\n' || idx < followup_idx {
Some(c)
} else {
None
}
})
.collect();
(followup_question, log)
};
Ok(AgentDecision::Action(AgentAction {
tool: "Intermediate Answer".into(),
tool_input: followup_question.into(),
log,
}))
} else if let Some((idx, prefix)) = self
.acceptable_finish_prefixes
.iter()
.find_map(|prefix| text.find(prefix).map(|idx| (idx, prefix)))
{
let final_answer = text.chars().skip(idx + prefix.len()).collect::<String>();
Ok(AgentDecision::Finish(AgentFinish {
return_values: parameters!("output" => final_answer.trim()),
log: text,
}))
} else {
Err(ParserError(text))
}
}
}
#[derive(Default)]
pub struct EarlyStoppingConfig {
pub max_iterations: Option<u32>,
pub max_time_elapsed_seconds: Option<f64>,
}
pub struct Agent<E, T>
where
E: Executor,
T: Tool,
T::Input: From<String>,
T::Output: Into<String>,
{
executor: E,
search_tool: T,
early_stopping_config: EarlyStoppingConfig,
observation_prefix: String,
llm_prefix: String,
output_parser: SelfAskWithSearchAgentOutputParser,
}
impl<E, T> Agent<E, T>
where
E: Executor,
T: Tool,
T::Input: From<String>,
T::Output: Into<String>,
{
pub fn new(executor: E, search_tool: T, early_stopping_config: EarlyStoppingConfig) -> Self {
Self {
executor,
search_tool,
early_stopping_config,
observation_prefix: "Intermediate answer: ".to_string(),
llm_prefix: "".to_string(),
output_parser: SelfAskWithSearchAgentOutputParser::default(),
}
}
fn should_continue(&self, iterations_elapsed: u32, time_elapsed_seconds: f64) -> bool {
match (
self.early_stopping_config.max_iterations,
self.early_stopping_config.max_time_elapsed_seconds,
) {
(None, None) => true,
(None, Some(max_time_elapsed_seconds)) => {
max_time_elapsed_seconds >= time_elapsed_seconds
}
(Some(max_iterations), None) => max_iterations >= iterations_elapsed,
(Some(max_iterations), Some(max_time_elapsed_seconds)) => {
max_iterations >= iterations_elapsed
&& max_time_elapsed_seconds >= time_elapsed_seconds
}
}
}
async fn take_next_step(
&self,
intermediate_steps: &Vec<AgentIntermediateStep>,
query: &str,
) -> Result<AgentIntermediateStepOutput, SelfAskWithSearchAgentError<<T as Tool>::Error>> {
let output = self.plan(intermediate_steps, query).await?;
let decision = self.output_parser.parse(output)?;
match decision {
AgentDecision::Action(action) => {
let observation = self
.search_tool
.invoke_typed(
&action
.tool_input
.as_str()
.ok_or(SelfAskWithSearchAgentError::ToolInputNotString(
action.tool_input.clone(),
))?
.to_string()
.into(),
)
.await
.map_err(SelfAskWithSearchAgentError::SearchToolError)?;
Ok(AgentIntermediateStepOutput::Step(AgentIntermediateStep {
action,
observation: serde_yaml::to_value(Into::<String>::into(observation))?,
}))
}
AgentDecision::Finish(finish) => Ok(AgentIntermediateStepOutput::Finish(finish)),
}
}
pub fn build_agent_scratchpad(
&self,
intermediate_steps: &Vec<AgentIntermediateStep>,
) -> String {
let mut scratchpad = "".to_string();
for intermediate_step in intermediate_steps {
scratchpad += &intermediate_step.action.log;
scratchpad += &format!(
"\n{}{}\n{}",
self.observation_prefix,
intermediate_step.observation.as_str().unwrap_or_default(),
self.llm_prefix
);
}
scratchpad
}
async fn plan(
&self,
intermediate_steps: &Vec<AgentIntermediateStep>,
query: &str,
) -> Result<String, SelfAskWithSearchAgentError<<T as Tool>::Error>> {
let scratchpad = self.build_agent_scratchpad(intermediate_steps);
let template_parameters = parameters!("input" => query, "agent_scratchpad" => scratchpad);
let prompt = PromptTemplate::Text(PROMPT.into()).format(&template_parameters)?;
let plan = self
.executor
.execute(Options::empty(), &prompt)
.await
.map_err(SelfAskWithSearchAgentError::ExecutorError)?;
plan.to_immediate()
.await
.map_err(SelfAskWithSearchAgentError::ExecutorError)?
.as_content()
.extract_last_body()
.cloned()
.ok_or(SelfAskWithSearchAgentError::NoChoicesReturned)
}
pub async fn run(
&self,
query: &str,
) -> Result<
(AgentFinish, Vec<AgentIntermediateStep>),
SelfAskWithSearchAgentError<<T as Tool>::Error>,
> {
let mut intermediate_steps = vec![];
let mut iterations = 0;
let start = Instant::now();
let mut full_duration = Duration::from_nanos(0);
while self.should_continue(iterations, full_duration.as_secs_f64()) {
let decision = self.take_next_step(&intermediate_steps, query).await?;
full_duration = start.elapsed();
iterations += 1;
match decision {
AgentIntermediateStepOutput::Step(step) => intermediate_steps.push(step),
AgentIntermediateStepOutput::Finish(finish) => {
return Ok((finish, intermediate_steps))
}
}
}
Err(SelfAskWithSearchAgentError::RuntimeExceeded {
time_elapsed_seconds: full_duration.as_secs_f64(),
iterations_elapsed: iterations,
})
}
}
#[cfg(test)]
mod tests {
use async_trait::async_trait;
use thiserror::Error;
use crate::{
agents::self_ask_with_search::{AgentIntermediateStep, EarlyStoppingConfig},
options::Options,
output::Output,
parameters,
prompt::Prompt,
tokens::{TokenCollection, Tokenizer},
tools::{Tool, ToolError},
traits::{Executor, ExecutorError},
};
use super::{
Agent, AgentAction, AgentDecision, AgentFinish, AgentOutputParser,
SelfAskWithSearchAgentOutputParser,
};
#[test]
fn test_parses_followup() {
let parser = SelfAskWithSearchAgentOutputParser::default();
let text = "
Whatever
Whatever
Follow up: my follow up question abc?";
let decision = parser.parse(text.into()).unwrap();
assert_eq!(
decision,
AgentDecision::Action(AgentAction {
tool: "Intermediate Answer".into(),
tool_input: "my follow up question abc?".into(),
log: text.into()
})
);
}
#[test]
fn test_parses_follow_up_trims_trailing_whitespace() {
let parser = SelfAskWithSearchAgentOutputParser::default();
let text = "
Whatever
Whatever
Follow up: my follow up question abc?
";
let decision = parser.parse(text.into()).unwrap();
assert_eq!(
decision,
AgentDecision::Action(AgentAction {
tool: "Intermediate Answer".into(),
tool_input: "my follow up question abc?".into(),
log: text.trim_end().into()
})
);
}
#[test]
fn test_parses_final_answer() {
let parser = SelfAskWithSearchAgentOutputParser::default();
let text = "
Whatever
Whatever
So the final answer is: yes abc!";
let decision = parser.parse(text.into()).unwrap();
assert_eq!(
decision,
AgentDecision::Finish(AgentFinish {
return_values: parameters!("output" => "yes abc!"),
log: text.into()
})
);
}
#[test]
fn test_parses_final_answer_ignores_trailing_whitespace() {
let parser = SelfAskWithSearchAgentOutputParser::default();
let text = "
Whatever
Whatever
So the final answer is: yes abc!
";
let decision = parser.parse(text.into()).unwrap();
assert_eq!(
decision,
AgentDecision::Finish(AgentFinish {
return_values: parameters!("output" => "yes abc!"),
log: text.into()
})
);
}
#[test]
fn test_parses_final_answer_with_colons() {
let parser = SelfAskWithSearchAgentOutputParser::default();
let text = "
Whatever
Whatever
So the final answer is: Mad Max: Fury road";
let decision = parser.parse(text.into()).unwrap();
assert_eq!(
decision,
AgentDecision::Finish(AgentFinish {
return_values: parameters!("output" => "Mad Max: Fury road"),
log: text.into()
})
);
}
#[test]
fn test_builds_agent_sratchpad() {
#[derive(Clone)]
struct MockOutput;
#[derive(Debug, Error)]
#[error("Mocked executor error")]
struct MockError;
impl ToolError for MockError {}
impl From<serde_yaml::Error> for MockError {
fn from(_: serde_yaml::Error) -> Self {
Self
}
}
struct MockTokenizer;
impl Tokenizer for MockTokenizer {
fn tokenize_str(
&self,
_: &str,
) -> Result<TokenCollection, crate::tokens::TokenizerError> {
todo!()
}
fn to_string(
&self,
_: TokenCollection,
) -> Result<String, crate::tokens::TokenizerError> {
todo!()
}
}
struct MockExecutor;
#[async_trait]
impl Executor for MockExecutor {
type StepTokenizer<'a> = MockTokenizer;
fn new_with_options(_: Options) -> Result<Self, crate::traits::ExecutorCreationError> {
todo!()
}
async fn execute(
&self,
_: &Options,
_: &crate::prompt::Prompt,
) -> Result<Output, ExecutorError> {
todo!()
}
fn tokens_used(
&self,
_: &Options,
_: &crate::prompt::Prompt,
) -> Result<crate::tokens::TokenCount, crate::tokens::PromptTokensError> {
todo!()
}
fn answer_prefix(&self, _prompt: &Prompt) -> Option<String> {
todo!()
}
fn max_tokens_allowed(&self, _: &Options) -> i32 {
todo!()
}
fn get_tokenizer(
&self,
_: &Options,
) -> Result<MockTokenizer, crate::tokens::TokenizerError> {
todo!()
}
}
struct MockSearch;
#[async_trait]
impl Tool for MockSearch {
type Input = String;
type Output = String;
type Error = MockError;
async fn invoke_typed(&self, _: &Self::Input) -> Result<Self::Output, Self::Error> {
todo!()
}
fn description(&self) -> crate::tools::ToolDescription {
todo!()
}
}
let mock_executor = MockExecutor;
let mock_search = MockSearch;
let agent = Agent::new(
mock_executor,
mock_search,
EarlyStoppingConfig {
max_iterations: None,
max_time_elapsed_seconds: None,
},
);
let intermediate_steps = vec![
AgentIntermediateStep {
action: AgentAction {
tool: "Intermediate Answer".into(),
tool_input: "How old was Muhammad Ali when he died?".into(),
log: "Yes.
Follow up: How old was Muhammad Ali when he died?"
.into(),
},
observation: "Muhammad Ali was 74 years old when he died.".into(),
},
AgentIntermediateStep {
action: AgentAction {
tool: "Intermediate Answer".into(),
tool_input: "How old was Alan Turing when he died?".into(),
log: "Follow up: How old was Alan Turing when he died?".into(),
},
observation: "Alan Turing was 41 years old when he died.".into(),
},
];
let expected_scratchpad = "Yes.
Follow up: How old was Muhammad Ali when he died?
Intermediate answer: Muhammad Ali was 74 years old when he died.
Follow up: How old was Alan Turing when he died?
Intermediate answer: Alan Turing was 41 years old when he died.\n";
let scratchpad = agent.build_agent_scratchpad(&intermediate_steps);
assert_eq!(scratchpad, expected_scratchpad);
}
}