sqlx_postgres/copy.rs
1use std::borrow::Cow;
2use std::ops::{Deref, DerefMut};
3
4use futures_core::future::BoxFuture;
5use futures_core::stream::BoxStream;
6
7use sqlx_core::bytes::{BufMut, Bytes};
8
9use crate::connection::PgConnection;
10use crate::error::{Error, Result};
11use crate::ext::async_stream::TryAsyncStream;
12use crate::io::AsyncRead;
13use crate::message::{
14 BackendMessageFormat, CommandComplete, CopyData, CopyDone, CopyFail, CopyInResponse,
15 CopyOutResponse, CopyResponseData, Query, ReadyForQuery,
16};
17use crate::pool::{Pool, PoolConnection};
18use crate::Postgres;
19
20impl PgConnection {
21 /// Issue a `COPY FROM STDIN` statement and transition the connection to streaming data
22 /// to Postgres. This is a more efficient way to import data into Postgres as compared to
23 /// `INSERT` but requires one of a few specific data formats (text/CSV/binary).
24 ///
25 /// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is
26 /// returned.
27 ///
28 /// Command examples and accepted formats for `COPY` data are shown here:
29 /// <https://www.postgresql.org/docs/current/sql-copy.html>
30 ///
31 /// ### Note
32 /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
33 /// will return an error the next time it is used.
34 pub async fn copy_in_raw(&mut self, statement: &str) -> Result<PgCopyIn<&mut Self>> {
35 PgCopyIn::begin(self, statement).await
36 }
37
38 /// Issue a `COPY TO STDOUT` statement and transition the connection to streaming data
39 /// from Postgres. This is a more efficient way to export data from Postgres but
40 /// arrives in chunks of one of a few data formats (text/CSV/binary).
41 ///
42 /// If `statement` is anything other than a `COPY ... TO STDOUT ...` command,
43 /// an error is returned.
44 ///
45 /// Note that once this process has begun, unless you read the stream to completion,
46 /// it can only be canceled in two ways:
47 ///
48 /// 1. by closing the connection, or:
49 /// 2. by using another connection to kill the server process that is sending the data as shown
50 /// [in this StackOverflow answer](https://stackoverflow.com/a/35319598).
51 ///
52 /// If you don't read the stream to completion, the next time the connection is used it will
53 /// need to read and discard all the remaining queued data, which could take some time.
54 ///
55 /// Command examples and accepted formats for `COPY` data are shown here:
56 /// <https://www.postgresql.org/docs/current/sql-copy.html>
57 #[allow(clippy::needless_lifetimes)]
58 pub async fn copy_out_raw<'c>(
59 &'c mut self,
60 statement: &str,
61 ) -> Result<BoxStream<'c, Result<Bytes>>> {
62 pg_begin_copy_out(self, statement).await
63 }
64}
65
66/// Implements methods for directly executing `COPY FROM/TO STDOUT` on a [`PgPool`][crate::PgPool].
67///
68/// This is a replacement for the inherent methods on `PgPool` which could not exist
69/// once the Postgres driver was moved out into its own crate.
70pub trait PgPoolCopyExt {
71 /// Issue a `COPY FROM STDIN` statement and begin streaming data to Postgres.
72 /// This is a more efficient way to import data into Postgres as compared to
73 /// `INSERT` but requires one of a few specific data formats (text/CSV/binary).
74 ///
75 /// A single connection will be checked out for the duration.
76 ///
77 /// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is
78 /// returned.
79 ///
80 /// Command examples and accepted formats for `COPY` data are shown here:
81 /// <https://www.postgresql.org/docs/current/sql-copy.html>
82 ///
83 /// ### Note
84 /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
85 /// will return an error the next time it is used.
86 fn copy_in_raw<'a>(
87 &'a self,
88 statement: &'a str,
89 ) -> BoxFuture<'a, Result<PgCopyIn<PoolConnection<Postgres>>>>;
90
91 /// Issue a `COPY TO STDOUT` statement and begin streaming data
92 /// from Postgres. This is a more efficient way to export data from Postgres but
93 /// arrives in chunks of one of a few data formats (text/CSV/binary).
94 ///
95 /// If `statement` is anything other than a `COPY ... TO STDOUT ...` command,
96 /// an error is returned.
97 ///
98 /// Note that once this process has begun, unless you read the stream to completion,
99 /// it can only be canceled in two ways:
100 ///
101 /// 1. by closing the connection, or:
102 /// 2. by using another connection to kill the server process that is sending the data as shown
103 /// [in this StackOverflow answer](https://stackoverflow.com/a/35319598).
104 ///
105 /// If you don't read the stream to completion, the next time the connection is used it will
106 /// need to read and discard all the remaining queued data, which could take some time.
107 ///
108 /// Command examples and accepted formats for `COPY` data are shown here:
109 /// <https://www.postgresql.org/docs/current/sql-copy.html>
110 fn copy_out_raw<'a>(
111 &'a self,
112 statement: &'a str,
113 ) -> BoxFuture<'a, Result<BoxStream<'static, Result<Bytes>>>>;
114}
115
116impl PgPoolCopyExt for Pool<Postgres> {
117 fn copy_in_raw<'a>(
118 &'a self,
119 statement: &'a str,
120 ) -> BoxFuture<'a, Result<PgCopyIn<PoolConnection<Postgres>>>> {
121 Box::pin(async { PgCopyIn::begin(self.acquire().await?, statement).await })
122 }
123
124 fn copy_out_raw<'a>(
125 &'a self,
126 statement: &'a str,
127 ) -> BoxFuture<'a, Result<BoxStream<'static, Result<Bytes>>>> {
128 Box::pin(async { pg_begin_copy_out(self.acquire().await?, statement).await })
129 }
130}
131
132/// A connection in streaming `COPY FROM STDIN` mode.
133///
134/// Created by [PgConnection::copy_in_raw] or [Pool::copy_out_raw].
135///
136/// ### Note
137/// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
138/// will return an error the next time it is used.
139#[must_use = "connection will error on next use if `.finish()` or `.abort()` is not called"]
140pub struct PgCopyIn<C: DerefMut<Target = PgConnection>> {
141 conn: Option<C>,
142 response: CopyResponseData,
143}
144
145impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
146 async fn begin(mut conn: C, statement: &str) -> Result<Self> {
147 conn.wait_until_ready().await?;
148 conn.inner.stream.send(Query(statement)).await?;
149
150 let response = match conn.inner.stream.recv_expect::<CopyInResponse>().await {
151 Ok(res) => res.0,
152 Err(e) => {
153 conn.inner.stream.recv().await?;
154 return Err(e);
155 }
156 };
157
158 Ok(PgCopyIn {
159 conn: Some(conn),
160 response,
161 })
162 }
163
164 /// Returns `true` if Postgres is expecting data in text or CSV format.
165 pub fn is_textual(&self) -> bool {
166 self.response.format == 0
167 }
168
169 /// Returns the number of columns expected in the input.
170 pub fn num_columns(&self) -> usize {
171 assert_eq!(
172 self.response.num_columns.unsigned_abs() as usize,
173 self.response.format_codes.len(),
174 "num_columns does not match format_codes.len()"
175 );
176 self.response.format_codes.len()
177 }
178
179 /// Check if a column is expecting data in text format (`true`) or binary format (`false`).
180 ///
181 /// ### Panics
182 /// If `column` is out of range according to [`.num_columns()`][Self::num_columns].
183 pub fn column_is_textual(&self, column: usize) -> bool {
184 self.response.format_codes[column] == 0
185 }
186
187 /// Send a chunk of `COPY` data.
188 ///
189 /// If you're copying data from an `AsyncRead`, maybe consider [Self::read_from] instead.
190 pub async fn send(&mut self, data: impl Deref<Target = [u8]>) -> Result<&mut Self> {
191 self.conn
192 .as_deref_mut()
193 .expect("send_data: conn taken")
194 .inner
195 .stream
196 .send(CopyData(data))
197 .await?;
198
199 Ok(self)
200 }
201
202 /// Copy data directly from `source` to the database without requiring an intermediate buffer.
203 ///
204 /// `source` will be read to the end.
205 ///
206 /// ### Note: Completion Step Required
207 /// You must still call either [Self::finish] or [Self::abort] to complete the process.
208 ///
209 /// ### Note: Runtime Features
210 /// This method uses the `AsyncRead` trait which is re-exported from either Tokio or `async-std`
211 /// depending on which runtime feature is used.
212 ///
213 /// The runtime features _used_ to be mutually exclusive, but are no longer.
214 /// If both `runtime-async-std` and `runtime-tokio` features are enabled, the Tokio version
215 /// takes precedent.
216 pub async fn read_from(&mut self, mut source: impl AsyncRead + Unpin) -> Result<&mut Self> {
217 let conn: &mut PgConnection = self.conn.as_deref_mut().expect("copy_from: conn taken");
218 loop {
219 let buf = conn.inner.stream.write_buffer_mut();
220
221 // Write the CopyData format code and reserve space for the length.
222 // This may end up sending an empty `CopyData` packet if, after this point,
223 // we get canceled or read 0 bytes, but that should be fine.
224 buf.put_slice(b"d\0\0\0\x04");
225
226 let read = buf.read_from(&mut source).await?;
227
228 if read == 0 {
229 break;
230 }
231
232 // Write the length
233 let read32 = u32::try_from(read)
234 .map_err(|_| err_protocol!("number of bytes read exceeds 2^32: {}", read))?;
235
236 (&mut buf.get_mut()[1..]).put_u32(read32 + 4);
237
238 conn.inner.stream.flush().await?;
239 }
240
241 Ok(self)
242 }
243
244 /// Signal that the `COPY` process should be aborted and any data received should be discarded.
245 ///
246 /// The given message can be used for indicating the reason for the abort in the database logs.
247 ///
248 /// The server is expected to respond with an error, so only _unexpected_ errors are returned.
249 pub async fn abort(mut self, msg: impl Into<String>) -> Result<()> {
250 let mut conn = self
251 .conn
252 .take()
253 .expect("PgCopyIn::fail_with: conn taken illegally");
254
255 conn.inner.stream.send(CopyFail::new(msg)).await?;
256
257 match conn.inner.stream.recv().await {
258 Ok(msg) => Err(err_protocol!(
259 "fail_with: expected ErrorResponse, got: {:?}",
260 msg.format
261 )),
262 Err(Error::Database(e)) => {
263 match e.code() {
264 Some(Cow::Borrowed("57014")) => {
265 // postgres abort received error code
266 conn.inner.stream.recv_expect::<ReadyForQuery>().await?;
267 Ok(())
268 }
269 _ => Err(Error::Database(e)),
270 }
271 }
272 Err(e) => Err(e),
273 }
274 }
275
276 /// Signal that the `COPY` process is complete.
277 ///
278 /// The number of rows affected is returned.
279 pub async fn finish(mut self) -> Result<u64> {
280 let mut conn = self
281 .conn
282 .take()
283 .expect("CopyWriter::finish: conn taken illegally");
284
285 conn.inner.stream.send(CopyDone).await?;
286 let cc: CommandComplete = match conn.inner.stream.recv_expect().await {
287 Ok(cc) => cc,
288 Err(e) => {
289 conn.inner.stream.recv().await?;
290 return Err(e);
291 }
292 };
293
294 conn.inner.stream.recv_expect::<ReadyForQuery>().await?;
295
296 Ok(cc.rows_affected())
297 }
298}
299
300impl<C: DerefMut<Target = PgConnection>> Drop for PgCopyIn<C> {
301 fn drop(&mut self) {
302 if let Some(mut conn) = self.conn.take() {
303 conn.inner
304 .stream
305 .write_msg(CopyFail::new(
306 "PgCopyIn dropped without calling finish() or fail()",
307 ))
308 .expect("BUG: PgCopyIn abort message should not be too large");
309 }
310 }
311}
312
313async fn pg_begin_copy_out<'c, C: DerefMut<Target = PgConnection> + Send + 'c>(
314 mut conn: C,
315 statement: &str,
316) -> Result<BoxStream<'c, Result<Bytes>>> {
317 conn.wait_until_ready().await?;
318 conn.inner.stream.send(Query(statement)).await?;
319
320 let _: CopyOutResponse = conn.inner.stream.recv_expect().await?;
321
322 let stream: TryAsyncStream<'c, Bytes> = try_stream! {
323 loop {
324 match conn.inner.stream.recv().await {
325 Err(e) => {
326 conn.inner.stream.recv_expect::<ReadyForQuery>().await?;
327 return Err(e);
328 },
329 Ok(msg) => match msg.format {
330 BackendMessageFormat::CopyData => r#yield!(msg.decode::<CopyData<Bytes>>()?.0),
331 BackendMessageFormat::CopyDone => {
332 let _ = msg.decode::<CopyDone>()?;
333 conn.inner.stream.recv_expect::<CommandComplete>().await?;
334 conn.inner.stream.recv_expect::<ReadyForQuery>().await?;
335 return Ok(())
336 },
337 _ => return Err(err_protocol!("unexpected message format during copy out: {:?}", msg.format))
338 }
339 }
340 }
341 };
342
343 Ok(Box::pin(stream))
344}