kind_openai/
endpoints.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
use reqwest::Method;
use serde::{Deserialize, Serialize};

use crate::{auth, error::OpenAIAPIError, OpenAI, OpenAIResult};

pub mod chat;
pub mod chat_reasoning;
pub mod embeddings;

const API_BASE_URL: &str = "https://api.openai.com/v1";

// this enum and the struct below it are hacks to deal with openai's weird response format
// where they will return either a single error field or the success payload.
#[derive(Deserialize)]
#[serde(untagged)]
enum GenericOpenAIResponse<T> {
    Success(T),
    Error(ResponseDeserializableOpenAIAPIError),
}

#[derive(Deserialize)]
struct ResponseDeserializableOpenAIAPIError {
    error: OpenAIAPIError,
}

impl<T> From<GenericOpenAIResponse<T>> for OpenAIResult<T> {
    fn from(value: GenericOpenAIResponse<T>) -> Self {
        match value {
            GenericOpenAIResponse::Success(success) => Ok(success),
            GenericOpenAIResponse::Error(error) => Err(crate::OpenAIError::API(error.error)),
        }
    }
}

pub(super) async fn send_request<Auth, R>(
    openai: &OpenAI<Auth>,
    request: &R,
) -> OpenAIResult<R::Response>
where
    Auth: auth::AuthTokenProvider,
    R: OpenAIRequestProvider,
{
    let bearer_token = openai
        .auth
        .resolve()
        .await
        .ok_or(crate::error::OpenAIError::MissingAuthToken)?;

    // take the response text and deserialize by hand so we can log response
    // bodies that don't conform to the same structure
    let response_text = openai
        .client
        .request(
            R::METHOD,
            format!("{API_BASE_URL}{}", R::path_with_leading_slash()),
        )
        .header("Authorization", format!("Bearer {bearer_token}"))
        // TODO: support a way to omit the body during a get request if the time comes
        .json(request)
        .send()
        .await?
        .text()
        .await?;

    match serde_json::from_str::<GenericOpenAIResponse<R::Response>>(&response_text) {
        Ok(response) => response.into(),
        Err(err) => Err(crate::error::OpenAIError::Serde(response_text, err)),
    }
}

mod private {
    pub trait Sealed {}
}

/// Any type that can be sent to the client's `req` method.
pub trait OpenAIRequestProvider: Serialize + private::Sealed {
    type Response: for<'de> Deserialize<'de>;
    const METHOD: Method;

    fn path_with_leading_slash() -> String;
}