postgres_protocol/message/
frontend.rs1#![allow(missing_docs)]
3
4use byteorder::{BigEndian, ByteOrder};
5use bytes::{Buf, BufMut, BytesMut};
6use std::convert::TryFrom;
7use std::error::Error;
8use std::io;
9use std::marker;
10
11use crate::{write_nullable, FromUsize, IsNull, Oid};
12
13#[inline]
14fn write_body<F, E>(buf: &mut BytesMut, f: F) -> Result<(), E>
15where
16 F: FnOnce(&mut BytesMut) -> Result<(), E>,
17 E: From<io::Error>,
18{
19 let base = buf.len();
20 buf.extend_from_slice(&[0; 4]);
21
22 f(buf)?;
23
24 let size = i32::from_usize(buf.len() - base)?;
25 BigEndian::write_i32(&mut buf[base..], size);
26 Ok(())
27}
28
29pub enum BindError {
30 Conversion(Box<dyn Error + marker::Sync + Send>),
31 Serialization(io::Error),
32}
33
34impl From<Box<dyn Error + marker::Sync + Send>> for BindError {
35 #[inline]
36 fn from(e: Box<dyn Error + marker::Sync + Send>) -> BindError {
37 BindError::Conversion(e)
38 }
39}
40
41impl From<io::Error> for BindError {
42 #[inline]
43 fn from(e: io::Error) -> BindError {
44 BindError::Serialization(e)
45 }
46}
47
48#[inline]
49pub fn bind<I, J, F, T, K>(
50 portal: &str,
51 statement: &str,
52 formats: I,
53 values: J,
54 mut serializer: F,
55 result_formats: K,
56 buf: &mut BytesMut,
57) -> Result<(), BindError>
58where
59 I: IntoIterator<Item = i16>,
60 J: IntoIterator<Item = T>,
61 F: FnMut(T, &mut BytesMut) -> Result<IsNull, Box<dyn Error + marker::Sync + Send>>,
62 K: IntoIterator<Item = i16>,
63{
64 buf.put_u8(b'B');
65
66 write_body(buf, |buf| {
67 write_cstr(portal.as_bytes(), buf)?;
68 write_cstr(statement.as_bytes(), buf)?;
69 write_counted(
70 formats,
71 |f, buf| {
72 buf.put_i16(f);
73 Ok::<_, io::Error>(())
74 },
75 buf,
76 )?;
77 write_counted(
78 values,
79 |v, buf| write_nullable(|buf| serializer(v, buf), buf),
80 buf,
81 )?;
82 write_counted(
83 result_formats,
84 |f, buf| {
85 buf.put_i16(f);
86 Ok::<_, io::Error>(())
87 },
88 buf,
89 )?;
90
91 Ok(())
92 })
93}
94
95#[inline]
96fn write_counted<I, T, F, E>(items: I, mut serializer: F, buf: &mut BytesMut) -> Result<(), E>
97where
98 I: IntoIterator<Item = T>,
99 F: FnMut(T, &mut BytesMut) -> Result<(), E>,
100 E: From<io::Error>,
101{
102 let base = buf.len();
103 buf.extend_from_slice(&[0; 2]);
104 let mut count = 0;
105 for item in items {
106 serializer(item, buf)?;
107 count += 1;
108 }
109 let count = i16::from_usize(count)?;
110 BigEndian::write_i16(&mut buf[base..], count);
111
112 Ok(())
113}
114
115#[inline]
116pub fn cancel_request(process_id: i32, secret_key: i32, buf: &mut BytesMut) {
117 write_body(buf, |buf| {
118 buf.put_i32(80_877_102);
119 buf.put_i32(process_id);
120 buf.put_i32(secret_key);
121 Ok::<_, io::Error>(())
122 })
123 .unwrap();
124}
125
126#[inline]
127pub fn close(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> {
128 buf.put_u8(b'C');
129 write_body(buf, |buf| {
130 buf.put_u8(variant);
131 write_cstr(name.as_bytes(), buf)
132 })
133}
134
135pub struct CopyData<T> {
136 buf: T,
137 len: i32,
138}
139
140impl<T> CopyData<T>
141where
142 T: Buf,
143{
144 pub fn new(buf: T) -> io::Result<CopyData<T>> {
145 let len = buf
146 .remaining()
147 .checked_add(4)
148 .and_then(|l| i32::try_from(l).ok())
149 .ok_or_else(|| {
150 io::Error::new(io::ErrorKind::InvalidInput, "message length overflow")
151 })?;
152
153 Ok(CopyData { buf, len })
154 }
155
156 pub fn write(self, out: &mut BytesMut) {
157 out.put_u8(b'd');
158 out.put_i32(self.len);
159 out.put(self.buf);
160 }
161}
162
163#[inline]
164pub fn copy_done(buf: &mut BytesMut) {
165 buf.put_u8(b'c');
166 write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
167}
168
169#[inline]
170pub fn copy_fail(message: &str, buf: &mut BytesMut) -> io::Result<()> {
171 buf.put_u8(b'f');
172 write_body(buf, |buf| write_cstr(message.as_bytes(), buf))
173}
174
175#[inline]
176pub fn describe(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> {
177 buf.put_u8(b'D');
178 write_body(buf, |buf| {
179 buf.put_u8(variant);
180 write_cstr(name.as_bytes(), buf)
181 })
182}
183
184#[inline]
185pub fn execute(portal: &str, max_rows: i32, buf: &mut BytesMut) -> io::Result<()> {
186 buf.put_u8(b'E');
187 write_body(buf, |buf| {
188 write_cstr(portal.as_bytes(), buf)?;
189 buf.put_i32(max_rows);
190 Ok(())
191 })
192}
193
194#[inline]
195pub fn parse<I>(name: &str, query: &str, param_types: I, buf: &mut BytesMut) -> io::Result<()>
196where
197 I: IntoIterator<Item = Oid>,
198{
199 buf.put_u8(b'P');
200 write_body(buf, |buf| {
201 write_cstr(name.as_bytes(), buf)?;
202 write_cstr(query.as_bytes(), buf)?;
203 write_counted(
204 param_types,
205 |t, buf| {
206 buf.put_u32(t);
207 Ok::<_, io::Error>(())
208 },
209 buf,
210 )?;
211 Ok(())
212 })
213}
214
215#[inline]
216pub fn password_message(password: &[u8], buf: &mut BytesMut) -> io::Result<()> {
217 buf.put_u8(b'p');
218 write_body(buf, |buf| write_cstr(password, buf))
219}
220
221#[inline]
222pub fn query(query: &str, buf: &mut BytesMut) -> io::Result<()> {
223 buf.put_u8(b'Q');
224 write_body(buf, |buf| write_cstr(query.as_bytes(), buf))
225}
226
227#[inline]
228pub fn sasl_initial_response(mechanism: &str, data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
229 buf.put_u8(b'p');
230 write_body(buf, |buf| {
231 write_cstr(mechanism.as_bytes(), buf)?;
232 let len = i32::from_usize(data.len())?;
233 buf.put_i32(len);
234 buf.put_slice(data);
235 Ok(())
236 })
237}
238
239#[inline]
240pub fn sasl_response(data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
241 buf.put_u8(b'p');
242 write_body(buf, |buf| {
243 buf.put_slice(data);
244 Ok(())
245 })
246}
247
248#[inline]
249pub fn ssl_request(buf: &mut BytesMut) {
250 write_body(buf, |buf| {
251 buf.put_i32(80_877_103);
252 Ok::<_, io::Error>(())
253 })
254 .unwrap();
255}
256
257#[inline]
258pub fn startup_message<'a, I>(parameters: I, buf: &mut BytesMut) -> io::Result<()>
259where
260 I: IntoIterator<Item = (&'a str, &'a str)>,
261{
262 write_body(buf, |buf| {
263 buf.put_i32(0x00_03_00_00);
265 for (key, value) in parameters {
266 write_cstr(key.as_bytes(), buf)?;
267 write_cstr(value.as_bytes(), buf)?;
268 }
269 buf.put_u8(0);
270 Ok(())
271 })
272}
273
274#[inline]
275pub fn flush(buf: &mut BytesMut) {
276 buf.put_u8(b'H');
277 write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
278}
279
280#[inline]
281pub fn sync(buf: &mut BytesMut) {
282 buf.put_u8(b'S');
283 write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
284}
285
286#[inline]
287pub fn terminate(buf: &mut BytesMut) {
288 buf.put_u8(b'X');
289 write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
290}
291
292#[inline]
293fn write_cstr(s: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> {
294 if s.contains(&0) {
295 return Err(io::Error::new(
296 io::ErrorKind::InvalidInput,
297 "string contains embedded null",
298 ));
299 }
300 buf.put_slice(s);
301 buf.put_u8(0);
302 Ok(())
303}