sqlx_postgres/
arguments.rs

1use std::fmt::{self, Write};
2use std::ops::{Deref, DerefMut};
3use std::sync::Arc;
4
5use crate::encode::{Encode, IsNull};
6use crate::error::Error;
7use crate::ext::ustr::UStr;
8use crate::types::Type;
9use crate::{PgConnection, PgTypeInfo, Postgres};
10
11use crate::type_info::PgArrayOf;
12pub(crate) use sqlx_core::arguments::Arguments;
13use sqlx_core::error::BoxDynError;
14
15// TODO: buf.patch(|| ...) is a poor name, can we think of a better name? Maybe `buf.lazy(||)` ?
16// TODO: Extend the patch system to support dynamic lengths
17//       Considerations:
18//          - The prefixed-len offset needs to be back-tracked and updated
19//          - message::Bind needs to take a &PgArguments and use a `write` method instead of
20//            referencing a buffer directly
21//          - The basic idea is that we write bytes for the buffer until we get somewhere
22//            that has a patch, we then apply the patch which should write to &mut Vec<u8>,
23//            backtrack and update the prefixed-len, then write until the next patch offset
24
25#[derive(Default)]
26pub struct PgArgumentBuffer {
27    buffer: Vec<u8>,
28
29    // Number of arguments
30    count: usize,
31
32    // Whenever an `Encode` impl needs to defer some work until after we resolve parameter types
33    // it can use `patch`.
34    //
35    // This currently is only setup to be useful if there is a *fixed-size* slot that needs to be
36    // tweaked from the input type. However, that's the only use case we currently have.
37    patches: Vec<Patch>,
38
39    // Whenever an `Encode` impl encounters a `PgTypeInfo` object that does not have an OID
40    // It pushes a "hole" that must be patched later.
41    //
42    // The hole is a `usize` offset into the buffer with the type name that should be resolved
43    // This is done for Records and Arrays as the OID is needed well before we are in an async
44    // function and can just ask postgres.
45    //
46    type_holes: Vec<(usize, HoleKind)>, // Vec<{ offset, type_name }>
47}
48
49enum HoleKind {
50    Type { name: UStr },
51    Array(Arc<PgArrayOf>),
52}
53
54struct Patch {
55    buf_offset: usize,
56    arg_index: usize,
57    #[allow(clippy::type_complexity)]
58    callback: Box<dyn Fn(&mut [u8], &PgTypeInfo) + 'static + Send + Sync>,
59}
60
61/// Implementation of [`Arguments`] for PostgreSQL.
62#[derive(Default)]
63pub struct PgArguments {
64    // Types of each bind parameter
65    pub(crate) types: Vec<PgTypeInfo>,
66
67    // Buffer of encoded bind parameters
68    pub(crate) buffer: PgArgumentBuffer,
69}
70
71impl PgArguments {
72    pub(crate) fn add<'q, T>(&mut self, value: T) -> Result<(), BoxDynError>
73    where
74        T: Encode<'q, Postgres> + Type<Postgres>,
75    {
76        let type_info = value.produces().unwrap_or_else(T::type_info);
77
78        let buffer_snapshot = self.buffer.snapshot();
79
80        // encode the value into our buffer
81        if let Err(error) = self.buffer.encode(value) {
82            // reset the value buffer to its previous value if encoding failed,
83            // so we don't leave a half-encoded value behind
84            self.buffer.reset_to_snapshot(buffer_snapshot);
85            return Err(error);
86        };
87
88        // remember the type information for this value
89        self.types.push(type_info);
90        // increment the number of arguments we are tracking
91        self.buffer.count += 1;
92
93        Ok(())
94    }
95
96    // Apply patches
97    // This should only go out and ask postgres if we have not seen the type name yet
98    pub(crate) async fn apply_patches(
99        &mut self,
100        conn: &mut PgConnection,
101        parameters: &[PgTypeInfo],
102    ) -> Result<(), Error> {
103        let PgArgumentBuffer {
104            ref patches,
105            ref type_holes,
106            ref mut buffer,
107            ..
108        } = self.buffer;
109
110        for patch in patches {
111            let buf = &mut buffer[patch.buf_offset..];
112            let ty = &parameters[patch.arg_index];
113
114            (patch.callback)(buf, ty);
115        }
116
117        for (offset, kind) in type_holes {
118            let oid = match kind {
119                HoleKind::Type { name } => conn.fetch_type_id_by_name(name).await?,
120                HoleKind::Array(array) => conn.fetch_array_type_id(array).await?,
121            };
122            buffer[*offset..(*offset + 4)].copy_from_slice(&oid.0.to_be_bytes());
123        }
124
125        Ok(())
126    }
127}
128
129impl<'q> Arguments<'q> for PgArguments {
130    type Database = Postgres;
131
132    fn reserve(&mut self, additional: usize, size: usize) {
133        self.types.reserve(additional);
134        self.buffer.reserve(size);
135    }
136
137    fn add<T>(&mut self, value: T) -> Result<(), BoxDynError>
138    where
139        T: Encode<'q, Self::Database> + Type<Self::Database>,
140    {
141        self.add(value)
142    }
143
144    fn format_placeholder<W: Write>(&self, writer: &mut W) -> fmt::Result {
145        write!(writer, "${}", self.buffer.count)
146    }
147
148    #[inline(always)]
149    fn len(&self) -> usize {
150        self.buffer.count
151    }
152}
153
154impl PgArgumentBuffer {
155    pub(crate) fn encode<'q, T>(&mut self, value: T) -> Result<(), BoxDynError>
156    where
157        T: Encode<'q, Postgres>,
158    {
159        // Won't catch everything but is a good sanity check
160        value_size_int4_checked(value.size_hint())?;
161
162        // reserve space to write the prefixed length of the value
163        let offset = self.len();
164
165        self.extend(&[0; 4]);
166
167        // encode the value into our buffer
168        let len = if let IsNull::No = value.encode(self)? {
169            // Ensure that the value size does not overflow i32
170            value_size_int4_checked(self.len() - offset - 4)?
171        } else {
172            // Write a -1 to indicate NULL
173            // NOTE: It is illegal for [encode] to write any data
174            debug_assert_eq!(self.len(), offset + 4);
175            -1_i32
176        };
177
178        // write the len to the beginning of the value
179        // (offset + 4) cannot overflow because it would have failed at `self.extend()`.
180        self[offset..(offset + 4)].copy_from_slice(&len.to_be_bytes());
181
182        Ok(())
183    }
184
185    // Adds a callback to be invoked later when we know the parameter type
186    #[allow(dead_code)]
187    pub(crate) fn patch<F>(&mut self, callback: F)
188    where
189        F: Fn(&mut [u8], &PgTypeInfo) + 'static + Send + Sync,
190    {
191        let offset = self.len();
192        let arg_index = self.count;
193
194        self.patches.push(Patch {
195            buf_offset: offset,
196            arg_index,
197            callback: Box::new(callback),
198        });
199    }
200
201    // Extends the inner buffer by enough space to have an OID
202    // Remembers where the OID goes and type name for the OID
203    pub(crate) fn patch_type_by_name(&mut self, type_name: &UStr) {
204        let offset = self.len();
205
206        self.extend_from_slice(&0_u32.to_be_bytes());
207        self.type_holes.push((
208            offset,
209            HoleKind::Type {
210                name: type_name.clone(),
211            },
212        ));
213    }
214
215    pub(crate) fn patch_array_type(&mut self, array: Arc<PgArrayOf>) {
216        let offset = self.len();
217
218        self.extend_from_slice(&0_u32.to_be_bytes());
219        self.type_holes.push((offset, HoleKind::Array(array)));
220    }
221
222    fn snapshot(&self) -> PgArgumentBufferSnapshot {
223        let Self {
224            buffer,
225            count,
226            patches,
227            type_holes,
228        } = self;
229
230        PgArgumentBufferSnapshot {
231            buffer_length: buffer.len(),
232            count: *count,
233            patches_length: patches.len(),
234            type_holes_length: type_holes.len(),
235        }
236    }
237
238    fn reset_to_snapshot(
239        &mut self,
240        PgArgumentBufferSnapshot {
241            buffer_length,
242            count,
243            patches_length,
244            type_holes_length,
245        }: PgArgumentBufferSnapshot,
246    ) {
247        self.buffer.truncate(buffer_length);
248        self.count = count;
249        self.patches.truncate(patches_length);
250        self.type_holes.truncate(type_holes_length);
251    }
252}
253
254struct PgArgumentBufferSnapshot {
255    buffer_length: usize,
256    count: usize,
257    patches_length: usize,
258    type_holes_length: usize,
259}
260
261impl Deref for PgArgumentBuffer {
262    type Target = Vec<u8>;
263
264    #[inline]
265    fn deref(&self) -> &Self::Target {
266        &self.buffer
267    }
268}
269
270impl DerefMut for PgArgumentBuffer {
271    #[inline]
272    fn deref_mut(&mut self) -> &mut Self::Target {
273        &mut self.buffer
274    }
275}
276
277pub(crate) fn value_size_int4_checked(size: usize) -> Result<i32, String> {
278    i32::try_from(size).map_err(|_| {
279        format!(
280            "value size would overflow in the binary protocol encoding: {size} > {}",
281            i32::MAX
282        )
283    })
284}