kind_openai/endpoints/chat/
standard.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
use std::collections::HashMap;

use bon::Builder;
use chat_completion_builder::IsComplete;
use kind_openai_schema::OpenAISchema;
use reqwest::Method;
use serde::{Deserialize, Serialize};

use crate::{endpoints::OpenAIRequestProvider, OpenAIResult, Usage};

use super::{
    structured::{ChatCompletionRequestResponseFormat, StructuredChatCompletion},
    FinishReason, Message, Model, UnifiedChatCompletionResponseMessage,
};

/// A standard chat completion request. The response will be a string in any shape and will not
/// be parsed.
#[derive(Serialize, Builder)]
#[builder(start_fn = model, finish_fn = unstructured, state_mod(vis = "pub"))]
pub struct ChatCompletion<'a> {
    #[builder(start_fn)]
    model: Model,
    messages: Vec<Message<'a>>,
    temperature: Option<f32>,
    top_p: Option<f32>,
    store: Option<bool>,
    metadata: Option<HashMap<String, String>>,
    logit_bias: Option<HashMap<i32, i32>>,
}

impl OpenAIRequestProvider for ChatCompletion<'_> {
    type Response = ChatCompletionResponse;

    const METHOD: Method = Method::POST;

    fn path_with_leading_slash() -> String {
        "/chat/completions".to_string()
    }
}

impl super::super::private::Sealed for ChatCompletion<'_> {}

// this is a neat trick where we can take a completed builder and allow it to be "upgraded".
// because of the `finish_fn` specification, we can either resolve and build immediately with
// `.unstructured()`, or we can call `.structured()` and provide a schema. doing it this way
// enables us to nicely represent the `ChatCompletionRequest` without having to specify the
// generic type.
impl<'a, S> ChatCompletionBuilder<'a, S>
where
    S: IsComplete,
{
    /// Upgrades a chat completion request to a structured chat completion request.
    /// Unless the return type can be inferred, you probably want to call this like so:
    /// `.structured::<MySchemadType>();`
    pub fn structured<SS>(self) -> StructuredChatCompletion<'a, SS>
    where
        SS: OpenAISchema,
    {
        StructuredChatCompletion {
            base_request: self.unstructured(),
            response_format: ChatCompletionRequestResponseFormat::JsonSchema(SS::openai_schema()),
            _phantom: std::marker::PhantomData,
        }
    }
}

/// A response from a chat completion request.
#[derive(Deserialize)]
pub struct ChatCompletionResponse {
    choices: Vec<ChatCompletionResponseChoice>,
    usage: Usage,
}

impl ChatCompletionResponse {
    /// Takes the first message in the response consumes the response.
    pub fn take_first_choice(self) -> Option<ChatCompletionResponseChoice> {
        self.choices.into_iter().next()
    }

    /// Gives the usage tokens of the response.
    pub fn usage(&self) -> &Usage {
        &self.usage
    }
}

/// A response choice from a chat completion request.
#[derive(Deserialize)]
pub struct ChatCompletionResponseChoice {
    finish_reason: FinishReason,
    index: i32,
    message: ChatCompletionResponseMessage,
}

impl ChatCompletionResponseChoice {
    /// Takes the message and returns a result that may contain a refusal.
    pub fn message(self) -> OpenAIResult<String> {
        Into::<UnifiedChatCompletionResponseMessage<String>>::into(self.message).into()
    }

    pub fn finish_reason(&self) -> FinishReason {
        self.finish_reason
    }

    pub fn index(&self) -> i32 {
        self.index
    }
}

// leave private, messages should only be interacted with through the unified message type.
#[derive(Deserialize)]
struct ChatCompletionResponseMessage {
    content: String,
    refusal: Option<String>,
}

impl From<ChatCompletionResponseMessage> for UnifiedChatCompletionResponseMessage<String> {
    fn from(value: ChatCompletionResponseMessage) -> Self {
        UnifiedChatCompletionResponseMessage {
            content: value.content,
            refusal: value.refusal,
        }
    }
}

#[macro_export]
macro_rules! logit_bias {
    () => {
        std::collections::HashMap::new()
    };

    ($($key:tt : $value:expr),+ $(,)?) => {{
        let mut map = std::collections::HashMap::new();
        $(
            map.insert($key as i32, $value as i32);
        )+
        map
    }};
}