sqlx_postgres/connection/
executor.rs

1use crate::describe::Describe;
2use crate::error::Error;
3use crate::executor::{Execute, Executor};
4use crate::io::{PortalId, StatementId};
5use crate::logger::QueryLogger;
6use crate::message::{
7    self, BackendMessageFormat, Bind, Close, CommandComplete, DataRow, ParameterDescription, Parse,
8    ParseComplete, Query, RowDescription,
9};
10use crate::statement::PgStatementMetadata;
11use crate::{
12    statement::PgStatement, PgArguments, PgConnection, PgQueryResult, PgRow, PgTypeInfo,
13    PgValueFormat, Postgres,
14};
15use futures_core::future::BoxFuture;
16use futures_core::stream::BoxStream;
17use futures_core::Stream;
18use futures_util::{pin_mut, TryStreamExt};
19use sqlx_core::arguments::Arguments;
20use sqlx_core::Either;
21use std::{borrow::Cow, sync::Arc};
22
23async fn prepare(
24    conn: &mut PgConnection,
25    sql: &str,
26    parameters: &[PgTypeInfo],
27    metadata: Option<Arc<PgStatementMetadata>>,
28) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
29    let id = conn.inner.next_statement_id;
30    conn.inner.next_statement_id = id.next();
31
32    // build a list of type OIDs to send to the database in the PARSE command
33    // we have not yet started the query sequence, so we are *safe* to cleanly make
34    // additional queries here to get any missing OIDs
35
36    let mut param_types = Vec::with_capacity(parameters.len());
37
38    for ty in parameters {
39        param_types.push(conn.resolve_type_id(&ty.0).await?);
40    }
41
42    // flush and wait until we are re-ready
43    conn.wait_until_ready().await?;
44
45    // next we send the PARSE command to the server
46    conn.inner.stream.write_msg(Parse {
47        param_types: &param_types,
48        query: sql,
49        statement: id,
50    })?;
51
52    if metadata.is_none() {
53        // get the statement columns and parameters
54        conn.inner
55            .stream
56            .write_msg(message::Describe::Statement(id))?;
57    }
58
59    // we ask for the server to immediately send us the result of the PARSE command
60    conn.write_sync();
61    conn.inner.stream.flush().await?;
62
63    // indicates that the SQL query string is now successfully parsed and has semantic validity
64    conn.inner.stream.recv_expect::<ParseComplete>().await?;
65
66    let metadata = if let Some(metadata) = metadata {
67        // each SYNC produces one READY FOR QUERY
68        conn.recv_ready_for_query().await?;
69
70        // we already have metadata
71        metadata
72    } else {
73        let parameters = recv_desc_params(conn).await?;
74
75        let rows = recv_desc_rows(conn).await?;
76
77        // each SYNC produces one READY FOR QUERY
78        conn.recv_ready_for_query().await?;
79
80        let parameters = conn.handle_parameter_description(parameters).await?;
81
82        let (columns, column_names) = conn.handle_row_description(rows, true).await?;
83
84        // ensure that if we did fetch custom data, we wait until we are fully ready before
85        // continuing
86        conn.wait_until_ready().await?;
87
88        Arc::new(PgStatementMetadata {
89            parameters,
90            columns,
91            column_names: Arc::new(column_names),
92        })
93    };
94
95    Ok((id, metadata))
96}
97
98async fn recv_desc_params(conn: &mut PgConnection) -> Result<ParameterDescription, Error> {
99    conn.inner.stream.recv_expect().await
100}
101
102async fn recv_desc_rows(conn: &mut PgConnection) -> Result<Option<RowDescription>, Error> {
103    let rows: Option<RowDescription> = match conn.inner.stream.recv().await? {
104        // describes the rows that will be returned when the statement is eventually executed
105        message if message.format == BackendMessageFormat::RowDescription => {
106            Some(message.decode()?)
107        }
108
109        // no data would be returned if this statement was executed
110        message if message.format == BackendMessageFormat::NoData => None,
111
112        message => {
113            return Err(err_protocol!(
114                "expecting RowDescription or NoData but received {:?}",
115                message.format
116            ));
117        }
118    };
119
120    Ok(rows)
121}
122
123impl PgConnection {
124    // wait for CloseComplete to indicate a statement was closed
125    pub(super) async fn wait_for_close_complete(&mut self, mut count: usize) -> Result<(), Error> {
126        // we need to wait for the [CloseComplete] to be returned from the server
127        while count > 0 {
128            match self.inner.stream.recv().await? {
129                message if message.format == BackendMessageFormat::PortalSuspended => {
130                    // there was an open portal
131                    // this can happen if the last time a statement was used it was not fully executed
132                }
133
134                message if message.format == BackendMessageFormat::CloseComplete => {
135                    // successfully closed the statement (and freed up the server resources)
136                    count -= 1;
137                }
138
139                message => {
140                    return Err(err_protocol!(
141                        "expecting PortalSuspended or CloseComplete but received {:?}",
142                        message.format
143                    ));
144                }
145            }
146        }
147
148        Ok(())
149    }
150
151    #[inline(always)]
152    pub(crate) fn write_sync(&mut self) {
153        self.inner
154            .stream
155            .write_msg(message::Sync)
156            .expect("BUG: Sync should not be too big for protocol");
157
158        // all SYNC messages will return a ReadyForQuery
159        self.inner.pending_ready_for_query_count += 1;
160    }
161
162    async fn get_or_prepare<'a>(
163        &mut self,
164        sql: &str,
165        parameters: &[PgTypeInfo],
166        // should we store the result of this prepare to the cache
167        store_to_cache: bool,
168        // optional metadata that was provided by the user, this means they are reusing
169        // a statement object
170        metadata: Option<Arc<PgStatementMetadata>>,
171    ) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
172        if let Some(statement) = self.inner.cache_statement.get_mut(sql) {
173            return Ok((*statement).clone());
174        }
175
176        let statement = prepare(self, sql, parameters, metadata).await?;
177
178        if store_to_cache && self.inner.cache_statement.is_enabled() {
179            if let Some((id, _)) = self.inner.cache_statement.insert(sql, statement.clone()) {
180                self.inner.stream.write_msg(Close::Statement(id))?;
181                self.write_sync();
182
183                self.inner.stream.flush().await?;
184
185                self.wait_for_close_complete(1).await?;
186                self.recv_ready_for_query().await?;
187            }
188        }
189
190        Ok(statement)
191    }
192
193    pub(crate) async fn run<'e, 'c: 'e, 'q: 'e>(
194        &'c mut self,
195        query: &'q str,
196        arguments: Option<PgArguments>,
197        limit: u8,
198        persistent: bool,
199        metadata_opt: Option<Arc<PgStatementMetadata>>,
200    ) -> Result<impl Stream<Item = Result<Either<PgQueryResult, PgRow>, Error>> + 'e, Error> {
201        let mut logger = QueryLogger::new(query, self.inner.log_settings.clone());
202
203        // before we continue, wait until we are "ready" to accept more queries
204        self.wait_until_ready().await?;
205
206        let mut metadata: Arc<PgStatementMetadata>;
207
208        let format = if let Some(mut arguments) = arguments {
209            // Check this before we write anything to the stream.
210            //
211            // Note: Postgres actually interprets this value as unsigned,
212            // making the max number of parameters 65535, not 32767
213            // https://github.com/launchbadge/sqlx/issues/3464
214            // https://www.postgresql.org/docs/current/limits.html
215            let num_params = u16::try_from(arguments.len()).map_err(|_| {
216                err_protocol!(
217                    "PgConnection::run(): too many arguments for query: {}",
218                    arguments.len()
219                )
220            })?;
221
222            // prepare the statement if this our first time executing it
223            // always return the statement ID here
224            let (statement, metadata_) = self
225                .get_or_prepare(query, &arguments.types, persistent, metadata_opt)
226                .await?;
227
228            metadata = metadata_;
229
230            // patch holes created during encoding
231            arguments.apply_patches(self, &metadata.parameters).await?;
232
233            // consume messages till `ReadyForQuery` before bind and execute
234            self.wait_until_ready().await?;
235
236            // bind to attach the arguments to the statement and create a portal
237            self.inner.stream.write_msg(Bind {
238                portal: PortalId::UNNAMED,
239                statement,
240                formats: &[PgValueFormat::Binary],
241                num_params,
242                params: &arguments.buffer,
243                result_formats: &[PgValueFormat::Binary],
244            })?;
245
246            // executes the portal up to the passed limit
247            // the protocol-level limit acts nearly identically to the `LIMIT` in SQL
248            self.inner.stream.write_msg(message::Execute {
249                portal: PortalId::UNNAMED,
250                limit: limit.into(),
251            })?;
252            // From https://www.postgresql.org/docs/current/protocol-flow.html:
253            //
254            // "An unnamed portal is destroyed at the end of the transaction, or as
255            // soon as the next Bind statement specifying the unnamed portal as
256            // destination is issued. (Note that a simple Query message also
257            // destroys the unnamed portal."
258
259            // we ask the database server to close the unnamed portal and free the associated resources
260            // earlier - after the execution of the current query.
261            self.inner
262                .stream
263                .write_msg(Close::Portal(PortalId::UNNAMED))?;
264
265            // finally, [Sync] asks postgres to process the messages that we sent and respond with
266            // a [ReadyForQuery] message when it's completely done. Theoretically, we could send
267            // dozens of queries before a [Sync] and postgres can handle that. Execution on the server
268            // is still serial but it would reduce round-trips. Some kind of builder pattern that is
269            // termed batching might suit this.
270            self.write_sync();
271
272            // prepared statements are binary
273            PgValueFormat::Binary
274        } else {
275            // Query will trigger a ReadyForQuery
276            self.inner.stream.write_msg(Query(query))?;
277            self.inner.pending_ready_for_query_count += 1;
278
279            // metadata starts out as "nothing"
280            metadata = Arc::new(PgStatementMetadata::default());
281
282            // and unprepared statements are text
283            PgValueFormat::Text
284        };
285
286        self.inner.stream.flush().await?;
287
288        Ok(try_stream! {
289            loop {
290                let message = self.inner.stream.recv().await?;
291
292                match message.format {
293                    BackendMessageFormat::BindComplete
294                    | BackendMessageFormat::ParseComplete
295                    | BackendMessageFormat::ParameterDescription
296                    | BackendMessageFormat::NoData
297                    // unnamed portal has been closed
298                    | BackendMessageFormat::CloseComplete
299                    => {
300                        // harmless messages to ignore
301                    }
302
303                    // "Execute phase is always terminated by the appearance of
304                    // exactly one of these messages: CommandComplete,
305                    // EmptyQueryResponse (if the portal was created from an
306                    // empty query string), ErrorResponse, or PortalSuspended"
307                    BackendMessageFormat::CommandComplete => {
308                        // a SQL command completed normally
309                        let cc: CommandComplete = message.decode()?;
310
311                        let rows_affected = cc.rows_affected();
312                        logger.increase_rows_affected(rows_affected);
313                        r#yield!(Either::Left(PgQueryResult {
314                            rows_affected,
315                        }));
316                    }
317
318                    BackendMessageFormat::EmptyQueryResponse => {
319                        // empty query string passed to an unprepared execute
320                    }
321
322                    // Message::ErrorResponse is handled in self.stream.recv()
323
324                    // incomplete query execution has finished
325                    BackendMessageFormat::PortalSuspended => {}
326
327                    BackendMessageFormat::RowDescription => {
328                        // indicates that a *new* set of rows are about to be returned
329                        let (columns, column_names) = self
330                            .handle_row_description(Some(message.decode()?), false)
331                            .await?;
332
333                        metadata = Arc::new(PgStatementMetadata {
334                            column_names: Arc::new(column_names),
335                            columns,
336                            parameters: Vec::default(),
337                        });
338                    }
339
340                    BackendMessageFormat::DataRow => {
341                        logger.increment_rows_returned();
342
343                        // one of the set of rows returned by a SELECT, FETCH, etc query
344                        let data: DataRow = message.decode()?;
345                        let row = PgRow {
346                            data,
347                            format,
348                            metadata: Arc::clone(&metadata),
349                        };
350
351                        r#yield!(Either::Right(row));
352                    }
353
354                    BackendMessageFormat::ReadyForQuery => {
355                        // processing of the query string is complete
356                        self.handle_ready_for_query(message)?;
357                        break;
358                    }
359
360                    _ => {
361                        return Err(err_protocol!(
362                            "execute: unexpected message: {:?}",
363                            message.format
364                        ));
365                    }
366                }
367            }
368
369            Ok(())
370        })
371    }
372}
373
374impl<'c> Executor<'c> for &'c mut PgConnection {
375    type Database = Postgres;
376
377    fn fetch_many<'e, 'q, E>(
378        self,
379        mut query: E,
380    ) -> BoxStream<'e, Result<Either<PgQueryResult, PgRow>, Error>>
381    where
382        'c: 'e,
383        E: Execute<'q, Self::Database>,
384        'q: 'e,
385        E: 'q,
386    {
387        let sql = query.sql();
388        // False positive: https://github.com/rust-lang/rust-clippy/issues/12560
389        #[allow(clippy::map_clone)]
390        let metadata = query.statement().map(|s| Arc::clone(&s.metadata));
391        let arguments = query.take_arguments().map_err(Error::Encode);
392        let persistent = query.persistent();
393
394        Box::pin(try_stream! {
395            let arguments = arguments?;
396            let s = self.run(sql, arguments, 0, persistent, metadata).await?;
397            pin_mut!(s);
398
399            while let Some(v) = s.try_next().await? {
400                r#yield!(v);
401            }
402
403            Ok(())
404        })
405    }
406
407    fn fetch_optional<'e, 'q, E>(self, mut query: E) -> BoxFuture<'e, Result<Option<PgRow>, Error>>
408    where
409        'c: 'e,
410        E: Execute<'q, Self::Database>,
411        'q: 'e,
412        E: 'q,
413    {
414        let sql = query.sql();
415        // False positive: https://github.com/rust-lang/rust-clippy/issues/12560
416        #[allow(clippy::map_clone)]
417        let metadata = query.statement().map(|s| Arc::clone(&s.metadata));
418        let arguments = query.take_arguments().map_err(Error::Encode);
419        let persistent = query.persistent();
420
421        Box::pin(async move {
422            let arguments = arguments?;
423            let s = self.run(sql, arguments, 1, persistent, metadata).await?;
424            pin_mut!(s);
425
426            // With deferred constraints we need to check all responses as we
427            // could get a OK response (with uncommitted data), only to get an
428            // error response after (when the deferred constraint is actually
429            // checked).
430            let mut ret = None;
431            while let Some(result) = s.try_next().await? {
432                match result {
433                    Either::Right(r) if ret.is_none() => ret = Some(r),
434                    _ => {}
435                }
436            }
437            Ok(ret)
438        })
439    }
440
441    fn prepare_with<'e, 'q: 'e>(
442        self,
443        sql: &'q str,
444        parameters: &'e [PgTypeInfo],
445    ) -> BoxFuture<'e, Result<PgStatement<'q>, Error>>
446    where
447        'c: 'e,
448    {
449        Box::pin(async move {
450            self.wait_until_ready().await?;
451
452            let (_, metadata) = self.get_or_prepare(sql, parameters, true, None).await?;
453
454            Ok(PgStatement {
455                sql: Cow::Borrowed(sql),
456                metadata,
457            })
458        })
459    }
460
461    fn describe<'e, 'q: 'e>(
462        self,
463        sql: &'q str,
464    ) -> BoxFuture<'e, Result<Describe<Self::Database>, Error>>
465    where
466        'c: 'e,
467    {
468        Box::pin(async move {
469            self.wait_until_ready().await?;
470
471            let (stmt_id, metadata) = self.get_or_prepare(sql, &[], true, None).await?;
472
473            let nullable = self.get_nullable_for_columns(stmt_id, &metadata).await?;
474
475            Ok(Describe {
476                columns: metadata.columns.clone(),
477                nullable,
478                parameters: Some(Either::Left(metadata.parameters.clone())),
479            })
480        })
481    }
482}