sqlx_postgres/connection/
executor.rs1use crate::describe::Describe;
2use crate::error::Error;
3use crate::executor::{Execute, Executor};
4use crate::io::{PortalId, StatementId};
5use crate::logger::QueryLogger;
6use crate::message::{
7 self, BackendMessageFormat, Bind, Close, CommandComplete, DataRow, ParameterDescription, Parse,
8 ParseComplete, Query, RowDescription,
9};
10use crate::statement::PgStatementMetadata;
11use crate::{
12 statement::PgStatement, PgArguments, PgConnection, PgQueryResult, PgRow, PgTypeInfo,
13 PgValueFormat, Postgres,
14};
15use futures_core::future::BoxFuture;
16use futures_core::stream::BoxStream;
17use futures_core::Stream;
18use futures_util::{pin_mut, TryStreamExt};
19use sqlx_core::arguments::Arguments;
20use sqlx_core::Either;
21use std::{borrow::Cow, sync::Arc};
22
23async fn prepare(
24 conn: &mut PgConnection,
25 sql: &str,
26 parameters: &[PgTypeInfo],
27 metadata: Option<Arc<PgStatementMetadata>>,
28) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
29 let id = conn.inner.next_statement_id;
30 conn.inner.next_statement_id = id.next();
31
32 let mut param_types = Vec::with_capacity(parameters.len());
37
38 for ty in parameters {
39 param_types.push(conn.resolve_type_id(&ty.0).await?);
40 }
41
42 conn.wait_until_ready().await?;
44
45 conn.inner.stream.write_msg(Parse {
47 param_types: ¶m_types,
48 query: sql,
49 statement: id,
50 })?;
51
52 if metadata.is_none() {
53 conn.inner
55 .stream
56 .write_msg(message::Describe::Statement(id))?;
57 }
58
59 conn.write_sync();
61 conn.inner.stream.flush().await?;
62
63 conn.inner.stream.recv_expect::<ParseComplete>().await?;
65
66 let metadata = if let Some(metadata) = metadata {
67 conn.recv_ready_for_query().await?;
69
70 metadata
72 } else {
73 let parameters = recv_desc_params(conn).await?;
74
75 let rows = recv_desc_rows(conn).await?;
76
77 conn.recv_ready_for_query().await?;
79
80 let parameters = conn.handle_parameter_description(parameters).await?;
81
82 let (columns, column_names) = conn.handle_row_description(rows, true).await?;
83
84 conn.wait_until_ready().await?;
87
88 Arc::new(PgStatementMetadata {
89 parameters,
90 columns,
91 column_names: Arc::new(column_names),
92 })
93 };
94
95 Ok((id, metadata))
96}
97
98async fn recv_desc_params(conn: &mut PgConnection) -> Result<ParameterDescription, Error> {
99 conn.inner.stream.recv_expect().await
100}
101
102async fn recv_desc_rows(conn: &mut PgConnection) -> Result<Option<RowDescription>, Error> {
103 let rows: Option<RowDescription> = match conn.inner.stream.recv().await? {
104 message if message.format == BackendMessageFormat::RowDescription => {
106 Some(message.decode()?)
107 }
108
109 message if message.format == BackendMessageFormat::NoData => None,
111
112 message => {
113 return Err(err_protocol!(
114 "expecting RowDescription or NoData but received {:?}",
115 message.format
116 ));
117 }
118 };
119
120 Ok(rows)
121}
122
123impl PgConnection {
124 pub(super) async fn wait_for_close_complete(&mut self, mut count: usize) -> Result<(), Error> {
126 while count > 0 {
128 match self.inner.stream.recv().await? {
129 message if message.format == BackendMessageFormat::PortalSuspended => {
130 }
133
134 message if message.format == BackendMessageFormat::CloseComplete => {
135 count -= 1;
137 }
138
139 message => {
140 return Err(err_protocol!(
141 "expecting PortalSuspended or CloseComplete but received {:?}",
142 message.format
143 ));
144 }
145 }
146 }
147
148 Ok(())
149 }
150
151 #[inline(always)]
152 pub(crate) fn write_sync(&mut self) {
153 self.inner
154 .stream
155 .write_msg(message::Sync)
156 .expect("BUG: Sync should not be too big for protocol");
157
158 self.inner.pending_ready_for_query_count += 1;
160 }
161
162 async fn get_or_prepare<'a>(
163 &mut self,
164 sql: &str,
165 parameters: &[PgTypeInfo],
166 store_to_cache: bool,
168 metadata: Option<Arc<PgStatementMetadata>>,
171 ) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
172 if let Some(statement) = self.inner.cache_statement.get_mut(sql) {
173 return Ok((*statement).clone());
174 }
175
176 let statement = prepare(self, sql, parameters, metadata).await?;
177
178 if store_to_cache && self.inner.cache_statement.is_enabled() {
179 if let Some((id, _)) = self.inner.cache_statement.insert(sql, statement.clone()) {
180 self.inner.stream.write_msg(Close::Statement(id))?;
181 self.write_sync();
182
183 self.inner.stream.flush().await?;
184
185 self.wait_for_close_complete(1).await?;
186 self.recv_ready_for_query().await?;
187 }
188 }
189
190 Ok(statement)
191 }
192
193 pub(crate) async fn run<'e, 'c: 'e, 'q: 'e>(
194 &'c mut self,
195 query: &'q str,
196 arguments: Option<PgArguments>,
197 limit: u8,
198 persistent: bool,
199 metadata_opt: Option<Arc<PgStatementMetadata>>,
200 ) -> Result<impl Stream<Item = Result<Either<PgQueryResult, PgRow>, Error>> + 'e, Error> {
201 let mut logger = QueryLogger::new(query, self.inner.log_settings.clone());
202
203 self.wait_until_ready().await?;
205
206 let mut metadata: Arc<PgStatementMetadata>;
207
208 let format = if let Some(mut arguments) = arguments {
209 let num_params = u16::try_from(arguments.len()).map_err(|_| {
216 err_protocol!(
217 "PgConnection::run(): too many arguments for query: {}",
218 arguments.len()
219 )
220 })?;
221
222 let (statement, metadata_) = self
225 .get_or_prepare(query, &arguments.types, persistent, metadata_opt)
226 .await?;
227
228 metadata = metadata_;
229
230 arguments.apply_patches(self, &metadata.parameters).await?;
232
233 self.wait_until_ready().await?;
235
236 self.inner.stream.write_msg(Bind {
238 portal: PortalId::UNNAMED,
239 statement,
240 formats: &[PgValueFormat::Binary],
241 num_params,
242 params: &arguments.buffer,
243 result_formats: &[PgValueFormat::Binary],
244 })?;
245
246 self.inner.stream.write_msg(message::Execute {
249 portal: PortalId::UNNAMED,
250 limit: limit.into(),
251 })?;
252 self.inner
262 .stream
263 .write_msg(Close::Portal(PortalId::UNNAMED))?;
264
265 self.write_sync();
271
272 PgValueFormat::Binary
274 } else {
275 self.inner.stream.write_msg(Query(query))?;
277 self.inner.pending_ready_for_query_count += 1;
278
279 metadata = Arc::new(PgStatementMetadata::default());
281
282 PgValueFormat::Text
284 };
285
286 self.inner.stream.flush().await?;
287
288 Ok(try_stream! {
289 loop {
290 let message = self.inner.stream.recv().await?;
291
292 match message.format {
293 BackendMessageFormat::BindComplete
294 | BackendMessageFormat::ParseComplete
295 | BackendMessageFormat::ParameterDescription
296 | BackendMessageFormat::NoData
297 | BackendMessageFormat::CloseComplete
299 => {
300 }
302
303 BackendMessageFormat::CommandComplete => {
308 let cc: CommandComplete = message.decode()?;
310
311 let rows_affected = cc.rows_affected();
312 logger.increase_rows_affected(rows_affected);
313 r#yield!(Either::Left(PgQueryResult {
314 rows_affected,
315 }));
316 }
317
318 BackendMessageFormat::EmptyQueryResponse => {
319 }
321
322 BackendMessageFormat::PortalSuspended => {}
326
327 BackendMessageFormat::RowDescription => {
328 let (columns, column_names) = self
330 .handle_row_description(Some(message.decode()?), false)
331 .await?;
332
333 metadata = Arc::new(PgStatementMetadata {
334 column_names: Arc::new(column_names),
335 columns,
336 parameters: Vec::default(),
337 });
338 }
339
340 BackendMessageFormat::DataRow => {
341 logger.increment_rows_returned();
342
343 let data: DataRow = message.decode()?;
345 let row = PgRow {
346 data,
347 format,
348 metadata: Arc::clone(&metadata),
349 };
350
351 r#yield!(Either::Right(row));
352 }
353
354 BackendMessageFormat::ReadyForQuery => {
355 self.handle_ready_for_query(message)?;
357 break;
358 }
359
360 _ => {
361 return Err(err_protocol!(
362 "execute: unexpected message: {:?}",
363 message.format
364 ));
365 }
366 }
367 }
368
369 Ok(())
370 })
371 }
372}
373
374impl<'c> Executor<'c> for &'c mut PgConnection {
375 type Database = Postgres;
376
377 fn fetch_many<'e, 'q, E>(
378 self,
379 mut query: E,
380 ) -> BoxStream<'e, Result<Either<PgQueryResult, PgRow>, Error>>
381 where
382 'c: 'e,
383 E: Execute<'q, Self::Database>,
384 'q: 'e,
385 E: 'q,
386 {
387 let sql = query.sql();
388 #[allow(clippy::map_clone)]
390 let metadata = query.statement().map(|s| Arc::clone(&s.metadata));
391 let arguments = query.take_arguments().map_err(Error::Encode);
392 let persistent = query.persistent();
393
394 Box::pin(try_stream! {
395 let arguments = arguments?;
396 let s = self.run(sql, arguments, 0, persistent, metadata).await?;
397 pin_mut!(s);
398
399 while let Some(v) = s.try_next().await? {
400 r#yield!(v);
401 }
402
403 Ok(())
404 })
405 }
406
407 fn fetch_optional<'e, 'q, E>(self, mut query: E) -> BoxFuture<'e, Result<Option<PgRow>, Error>>
408 where
409 'c: 'e,
410 E: Execute<'q, Self::Database>,
411 'q: 'e,
412 E: 'q,
413 {
414 let sql = query.sql();
415 #[allow(clippy::map_clone)]
417 let metadata = query.statement().map(|s| Arc::clone(&s.metadata));
418 let arguments = query.take_arguments().map_err(Error::Encode);
419 let persistent = query.persistent();
420
421 Box::pin(async move {
422 let arguments = arguments?;
423 let s = self.run(sql, arguments, 1, persistent, metadata).await?;
424 pin_mut!(s);
425
426 let mut ret = None;
431 while let Some(result) = s.try_next().await? {
432 match result {
433 Either::Right(r) if ret.is_none() => ret = Some(r),
434 _ => {}
435 }
436 }
437 Ok(ret)
438 })
439 }
440
441 fn prepare_with<'e, 'q: 'e>(
442 self,
443 sql: &'q str,
444 parameters: &'e [PgTypeInfo],
445 ) -> BoxFuture<'e, Result<PgStatement<'q>, Error>>
446 where
447 'c: 'e,
448 {
449 Box::pin(async move {
450 self.wait_until_ready().await?;
451
452 let (_, metadata) = self.get_or_prepare(sql, parameters, true, None).await?;
453
454 Ok(PgStatement {
455 sql: Cow::Borrowed(sql),
456 metadata,
457 })
458 })
459 }
460
461 fn describe<'e, 'q: 'e>(
462 self,
463 sql: &'q str,
464 ) -> BoxFuture<'e, Result<Describe<Self::Database>, Error>>
465 where
466 'c: 'e,
467 {
468 Box::pin(async move {
469 self.wait_until_ready().await?;
470
471 let (stmt_id, metadata) = self.get_or_prepare(sql, &[], true, None).await?;
472
473 let nullable = self.get_nullable_for_columns(stmt_id, &metadata).await?;
474
475 Ok(Describe {
476 columns: metadata.columns.clone(),
477 nullable,
478 parameters: Some(Either::Left(metadata.parameters.clone())),
479 })
480 })
481 }
482}