gemini_ai/
lib.rs

1use std::collections::HashMap;
2
3use httparse::{Response, EMPTY_HEADER};
4use serde::{Deserialize, Serialize};
5
6// pub mod cloud;
7pub mod content_gen;
8// pub mod file;
9// pub mod error;
10pub mod format;
11// pub mod pulse;
12pub mod schema;
13// pub mod tunemodel;
14
15#[derive(Debug, Clone)]
16pub struct GeminiContentGen<'gemini> {
17    env_variable: &'gemini str,
18    model: &'gemini str,
19    max_len: u64,
20    instruction: &'gemini str,
21    text: &'gemini str,
22    config: Config<'gemini>,
23    memory: MemoryType,
24}
25
26#[derive(Debug)]
27pub enum TokenLen {
28    Default,
29    Custome(u64),
30}
31#[derive(Debug, Clone)]
32pub struct Config<'config> {
33    pub response: Kind<'config>,
34}
35
36#[derive(Debug, Clone)]
37pub enum Kind<'response> {
38    Json(&'response str),
39    Text,
40    Audio(&'response Vec<u8>),
41    Transcribe(&'response Vec<u8>),
42    Image(&'response Vec<u8>),
43    Video(&'response Vec<u8>),
44    Pdf(&'response Vec<u8>),
45    Csv(&'response Vec<u8>),
46    Rag(&'response [&'response str]),
47}
48
49#[derive(Debug)]
50pub struct Gemini<
51    'gemini,
52    EnvState,
53    ModelState,
54    ConfigState,
55    InstructionState,
56    TextState,
57    MaxState,
58    PropertiesState,
59    MemoryState,
60> {
61    env_variable: &'gemini str,
62    model: &'gemini str,
63    instruction: &'gemini str,
64    max_len: u64,
65    text: &'gemini str,
66    memory: MemoryType,
67    config: ConfigBuilder<'gemini, PropertiesState>,
68    envstate: std::marker::PhantomData<EnvState>,
69    modelstate: std::marker::PhantomData<ModelState>,
70    configstate: std::marker::PhantomData<ConfigState>,
71    maxstate: std::marker::PhantomData<MaxState>,
72    instructionstate: std::marker::PhantomData<InstructionState>,
73    textstate: std::marker::PhantomData<TextState>,
74    memorystate: std::marker::PhantomData<MemoryState>,
75}
76
77#[derive(Debug)]
78pub struct ConfigBuilder<'config, PropertiesState> {
79    r#type: Kind<'config>,
80    propertiesstate: std::marker::PhantomData<PropertiesState>,
81}
82
83#[derive(Debug)]
84pub struct Properties {
85    pub key: String,
86    pub r#type: String,
87    pub nested: Option<Vec<Properties>>,
88}
89
90#[derive(Debug)]
91pub enum Models<'model> {
92    GEMINI_1_5_FLASH,
93    GEMINI_1_5_PRO_002,
94    GEMINI_1_5_PRO,
95    GEMINI_1_5_FLASH_002,
96    GEMINI_1_5_FLASH_8B,
97    GEMINI_1_0_PRO,
98    Custom(&'model str),
99}
100
101#[derive(Debug)]
102pub struct ModelPresent;
103pub struct ModelNotPresent;
104
105#[derive(Debug)]
106pub struct EnvVariablePresent;
107pub struct EnvVariableNotPresent;
108
109#[derive(Debug)]
110
111pub struct TextPresent;
112pub struct TextNotPresent;
113
114#[derive(Debug)]
115pub struct ConfigPresent;
116pub struct ConfigNotPresent;
117
118#[derive(Debug)]
119pub struct PropertiesPresent;
120pub struct PropertiesNotPresent;
121
122pub struct Memory;
123
124pub struct Default;
125
126#[derive(Debug, Clone)]
127pub enum MemoryType {
128    Memory(Memorys),
129    NoMemory,
130}
131
132#[derive(Debug, Clone, Copy)]
133pub enum Memorys {
134    File,
135    Json,
136    // Database,
137}
138
139#[derive(Serialize, Deserialize, Debug)]
140pub struct Candidate {
141    pub content: Content,
142    finishReason: String,
143    avgLogprobs: f64,
144}
145
146#[derive(Serialize, Deserialize, Debug)]
147pub struct Content {
148    pub parts: Vec<Part>,
149}
150
151#[derive(Serialize, Deserialize, Debug)]
152pub struct Part {
153    pub text: String,
154}
155
156#[derive(Serialize, Deserialize, Debug)]
157pub struct UsageMetadata {
158    promptTokenCount: u32,
159    candidatesTokenCount: u32,
160    totalTokenCount: u32,
161}
162
163#[derive(Serialize, Deserialize, Debug)]
164pub struct Responses {
165    pub candidates: Vec<Candidate>,
166    usageMetadata: UsageMetadata,
167    modelVersion: String,
168}
169
170/// Take a HTTP response and decode it into a strongly-typed struct.
171/// Assumes full raw response in the argument.
172pub fn decode_gemini(raw_response: &str) -> Result<Responses, Box<dyn std::error::Error>> {
173    // Convert to bytes for httparse
174    let raw_bytes = raw_response.as_bytes();
175
176    let mut headers_buf = [EMPTY_HEADER; 64]; // Increase if you need more
177    let mut res = Response::new(&mut headers_buf);
178
179    let _ = res.parse(raw_bytes)?;
180
181    let code = res.code.unwrap_or(400); // e.g. 200
182    let reason = res.reason.unwrap_or("");
183    // dbg!("Status: {} {}", code, reason);
184
185    // Find where the headers ended.
186    let parsed_len = res.parse(raw_bytes)?.unwrap();
187    let body_bytes = &raw_bytes[parsed_len..];
188
189    let mut headers_map = HashMap::new();
190    for h in res.headers {
191        let name = h.name.to_lowercase(); // often normalized
192        let value = String::from_utf8_lossy(h.value).to_string();
193        headers_map.insert(name, value);
194    }
195
196    let transfer_encoding = headers_map
197        .get("transfer-encoding")
198        .unwrap_or(&String::new())
199        .to_lowercase();
200
201    let decoded_body = if transfer_encoding.contains("chunked") {
202        let mut decoder = chunked_transfer::Decoder::new(body_bytes);
203        let mut buf = Vec::new();
204        std::io::Read::read_to_end(&mut decoder, &mut buf)?;
205        buf
206    } else {
207        body_bytes.to_vec()
208    };
209
210    let body_str = String::from_utf8_lossy(&decoded_body);
211
212    // TODO: Make error handling less ugly
213    let responses: Responses = serde_json::from_str(&body_str)?;
214    Ok(responses)
215}
216
217pub struct Pair<'key> {
218    pub key: &'key str,
219    pub r#type: &'key str,
220}
221
222pub struct TrainPresent;
223pub struct TrainNotPresent;
224
225pub struct InstructionPresent;
226pub struct InstructionNotPresent;
227
228pub struct TellPresent;
229pub struct TellNotPresent;
230
231pub struct MaxLenPresent;
232pub struct MaxLenNotPresent;
233
234pub struct MemoryOK;
235pub struct MemoryNot;
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    #[test]
242    fn test_text() {
243        let builder = Gemini::new()
244            .env("GEMINI_API_KEY")
245            .model(Models::GEMINI_1_5_FLASH)
246            .no_memory()
247            .kind(Kind::Text)
248            .instruction("You are an unhelpful assistant")
249            .text("What is the capital of Latvia?")
250            .max_token(TokenLen::Default)
251            .build()
252            .output();
253        let result = decode_gemini(&builder);
254
255        dbg!(&builder);
256        dbg!(&result);
257
258        assert!(result.is_ok());
259    }
260}