kind_openai/endpoints/chat/
standard.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
use bon::Builder;
use chat_completion_builder::IsComplete;
use kind_openai_schema::OpenAISchema;
use reqwest::Method;
use serde::{Deserialize, Serialize};

use crate::{endpoints::OpenAIRequestProvider, OpenAIResult, Usage};

use super::{
    structured::{ChatCompletionRequestResponseFormat, StructuredChatCompletion},
    FinishReason, Message, Model, UnifiedChatCompletionResponseMessage,
};

/// A standard chat completion request. The response will be a string in any shape and will not
/// be parsed.
#[derive(Serialize, Builder)]
#[builder(start_fn = model, finish_fn = unstructured)]
pub struct ChatCompletion<'a> {
    #[builder(start_fn)]
    model: Model,
    messages: Vec<Message<'a>>,
    temperature: Option<f32>,
    top_p: Option<f32>,
}

impl OpenAIRequestProvider for ChatCompletion<'_> {
    type Response = ChatCompletionResponse;

    const METHOD: Method = Method::POST;

    fn path_with_leading_slash() -> String {
        "/chat/completions".to_string()
    }
}

impl super::super::private::Sealed for ChatCompletion<'_> {}

// this is a neat trick where we can take a completed builder and allow it to be "upgraded".
// because of the `finish_fn` specification, we can either resolve and build immediately with
// `.unstructured()`, or we can call `.structured()` and provide a schema. doing it this way
// enables us to nicely represent the `ChatCompletionRequest` without having to specify the
// generic type.
impl<'a, S> ChatCompletionBuilder<'a, S>
where
    S: IsComplete,
{
    /// Upgrades a chat completion request to a structured chat completion request.
    /// Unless the return type can be inferred, you probably want to call this like so:
    /// `.structured::<MySchemadType>();`
    pub fn structured<SS>(self) -> StructuredChatCompletion<'a, SS>
    where
        SS: OpenAISchema,
    {
        StructuredChatCompletion {
            base_request: self.unstructured(),
            response_format: ChatCompletionRequestResponseFormat::JsonSchema(SS::openai_schema()),
            _phantom: std::marker::PhantomData,
        }
    }
}

/// A response from a chat completion request.
#[derive(Deserialize)]
pub struct ChatCompletionResponse {
    choices: Vec<ChatCompletionResponseChoice>,
    usage: Usage,
}

impl ChatCompletionResponse {
    /// Takes the first message in the response consumes the response.
    pub fn take_first_choice(self) -> Option<ChatCompletionResponseChoice> {
        self.choices.into_iter().next()
    }

    /// Gives the usage tokens of the response.
    pub fn usage(&self) -> &Usage {
        &self.usage
    }
}

/// A response choice from a chat completion request.
#[derive(Deserialize)]
pub struct ChatCompletionResponseChoice {
    finish_reason: FinishReason,
    index: i32,
    message: ChatCompletionResponseMessage,
}

impl ChatCompletionResponseChoice {
    /// Takes the message and returns a result that may contain a refusal.
    pub fn message(self) -> OpenAIResult<String> {
        Into::<UnifiedChatCompletionResponseMessage<String>>::into(self.message).into()
    }

    pub fn finish_reason(&self) -> FinishReason {
        self.finish_reason
    }

    pub fn index(&self) -> i32 {
        self.index
    }
}

// leave private, messages should only be interacted with through the unified message type.
#[derive(Deserialize)]
struct ChatCompletionResponseMessage {
    content: String,
    refusal: Option<String>,
}

impl From<ChatCompletionResponseMessage> for UnifiedChatCompletionResponseMessage<String> {
    fn from(value: ChatCompletionResponseMessage) -> Self {
        UnifiedChatCompletionResponseMessage {
            content: value.content,
            refusal: value.refusal,
        }
    }
}