llm_chain_openai/chatgpt/model.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
use llm_chain::options::{ModelRef, Opt};
use serde::{Deserialize, Serialize};
use strum_macros::EnumString;
/// The `Model` enum represents the available ChatGPT models that you can use through the OpenAI
/// API.
///
/// These models have different capabilities and performance characteristics, allowing you to choose
/// the one that best suits your needs. See <https://platform.openai.com/docs/models> for more
/// information.
///
/// # Example
///
/// ```
/// use llm_chain_openai::chatgpt::Model;
///
/// let turbo_model = Model::Gpt35Turbo;
/// let custom_model = Model::Other("your_custom_model_name".to_string());
/// ```
#[derive(Debug, Default, Clone, Serialize, Deserialize, EnumString, PartialEq, Eq)]
#[non_exhaustive]
pub enum Model {
/// A high-performance and versatile model that offers a great balance of speed, quality, and
/// affordability.
#[default]
#[strum(
serialize = "gpt-3.5-turbo",
serialize = "gpt-35-turbo",
serialize = "gpt3.5",
serialize = "gpt35"
)]
Gpt35Turbo,
/// Snapshot of gpt-3.5-turbo from March 1st 2023. Unlike gpt-3.5-turbo, this model will not
/// receive updates, and will be deprecated 3 months after a new version is released.
#[strum(serialize = "gpt-3.5-turbo-0301")]
Gpt35Turbo0301,
/// A high-performance model that offers the best quality, but is slower and more expensive than
/// the `ChatGPT3_5Turbo` model.
#[strum(serialize = "gpt-4", serialize = "gpt4")]
Gpt4,
/// Snapshot of gpt-4 from March 14th 2023. Unlike gpt-4, this model will not receive updates,
/// and will be deprecated 3 months after a new version is released.
#[strum(serialize = "gpt-4-0314")]
Gpt4_0314,
/// Same capabilities as the base gpt-4 mode but with 4x the context length. Will be updated
/// with our latest model iteration.
#[strum(serialize = "gpt-4-32k")]
Gpt4_32k,
/// Snapshot of gpt-4-32 from March 14th 2023. Unlike gpt-4-32k, this model will not receive
/// updates, and will be deprecated 3 months after a new version is released.
#[strum(serialize = "gpt-4-32k-0314")]
Gpt4_32k0314,
/// A variant that allows you to specify a custom model name as a string, in case new models
/// are introduced or you have access to specialized models.
#[strum(default)]
Other(String),
}
impl Model {
/// included for backwards compatibility
#[deprecated(note = "Use `Model::Gpt35Turbo` instead")]
#[allow(non_upper_case_globals)]
pub const ChatGPT3_5Turbo: Model = Model::Gpt35Turbo;
/// included for backwards compatibility
#[deprecated(note = "Use `Model::Gpt4` instead")]
pub const GPT4: Model = Model::Gpt4;
}
/// The `Model` enum implements the `ToString` trait, allowing you to easily convert it to a string.
impl ToString for Model {
fn to_string(&self) -> String {
match &self {
Model::Gpt35Turbo => "gpt-3.5-turbo".to_string(),
Model::Gpt4 => "gpt-4".to_string(),
Model::Gpt35Turbo0301 => "gpt-3.5-turbo-0301".to_string(),
Model::Gpt4_0314 => "gpt-4-0314".to_string(),
Model::Gpt4_32k => "gpt-4-32k".to_string(),
Model::Gpt4_32k0314 => "gpt-4-32k-0314".to_string(),
Model::Other(model) => model.to_string(),
}
}
}
/// Conversion from Model to ModelRef
impl From<Model> for ModelRef {
fn from(value: Model) -> Self {
ModelRef::from_model_name(value.to_string())
}
}
/// Conversion from Model to Option
impl From<Model> for Opt {
fn from(value: Model) -> Self {
Opt::Model(value.into())
}
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use super::*;
// Tests for FromStr
#[test]
fn test_from_str() -> Result<(), Box<dyn std::error::Error>> {
assert_eq!(Model::from_str("gpt-3.5-turbo")?, Model::Gpt35Turbo);
assert_eq!(
Model::from_str("gpt-3.5-turbo-0301")?,
Model::Gpt35Turbo0301
);
assert_eq!(Model::from_str("gpt-4")?, Model::Gpt4);
assert_eq!(Model::from_str("gpt-4-0314")?, Model::Gpt4_0314);
assert_eq!(Model::from_str("gpt-4-32k")?, Model::Gpt4_32k);
assert_eq!(Model::from_str("gpt-4-32k-0314")?, Model::Gpt4_32k0314);
assert_eq!(
Model::from_str("custom_model")?,
Model::Other("custom_model".to_string())
);
Ok(())
}
// Test ToString
#[test]
fn test_to_string() {
assert_eq!(Model::Gpt35Turbo.to_string(), "gpt-3.5-turbo");
assert_eq!(Model::Gpt4.to_string(), "gpt-4");
assert_eq!(Model::Gpt35Turbo0301.to_string(), "gpt-3.5-turbo-0301");
assert_eq!(Model::Gpt4_0314.to_string(), "gpt-4-0314");
assert_eq!(Model::Gpt4_32k.to_string(), "gpt-4-32k");
assert_eq!(Model::Gpt4_32k0314.to_string(), "gpt-4-32k-0314");
assert_eq!(
Model::Other("custom_model".to_string()).to_string(),
"custom_model"
);
}
#[test]
#[allow(deprecated)]
fn test_to_string_deprecated() {
assert_eq!(Model::ChatGPT3_5Turbo.to_string(), "gpt-3.5-turbo");
assert_eq!(Model::GPT4.to_string(), "gpt-4");
}
}