1use super::MySqlStream;
2use crate::connection::stream::Waiting;
3use crate::describe::Describe;
4use crate::error::Error;
5use crate::executor::{Execute, Executor};
6use crate::ext::ustr::UStr;
7use crate::io::MySqlBufExt;
8use crate::logger::QueryLogger;
9use crate::protocol::response::Status;
10use crate::protocol::statement::{
11 BinaryRow, Execute as StatementExecute, Prepare, PrepareOk, StmtClose,
12};
13use crate::protocol::text::{ColumnDefinition, ColumnFlags, Query, TextRow};
14use crate::statement::{MySqlStatement, MySqlStatementMetadata};
15use crate::HashMap;
16use crate::{
17 MySql, MySqlArguments, MySqlColumn, MySqlConnection, MySqlQueryResult, MySqlRow, MySqlTypeInfo,
18 MySqlValueFormat,
19};
20use either::Either;
21use futures_core::future::BoxFuture;
22use futures_core::stream::BoxStream;
23use futures_core::Stream;
24use futures_util::TryStreamExt;
25use std::{borrow::Cow, pin::pin, sync::Arc};
26
27impl MySqlConnection {
28 async fn prepare_statement<'c>(
29 &mut self,
30 sql: &str,
31 ) -> Result<(u32, MySqlStatementMetadata), Error> {
32 self.inner
36 .stream
37 .send_packet(Prepare { query: sql })
38 .await?;
39
40 let ok: PrepareOk = self.inner.stream.recv().await?;
41
42 if ok.params > 0 {
46 for _ in 0..ok.params {
47 let _def: ColumnDefinition = self.inner.stream.recv().await?;
48 }
49
50 self.inner.stream.maybe_recv_eof().await?;
51 }
52
53 let mut columns = Vec::new();
58
59 let column_names = if ok.columns > 0 {
60 recv_result_metadata(&mut self.inner.stream, ok.columns as usize, &mut columns).await?
61 } else {
62 Default::default()
63 };
64
65 let id = ok.statement_id;
66 let metadata = MySqlStatementMetadata {
67 parameters: ok.params as usize,
68 columns: Arc::new(columns),
69 column_names: Arc::new(column_names),
70 };
71
72 Ok((id, metadata))
73 }
74
75 async fn get_or_prepare_statement<'c>(
76 &mut self,
77 sql: &str,
78 ) -> Result<(u32, MySqlStatementMetadata), Error> {
79 if let Some(statement) = self.inner.cache_statement.get_mut(sql) {
80 return Ok((*statement).clone());
82 }
83
84 let (id, metadata) = self.prepare_statement(sql).await?;
85
86 if let Some((id, _)) = self
88 .inner
89 .cache_statement
90 .insert(sql, (id, metadata.clone()))
91 {
92 self.inner
93 .stream
94 .send_packet(StmtClose { statement: id })
95 .await?;
96 }
97
98 Ok((id, metadata))
99 }
100
101 #[allow(clippy::needless_lifetimes)]
102 pub(crate) async fn run<'e, 'c: 'e, 'q: 'e>(
103 &'c mut self,
104 sql: &'q str,
105 arguments: Option<MySqlArguments>,
106 persistent: bool,
107 ) -> Result<impl Stream<Item = Result<Either<MySqlQueryResult, MySqlRow>, Error>> + 'e, Error>
108 {
109 let mut logger = QueryLogger::new(sql, self.inner.log_settings.clone());
110
111 self.inner.stream.wait_until_ready().await?;
112 self.inner.stream.waiting.push_back(Waiting::Result);
113
114 Ok(try_stream! {
115 let mut columns = Arc::new(Vec::new());
119
120 let (mut column_names, format, mut needs_metadata) = if let Some(arguments) = arguments {
121 if persistent && self.inner.cache_statement.is_enabled() {
122 let (id, metadata) = self
123 .get_or_prepare_statement(sql)
124 .await?;
125
126 self.inner.stream
128 .send_packet(StatementExecute {
129 statement: id,
130 arguments: &arguments,
131 })
132 .await?;
133
134 (metadata.column_names, MySqlValueFormat::Binary, false)
135 } else {
136 let (id, metadata) = self
137 .prepare_statement(sql)
138 .await?;
139
140 self.inner.stream
142 .send_packet(StatementExecute {
143 statement: id,
144 arguments: &arguments,
145 })
146 .await?;
147
148 self.inner.stream.send_packet(StmtClose { statement: id }).await?;
149
150 (metadata.column_names, MySqlValueFormat::Binary, false)
151 }
152 } else {
153 self.inner.stream.send_packet(Query(sql)).await?;
155
156 (Arc::default(), MySqlValueFormat::Text, true)
157 };
158
159 loop {
160 let mut packet = self.inner.stream.recv_packet().await?;
163
164 if packet[0] == 0x00 || packet[0] == 0xff {
165 let ok = packet.ok()?;
168
169 self.inner.status_flags = ok.status;
170
171 let rows_affected = ok.affected_rows;
172 logger.increase_rows_affected(rows_affected);
173 let done = MySqlQueryResult {
174 rows_affected,
175 last_insert_id: ok.last_insert_id,
176 };
177
178 r#yield!(Either::Left(done));
179
180 if ok.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) {
181 continue;
183 }
184
185 self.inner.stream.waiting.pop_front();
186 return Ok(());
187 }
188
189 *self.inner.stream.waiting.front_mut().unwrap() = Waiting::Row;
191
192 let num_columns = packet.get_uint_lenenc(); let num_columns = usize::try_from(num_columns)
194 .map_err(|_| err_protocol!("column count overflows usize: {num_columns}"))?;
195
196 if needs_metadata {
197 column_names = Arc::new(recv_result_metadata(&mut self.inner.stream, num_columns, Arc::make_mut(&mut columns)).await?);
198 } else {
199 needs_metadata = true;
202
203 recv_result_columns(&mut self.inner.stream, num_columns, Arc::make_mut(&mut columns)).await?;
204 }
205
206 loop {
208 let packet = self.inner.stream.recv_packet().await?;
209
210 if packet[0] == 0xfe && packet.len() < 9 {
211 let eof = packet.eof(self.inner.stream.capabilities)?;
212
213 self.inner.status_flags = eof.status;
214
215 r#yield!(Either::Left(MySqlQueryResult {
216 rows_affected: 0,
217 last_insert_id: 0,
218 }));
219
220 if eof.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) {
221 *self.inner.stream.waiting.front_mut().unwrap() = Waiting::Result;
223 break;
224 }
225
226 self.inner.stream.waiting.pop_front();
227 return Ok(());
228 }
229
230 let row = match format {
231 MySqlValueFormat::Binary => packet.decode_with::<BinaryRow, _>(&columns)?.0,
232 MySqlValueFormat::Text => packet.decode_with::<TextRow, _>(&columns)?.0,
233 };
234
235 let v = Either::Right(MySqlRow {
236 row,
237 format,
238 columns: Arc::clone(&columns),
239 column_names: Arc::clone(&column_names),
240 });
241
242 logger.increment_rows_returned();
243
244 r#yield!(v);
245 }
246 }
247 })
248 }
249}
250
251impl<'c> Executor<'c> for &'c mut MySqlConnection {
252 type Database = MySql;
253
254 fn fetch_many<'e, 'q, E>(
255 self,
256 mut query: E,
257 ) -> BoxStream<'e, Result<Either<MySqlQueryResult, MySqlRow>, Error>>
258 where
259 'c: 'e,
260 E: Execute<'q, Self::Database>,
261 'q: 'e,
262 E: 'q,
263 {
264 let sql = query.sql();
265 let arguments = query.take_arguments().map_err(Error::Encode);
266 let persistent = query.persistent();
267
268 Box::pin(try_stream! {
269 let arguments = arguments?;
270 let mut s = pin!(self.run(sql, arguments, persistent).await?);
271
272 while let Some(v) = s.try_next().await? {
273 r#yield!(v);
274 }
275
276 Ok(())
277 })
278 }
279
280 fn fetch_optional<'e, 'q, E>(self, query: E) -> BoxFuture<'e, Result<Option<MySqlRow>, Error>>
281 where
282 'c: 'e,
283 E: Execute<'q, Self::Database>,
284 'q: 'e,
285 E: 'q,
286 {
287 let mut s = self.fetch_many(query);
288
289 Box::pin(async move {
290 while let Some(v) = s.try_next().await? {
291 if let Either::Right(r) = v {
292 return Ok(Some(r));
293 }
294 }
295
296 Ok(None)
297 })
298 }
299
300 fn prepare_with<'e, 'q: 'e>(
301 self,
302 sql: &'q str,
303 _parameters: &'e [MySqlTypeInfo],
304 ) -> BoxFuture<'e, Result<MySqlStatement<'q>, Error>>
305 where
306 'c: 'e,
307 {
308 Box::pin(async move {
309 self.inner.stream.wait_until_ready().await?;
310
311 let metadata = if self.inner.cache_statement.is_enabled() {
312 self.get_or_prepare_statement(sql).await?.1
313 } else {
314 let (id, metadata) = self.prepare_statement(sql).await?;
315
316 self.inner
317 .stream
318 .send_packet(StmtClose { statement: id })
319 .await?;
320
321 metadata
322 };
323
324 Ok(MySqlStatement {
325 sql: Cow::Borrowed(sql),
326 metadata: metadata.clone(),
328 })
329 })
330 }
331
332 #[doc(hidden)]
333 fn describe<'e, 'q: 'e>(self, sql: &'q str) -> BoxFuture<'e, Result<Describe<MySql>, Error>>
334 where
335 'c: 'e,
336 {
337 Box::pin(async move {
338 self.inner.stream.wait_until_ready().await?;
339
340 let (id, metadata) = self.prepare_statement(sql).await?;
341
342 self.inner
343 .stream
344 .send_packet(StmtClose { statement: id })
345 .await?;
346
347 let columns = (*metadata.columns).clone();
348
349 let nullable = columns
350 .iter()
351 .map(|col| {
352 col.flags
353 .map(|flags| !flags.contains(ColumnFlags::NOT_NULL))
354 })
355 .collect();
356
357 Ok(Describe {
358 parameters: Some(Either::Right(metadata.parameters)),
359 columns,
360 nullable,
361 })
362 })
363 }
364}
365
366async fn recv_result_columns(
367 stream: &mut MySqlStream,
368 num_columns: usize,
369 columns: &mut Vec<MySqlColumn>,
370) -> Result<(), Error> {
371 columns.clear();
372 columns.reserve(num_columns);
373
374 for ordinal in 0..num_columns {
375 columns.push(recv_next_result_column(&stream.recv().await?, ordinal)?);
376 }
377
378 if num_columns > 0 {
379 stream.maybe_recv_eof().await?;
380 }
381
382 Ok(())
383}
384
385fn recv_next_result_column(def: &ColumnDefinition, ordinal: usize) -> Result<MySqlColumn, Error> {
386 let name = match (def.name()?, def.alias()?) {
389 (_, alias) if !alias.is_empty() => UStr::new(alias),
390 (name, _) => UStr::new(name),
391 };
392
393 let type_info = MySqlTypeInfo::from_column(def);
394
395 Ok(MySqlColumn {
396 name,
397 type_info,
398 ordinal,
399 flags: Some(def.flags),
400 })
401}
402
403async fn recv_result_metadata(
404 stream: &mut MySqlStream,
405 num_columns: usize,
406 columns: &mut Vec<MySqlColumn>,
407) -> Result<HashMap<UStr, usize>, Error> {
408 let mut column_names = HashMap::with_capacity(num_columns);
412
413 columns.clear();
414 columns.reserve(num_columns);
415
416 for ordinal in 0..num_columns {
417 let def: ColumnDefinition = stream.recv().await?;
418
419 let column = recv_next_result_column(&def, ordinal)?;
420
421 column_names.insert(column.name.clone(), ordinal);
422 columns.push(column);
423 }
424
425 stream.maybe_recv_eof().await?;
426
427 Ok(column_names)
428}