sqlx_postgres/
copy.rs

1use std::borrow::Cow;
2use std::ops::{Deref, DerefMut};
3
4use futures_core::future::BoxFuture;
5use futures_core::stream::BoxStream;
6
7use sqlx_core::bytes::{BufMut, Bytes};
8
9use crate::connection::PgConnection;
10use crate::error::{Error, Result};
11use crate::ext::async_stream::TryAsyncStream;
12use crate::io::AsyncRead;
13use crate::message::{
14    BackendMessageFormat, CommandComplete, CopyData, CopyDone, CopyFail, CopyInResponse,
15    CopyOutResponse, CopyResponseData, Query, ReadyForQuery,
16};
17use crate::pool::{Pool, PoolConnection};
18use crate::Postgres;
19
20impl PgConnection {
21    /// Issue a `COPY FROM STDIN` statement and transition the connection to streaming data
22    /// to Postgres. This is a more efficient way to import data into Postgres as compared to
23    /// `INSERT` but requires one of a few specific data formats (text/CSV/binary).
24    ///
25    /// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is
26    /// returned.
27    ///
28    /// Command examples and accepted formats for `COPY` data are shown here:
29    /// <https://www.postgresql.org/docs/current/sql-copy.html>
30    ///
31    /// ### Note
32    /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
33    /// will return an error the next time it is used.
34    pub async fn copy_in_raw(&mut self, statement: &str) -> Result<PgCopyIn<&mut Self>> {
35        PgCopyIn::begin(self, statement).await
36    }
37
38    /// Issue a `COPY TO STDOUT` statement and transition the connection to streaming data
39    /// from Postgres. This is a more efficient way to export data from Postgres but
40    /// arrives in chunks of one of a few data formats (text/CSV/binary).
41    ///
42    /// If `statement` is anything other than a `COPY ... TO STDOUT ...` command,
43    /// an error is returned.
44    ///
45    /// Note that once this process has begun, unless you read the stream to completion,
46    /// it can only be canceled in two ways:
47    ///
48    /// 1. by closing the connection, or:
49    /// 2. by using another connection to kill the server process that is sending the data as shown
50    ///    [in this StackOverflow answer](https://stackoverflow.com/a/35319598).
51    ///
52    /// If you don't read the stream to completion, the next time the connection is used it will
53    /// need to read and discard all the remaining queued data, which could take some time.
54    ///
55    /// Command examples and accepted formats for `COPY` data are shown here:
56    /// <https://www.postgresql.org/docs/current/sql-copy.html>
57    #[allow(clippy::needless_lifetimes)]
58    pub async fn copy_out_raw<'c>(
59        &'c mut self,
60        statement: &str,
61    ) -> Result<BoxStream<'c, Result<Bytes>>> {
62        pg_begin_copy_out(self, statement).await
63    }
64}
65
66/// Implements methods for directly executing `COPY FROM/TO STDOUT` on a [`PgPool`][crate::PgPool].
67///
68/// This is a replacement for the inherent methods on `PgPool` which could not exist
69/// once the Postgres driver was moved out into its own crate.
70pub trait PgPoolCopyExt {
71    /// Issue a `COPY FROM STDIN` statement and begin streaming data to Postgres.
72    /// This is a more efficient way to import data into Postgres as compared to
73    /// `INSERT` but requires one of a few specific data formats (text/CSV/binary).
74    ///
75    /// A single connection will be checked out for the duration.
76    ///
77    /// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is
78    /// returned.
79    ///
80    /// Command examples and accepted formats for `COPY` data are shown here:
81    /// <https://www.postgresql.org/docs/current/sql-copy.html>
82    ///
83    /// ### Note
84    /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
85    /// will return an error the next time it is used.
86    fn copy_in_raw<'a>(
87        &'a self,
88        statement: &'a str,
89    ) -> BoxFuture<'a, Result<PgCopyIn<PoolConnection<Postgres>>>>;
90
91    /// Issue a `COPY TO STDOUT` statement and begin streaming data
92    /// from Postgres. This is a more efficient way to export data from Postgres but
93    /// arrives in chunks of one of a few data formats (text/CSV/binary).
94    ///
95    /// If `statement` is anything other than a `COPY ... TO STDOUT ...` command,
96    /// an error is returned.
97    ///
98    /// Note that once this process has begun, unless you read the stream to completion,
99    /// it can only be canceled in two ways:
100    ///
101    /// 1. by closing the connection, or:
102    /// 2. by using another connection to kill the server process that is sending the data as shown
103    ///    [in this StackOverflow answer](https://stackoverflow.com/a/35319598).
104    ///
105    /// If you don't read the stream to completion, the next time the connection is used it will
106    /// need to read and discard all the remaining queued data, which could take some time.
107    ///
108    /// Command examples and accepted formats for `COPY` data are shown here:
109    /// <https://www.postgresql.org/docs/current/sql-copy.html>
110    fn copy_out_raw<'a>(
111        &'a self,
112        statement: &'a str,
113    ) -> BoxFuture<'a, Result<BoxStream<'static, Result<Bytes>>>>;
114}
115
116impl PgPoolCopyExt for Pool<Postgres> {
117    fn copy_in_raw<'a>(
118        &'a self,
119        statement: &'a str,
120    ) -> BoxFuture<'a, Result<PgCopyIn<PoolConnection<Postgres>>>> {
121        Box::pin(async { PgCopyIn::begin(self.acquire().await?, statement).await })
122    }
123
124    fn copy_out_raw<'a>(
125        &'a self,
126        statement: &'a str,
127    ) -> BoxFuture<'a, Result<BoxStream<'static, Result<Bytes>>>> {
128        Box::pin(async { pg_begin_copy_out(self.acquire().await?, statement).await })
129    }
130}
131
132/// A connection in streaming `COPY FROM STDIN` mode.
133///
134/// Created by [PgConnection::copy_in_raw] or [Pool::copy_out_raw].
135///
136/// ### Note
137/// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
138/// will return an error the next time it is used.
139#[must_use = "connection will error on next use if `.finish()` or `.abort()` is not called"]
140pub struct PgCopyIn<C: DerefMut<Target = PgConnection>> {
141    conn: Option<C>,
142    response: CopyResponseData,
143}
144
145impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
146    async fn begin(mut conn: C, statement: &str) -> Result<Self> {
147        conn.wait_until_ready().await?;
148        conn.inner.stream.send(Query(statement)).await?;
149
150        let response = match conn.inner.stream.recv_expect::<CopyInResponse>().await {
151            Ok(res) => res.0,
152            Err(e) => {
153                conn.inner.stream.recv().await?;
154                return Err(e);
155            }
156        };
157
158        Ok(PgCopyIn {
159            conn: Some(conn),
160            response,
161        })
162    }
163
164    /// Returns `true` if Postgres is expecting data in text or CSV format.
165    pub fn is_textual(&self) -> bool {
166        self.response.format == 0
167    }
168
169    /// Returns the number of columns expected in the input.
170    pub fn num_columns(&self) -> usize {
171        assert_eq!(
172            self.response.num_columns.unsigned_abs() as usize,
173            self.response.format_codes.len(),
174            "num_columns does not match format_codes.len()"
175        );
176        self.response.format_codes.len()
177    }
178
179    /// Check if a column is expecting data in text format (`true`) or binary format (`false`).
180    ///
181    /// ### Panics
182    /// If `column` is out of range according to [`.num_columns()`][Self::num_columns].
183    pub fn column_is_textual(&self, column: usize) -> bool {
184        self.response.format_codes[column] == 0
185    }
186
187    /// Send a chunk of `COPY` data.
188    ///
189    /// If you're copying data from an `AsyncRead`, maybe consider [Self::read_from] instead.
190    pub async fn send(&mut self, data: impl Deref<Target = [u8]>) -> Result<&mut Self> {
191        self.conn
192            .as_deref_mut()
193            .expect("send_data: conn taken")
194            .inner
195            .stream
196            .send(CopyData(data))
197            .await?;
198
199        Ok(self)
200    }
201
202    /// Copy data directly from `source` to the database without requiring an intermediate buffer.
203    ///
204    /// `source` will be read to the end.
205    ///
206    /// ### Note: Completion Step Required
207    /// You must still call either [Self::finish] or [Self::abort] to complete the process.
208    ///
209    /// ### Note: Runtime Features
210    /// This method uses the `AsyncRead` trait which is re-exported from either Tokio or `async-std`
211    /// depending on which runtime feature is used.
212    ///
213    /// The runtime features _used_ to be mutually exclusive, but are no longer.
214    /// If both `runtime-async-std` and `runtime-tokio` features are enabled, the Tokio version
215    /// takes precedent.
216    pub async fn read_from(&mut self, mut source: impl AsyncRead + Unpin) -> Result<&mut Self> {
217        let conn: &mut PgConnection = self.conn.as_deref_mut().expect("copy_from: conn taken");
218        loop {
219            let buf = conn.inner.stream.write_buffer_mut();
220
221            // Write the CopyData format code and reserve space for the length.
222            // This may end up sending an empty `CopyData` packet if, after this point,
223            // we get canceled or read 0 bytes, but that should be fine.
224            buf.put_slice(b"d\0\0\0\x04");
225
226            let read = buf.read_from(&mut source).await?;
227
228            if read == 0 {
229                break;
230            }
231
232            // Write the length
233            let read32 = u32::try_from(read)
234                .map_err(|_| err_protocol!("number of bytes read exceeds 2^32: {}", read))?;
235
236            (&mut buf.get_mut()[1..]).put_u32(read32 + 4);
237
238            conn.inner.stream.flush().await?;
239        }
240
241        Ok(self)
242    }
243
244    /// Signal that the `COPY` process should be aborted and any data received should be discarded.
245    ///
246    /// The given message can be used for indicating the reason for the abort in the database logs.
247    ///
248    /// The server is expected to respond with an error, so only _unexpected_ errors are returned.
249    pub async fn abort(mut self, msg: impl Into<String>) -> Result<()> {
250        let mut conn = self
251            .conn
252            .take()
253            .expect("PgCopyIn::fail_with: conn taken illegally");
254
255        conn.inner.stream.send(CopyFail::new(msg)).await?;
256
257        match conn.inner.stream.recv().await {
258            Ok(msg) => Err(err_protocol!(
259                "fail_with: expected ErrorResponse, got: {:?}",
260                msg.format
261            )),
262            Err(Error::Database(e)) => {
263                match e.code() {
264                    Some(Cow::Borrowed("57014")) => {
265                        // postgres abort received error code
266                        conn.inner.stream.recv_expect::<ReadyForQuery>().await?;
267                        Ok(())
268                    }
269                    _ => Err(Error::Database(e)),
270                }
271            }
272            Err(e) => Err(e),
273        }
274    }
275
276    /// Signal that the `COPY` process is complete.
277    ///
278    /// The number of rows affected is returned.
279    pub async fn finish(mut self) -> Result<u64> {
280        let mut conn = self
281            .conn
282            .take()
283            .expect("CopyWriter::finish: conn taken illegally");
284
285        conn.inner.stream.send(CopyDone).await?;
286        let cc: CommandComplete = match conn.inner.stream.recv_expect().await {
287            Ok(cc) => cc,
288            Err(e) => {
289                conn.inner.stream.recv().await?;
290                return Err(e);
291            }
292        };
293
294        conn.inner.stream.recv_expect::<ReadyForQuery>().await?;
295
296        Ok(cc.rows_affected())
297    }
298}
299
300impl<C: DerefMut<Target = PgConnection>> Drop for PgCopyIn<C> {
301    fn drop(&mut self) {
302        if let Some(mut conn) = self.conn.take() {
303            conn.inner
304                .stream
305                .write_msg(CopyFail::new(
306                    "PgCopyIn dropped without calling finish() or fail()",
307                ))
308                .expect("BUG: PgCopyIn abort message should not be too large");
309        }
310    }
311}
312
313async fn pg_begin_copy_out<'c, C: DerefMut<Target = PgConnection> + Send + 'c>(
314    mut conn: C,
315    statement: &str,
316) -> Result<BoxStream<'c, Result<Bytes>>> {
317    conn.wait_until_ready().await?;
318    conn.inner.stream.send(Query(statement)).await?;
319
320    let _: CopyOutResponse = conn.inner.stream.recv_expect().await?;
321
322    let stream: TryAsyncStream<'c, Bytes> = try_stream! {
323        loop {
324            match conn.inner.stream.recv().await {
325                Err(e) => {
326                    conn.inner.stream.recv_expect::<ReadyForQuery>().await?;
327                    return Err(e);
328                },
329                Ok(msg) => match msg.format {
330                    BackendMessageFormat::CopyData => r#yield!(msg.decode::<CopyData<Bytes>>()?.0),
331                    BackendMessageFormat::CopyDone => {
332                        let _ = msg.decode::<CopyDone>()?;
333                        conn.inner.stream.recv_expect::<CommandComplete>().await?;
334                        conn.inner.stream.recv_expect::<ReadyForQuery>().await?;
335                        return Ok(())
336                    },
337                    _ => return Err(err_protocol!("unexpected message format during copy out: {:?}", msg.format))
338                }
339            }
340        }
341    };
342
343    Ok(Box::pin(stream))
344}