sqlx_postgres/
arguments.rs1use std::fmt::{self, Write};
2use std::ops::{Deref, DerefMut};
3use std::sync::Arc;
4
5use crate::encode::{Encode, IsNull};
6use crate::error::Error;
7use crate::ext::ustr::UStr;
8use crate::types::Type;
9use crate::{PgConnection, PgTypeInfo, Postgres};
10
11use crate::type_info::PgArrayOf;
12pub(crate) use sqlx_core::arguments::Arguments;
13use sqlx_core::error::BoxDynError;
14
15#[derive(Default)]
26pub struct PgArgumentBuffer {
27 buffer: Vec<u8>,
28
29 count: usize,
31
32 patches: Vec<Patch>,
38
39 type_holes: Vec<(usize, HoleKind)>, }
48
49enum HoleKind {
50 Type { name: UStr },
51 Array(Arc<PgArrayOf>),
52}
53
54struct Patch {
55 buf_offset: usize,
56 arg_index: usize,
57 #[allow(clippy::type_complexity)]
58 callback: Box<dyn Fn(&mut [u8], &PgTypeInfo) + 'static + Send + Sync>,
59}
60
61#[derive(Default)]
63pub struct PgArguments {
64 pub(crate) types: Vec<PgTypeInfo>,
66
67 pub(crate) buffer: PgArgumentBuffer,
69}
70
71impl PgArguments {
72 pub(crate) fn add<'q, T>(&mut self, value: T) -> Result<(), BoxDynError>
73 where
74 T: Encode<'q, Postgres> + Type<Postgres>,
75 {
76 let type_info = value.produces().unwrap_or_else(T::type_info);
77
78 let buffer_snapshot = self.buffer.snapshot();
79
80 if let Err(error) = self.buffer.encode(value) {
82 self.buffer.reset_to_snapshot(buffer_snapshot);
85 return Err(error);
86 };
87
88 self.types.push(type_info);
90 self.buffer.count += 1;
92
93 Ok(())
94 }
95
96 pub(crate) async fn apply_patches(
99 &mut self,
100 conn: &mut PgConnection,
101 parameters: &[PgTypeInfo],
102 ) -> Result<(), Error> {
103 let PgArgumentBuffer {
104 ref patches,
105 ref type_holes,
106 ref mut buffer,
107 ..
108 } = self.buffer;
109
110 for patch in patches {
111 let buf = &mut buffer[patch.buf_offset..];
112 let ty = ¶meters[patch.arg_index];
113
114 (patch.callback)(buf, ty);
115 }
116
117 for (offset, kind) in type_holes {
118 let oid = match kind {
119 HoleKind::Type { name } => conn.fetch_type_id_by_name(name).await?,
120 HoleKind::Array(array) => conn.fetch_array_type_id(array).await?,
121 };
122 buffer[*offset..(*offset + 4)].copy_from_slice(&oid.0.to_be_bytes());
123 }
124
125 Ok(())
126 }
127}
128
129impl<'q> Arguments<'q> for PgArguments {
130 type Database = Postgres;
131
132 fn reserve(&mut self, additional: usize, size: usize) {
133 self.types.reserve(additional);
134 self.buffer.reserve(size);
135 }
136
137 fn add<T>(&mut self, value: T) -> Result<(), BoxDynError>
138 where
139 T: Encode<'q, Self::Database> + Type<Self::Database>,
140 {
141 self.add(value)
142 }
143
144 fn format_placeholder<W: Write>(&self, writer: &mut W) -> fmt::Result {
145 write!(writer, "${}", self.buffer.count)
146 }
147
148 #[inline(always)]
149 fn len(&self) -> usize {
150 self.buffer.count
151 }
152}
153
154impl PgArgumentBuffer {
155 pub(crate) fn encode<'q, T>(&mut self, value: T) -> Result<(), BoxDynError>
156 where
157 T: Encode<'q, Postgres>,
158 {
159 value_size_int4_checked(value.size_hint())?;
161
162 let offset = self.len();
164
165 self.extend(&[0; 4]);
166
167 let len = if let IsNull::No = value.encode(self)? {
169 value_size_int4_checked(self.len() - offset - 4)?
171 } else {
172 debug_assert_eq!(self.len(), offset + 4);
175 -1_i32
176 };
177
178 self[offset..(offset + 4)].copy_from_slice(&len.to_be_bytes());
181
182 Ok(())
183 }
184
185 #[allow(dead_code)]
187 pub(crate) fn patch<F>(&mut self, callback: F)
188 where
189 F: Fn(&mut [u8], &PgTypeInfo) + 'static + Send + Sync,
190 {
191 let offset = self.len();
192 let arg_index = self.count;
193
194 self.patches.push(Patch {
195 buf_offset: offset,
196 arg_index,
197 callback: Box::new(callback),
198 });
199 }
200
201 pub(crate) fn patch_type_by_name(&mut self, type_name: &UStr) {
204 let offset = self.len();
205
206 self.extend_from_slice(&0_u32.to_be_bytes());
207 self.type_holes.push((
208 offset,
209 HoleKind::Type {
210 name: type_name.clone(),
211 },
212 ));
213 }
214
215 pub(crate) fn patch_array_type(&mut self, array: Arc<PgArrayOf>) {
216 let offset = self.len();
217
218 self.extend_from_slice(&0_u32.to_be_bytes());
219 self.type_holes.push((offset, HoleKind::Array(array)));
220 }
221
222 fn snapshot(&self) -> PgArgumentBufferSnapshot {
223 let Self {
224 buffer,
225 count,
226 patches,
227 type_holes,
228 } = self;
229
230 PgArgumentBufferSnapshot {
231 buffer_length: buffer.len(),
232 count: *count,
233 patches_length: patches.len(),
234 type_holes_length: type_holes.len(),
235 }
236 }
237
238 fn reset_to_snapshot(
239 &mut self,
240 PgArgumentBufferSnapshot {
241 buffer_length,
242 count,
243 patches_length,
244 type_holes_length,
245 }: PgArgumentBufferSnapshot,
246 ) {
247 self.buffer.truncate(buffer_length);
248 self.count = count;
249 self.patches.truncate(patches_length);
250 self.type_holes.truncate(type_holes_length);
251 }
252}
253
254struct PgArgumentBufferSnapshot {
255 buffer_length: usize,
256 count: usize,
257 patches_length: usize,
258 type_holes_length: usize,
259}
260
261impl Deref for PgArgumentBuffer {
262 type Target = Vec<u8>;
263
264 #[inline]
265 fn deref(&self) -> &Self::Target {
266 &self.buffer
267 }
268}
269
270impl DerefMut for PgArgumentBuffer {
271 #[inline]
272 fn deref_mut(&mut self) -> &mut Self::Target {
273 &mut self.buffer
274 }
275}
276
277pub(crate) fn value_size_int4_checked(size: usize) -> Result<i32, String> {
278 i32::try_from(size).map_err(|_| {
279 format!(
280 "value size would overflow in the binary protocol encoding: {size} > {}",
281 i32::MAX
282 )
283 })
284}