kind_openai/endpoints/chat/
structured.rsuse kind_openai_schema::{GeneratedOpenAISchema, OpenAISchema};
use serde::{de::DeserializeOwned, Deserialize, Deserializer, Serialize};
use crate::{endpoints::OpenAIRequestProvider, OpenAIResult, Usage};
use super::{standard::ChatCompletion, FinishReason, UnifiedChatCompletionResponseMessage};
#[derive(Serialize)]
pub struct StructuredChatCompletion<'a, S> {
#[serde(flatten)]
pub(super) base_request: ChatCompletion<'a>,
pub(super) response_format: ChatCompletionRequestResponseFormat,
#[serde(skip)]
pub(super) _phantom: std::marker::PhantomData<S>,
}
#[derive(Serialize)]
#[serde(tag = "type", content = "json_schema", rename_all = "snake_case")]
pub(super) enum ChatCompletionRequestResponseFormat {
JsonSchema(GeneratedOpenAISchema),
}
impl<S> OpenAIRequestProvider for StructuredChatCompletion<'_, S>
where
S: OpenAISchema + for<'de> Deserialize<'de>,
{
type Response = StructuredChatCompletionResponse<S>;
const METHOD: reqwest::Method = reqwest::Method::POST;
fn path_with_leading_slash() -> String {
"/chat/completions".to_string()
}
}
impl<S> super::super::private::Sealed for StructuredChatCompletion<'_, S> {}
#[derive(Deserialize)]
#[serde(bound(deserialize = "S: DeserializeOwned"))]
pub struct StructuredChatCompletionResponse<S> {
choices: Vec<StructuredChatCompletionResponseChoice<S>>,
usage: Usage,
}
impl<S> StructuredChatCompletionResponse<S> {
pub fn take_first_choice(self) -> Option<StructuredChatCompletionResponseChoice<S>> {
self.choices.into_iter().next()
}
pub fn usage(&self) -> Usage {
self.usage
}
}
#[derive(Deserialize)]
#[serde(bound(deserialize = "S: DeserializeOwned"))]
pub struct StructuredChatCompletionResponseChoice<S> {
finish_reason: FinishReason,
index: i32,
message: StructuredChatCompletionResponseMessage<S>,
}
impl<S> StructuredChatCompletionResponseChoice<S> {
pub fn message(self) -> OpenAIResult<S> {
Into::<UnifiedChatCompletionResponseMessage<S>>::into(self.message).into()
}
pub fn finish_reason(&self) -> FinishReason {
self.finish_reason
}
pub fn index(&self) -> i32 {
self.index
}
}
#[derive(Deserialize)]
#[serde(bound(deserialize = "S: DeserializeOwned"))]
struct StructuredChatCompletionResponseMessage<S> {
#[serde(deserialize_with = "de_from_str")]
content: S,
refusal: Option<String>,
}
fn de_from_str<'de, D, S>(deserializer: D) -> Result<S, D::Error>
where
D: Deserializer<'de>,
S: DeserializeOwned,
{
let s = String::deserialize(deserializer)?;
serde_json::from_str(&s).map_err(serde::de::Error::custom)
}
impl<S> From<StructuredChatCompletionResponseMessage<S>>
for UnifiedChatCompletionResponseMessage<S>
{
fn from(value: StructuredChatCompletionResponseMessage<S>) -> Self {
UnifiedChatCompletionResponseMessage {
content: value.content,
refusal: value.refusal,
}
}
}