1use std::collections::HashMap;
2
3use httparse::{Response, EMPTY_HEADER};
4use serde::{Deserialize, Serialize};
5
6pub mod content_gen;
8pub mod format;
11pub mod schema;
13#[derive(Debug, Clone)]
16pub struct GeminiContentGen<'gemini> {
17 env_variable: &'gemini str,
18 model: &'gemini str,
19 max_len: u64,
20 instruction: &'gemini str,
21 text: &'gemini str,
22 config: Config<'gemini>,
23 memory: MemoryType,
24}
25
26#[derive(Debug)]
27pub enum TokenLen {
28 Default,
29 Custome(u64),
30}
31#[derive(Debug, Clone)]
32pub struct Config<'config> {
33 pub response: Kind<'config>,
34}
35
36#[derive(Debug, Clone)]
37pub enum Kind<'response> {
38 Json(&'response str),
39 Text,
40 Audio(&'response Vec<u8>),
41 Transcribe(&'response Vec<u8>),
42 Image(&'response Vec<u8>),
43 Video(&'response Vec<u8>),
44 Pdf(&'response Vec<u8>),
45 Csv(&'response Vec<u8>),
46 Rag(&'response [&'response str]),
47}
48
49#[derive(Debug)]
50pub struct Gemini<
51 'gemini,
52 EnvState,
53 ModelState,
54 ConfigState,
55 InstructionState,
56 TextState,
57 MaxState,
58 PropertiesState,
59 MemoryState,
60> {
61 env_variable: &'gemini str,
62 model: &'gemini str,
63 instruction: &'gemini str,
64 max_len: u64,
65 text: &'gemini str,
66 memory: MemoryType,
67 config: ConfigBuilder<'gemini, PropertiesState>,
68 envstate: std::marker::PhantomData<EnvState>,
69 modelstate: std::marker::PhantomData<ModelState>,
70 configstate: std::marker::PhantomData<ConfigState>,
71 maxstate: std::marker::PhantomData<MaxState>,
72 instructionstate: std::marker::PhantomData<InstructionState>,
73 textstate: std::marker::PhantomData<TextState>,
74 memorystate: std::marker::PhantomData<MemoryState>,
75}
76
77#[derive(Debug)]
78pub struct ConfigBuilder<'config, PropertiesState> {
79 r#type: Kind<'config>,
80 propertiesstate: std::marker::PhantomData<PropertiesState>,
81}
82
83#[derive(Debug)]
84pub struct Properties {
85 pub key: String,
86 pub r#type: String,
87 pub nested: Option<Vec<Properties>>,
88}
89
90#[derive(Debug)]
91pub enum Models<'model> {
92 GEMINI_1_5_FLASH,
93 GEMINI_1_5_PRO_002,
94 GEMINI_1_5_PRO,
95 GEMINI_1_5_FLASH_002,
96 GEMINI_1_5_FLASH_8B,
97 GEMINI_1_0_PRO,
98 Custom(&'model str),
99}
100
101#[derive(Debug)]
102pub struct ModelPresent;
103pub struct ModelNotPresent;
104
105#[derive(Debug)]
106pub struct EnvVariablePresent;
107pub struct EnvVariableNotPresent;
108
109#[derive(Debug)]
110
111pub struct TextPresent;
112pub struct TextNotPresent;
113
114#[derive(Debug)]
115pub struct ConfigPresent;
116pub struct ConfigNotPresent;
117
118#[derive(Debug)]
119pub struct PropertiesPresent;
120pub struct PropertiesNotPresent;
121
122pub struct Memory;
123
124pub struct Default;
125
126#[derive(Debug, Clone)]
127pub enum MemoryType {
128 Memory(Memorys),
129 NoMemory,
130}
131
132#[derive(Debug, Clone, Copy)]
133pub enum Memorys {
134 File,
135 Json,
136 }
138
139#[derive(Serialize, Deserialize, Debug)]
140pub struct Candidate {
141 pub content: Content,
142 finishReason: String,
143 avgLogprobs: f64,
144}
145
146#[derive(Serialize, Deserialize, Debug)]
147pub struct Content {
148 pub parts: Vec<Part>,
149}
150
151#[derive(Serialize, Deserialize, Debug)]
152pub struct Part {
153 pub text: String,
154}
155
156#[derive(Serialize, Deserialize, Debug)]
157pub struct UsageMetadata {
158 promptTokenCount: u32,
159 candidatesTokenCount: u32,
160 totalTokenCount: u32,
161}
162
163#[derive(Serialize, Deserialize, Debug)]
164pub struct Responses {
165 pub candidates: Vec<Candidate>,
166 usageMetadata: UsageMetadata,
167 modelVersion: String,
168}
169
170pub fn decode_gemini(raw_response: &str) -> Result<Responses, Box<dyn std::error::Error>> {
173 let raw_bytes = raw_response.as_bytes();
175
176 let mut headers_buf = [EMPTY_HEADER; 64]; let mut res = Response::new(&mut headers_buf);
178
179 let _ = res.parse(raw_bytes)?;
180
181 let code = res.code.unwrap_or(400); let reason = res.reason.unwrap_or("");
183 let parsed_len = res.parse(raw_bytes)?.unwrap();
187 let body_bytes = &raw_bytes[parsed_len..];
188
189 let mut headers_map = HashMap::new();
190 for h in res.headers {
191 let name = h.name.to_lowercase(); let value = String::from_utf8_lossy(h.value).to_string();
193 headers_map.insert(name, value);
194 }
195
196 let transfer_encoding = headers_map
197 .get("transfer-encoding")
198 .unwrap_or(&String::new())
199 .to_lowercase();
200
201 let decoded_body = if transfer_encoding.contains("chunked") {
202 let mut decoder = chunked_transfer::Decoder::new(body_bytes);
203 let mut buf = Vec::new();
204 std::io::Read::read_to_end(&mut decoder, &mut buf)?;
205 buf
206 } else {
207 body_bytes.to_vec()
208 };
209
210 let body_str = String::from_utf8_lossy(&decoded_body);
211
212 let responses: Responses = serde_json::from_str(&body_str)?;
214 Ok(responses)
215}
216
217pub struct Pair<'key> {
218 pub key: &'key str,
219 pub r#type: &'key str,
220}
221
222pub struct TrainPresent;
223pub struct TrainNotPresent;
224
225pub struct InstructionPresent;
226pub struct InstructionNotPresent;
227
228pub struct TellPresent;
229pub struct TellNotPresent;
230
231pub struct MaxLenPresent;
232pub struct MaxLenNotPresent;
233
234pub struct MemoryOK;
235pub struct MemoryNot;
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240
241 #[test]
242 fn test_text() {
243 let builder = Gemini::new()
244 .env("GEMINI_API_KEY")
245 .model(Models::GEMINI_1_5_FLASH)
246 .no_memory()
247 .kind(Kind::Text)
248 .instruction("You are an unhelpful assistant")
249 .text("What is the capital of Latvia?")
250 .max_token(TokenLen::Default)
251 .build()
252 .output();
253 let result = decode_gemini(&builder);
254
255 dbg!(&builder);
256 dbg!(&result);
257
258 assert!(result.is_ok());
259 }
260}