postgres_protocol/message/
frontend.rs

1//! Frontend message serialization.
2#![allow(missing_docs)]
3
4use byteorder::{BigEndian, ByteOrder};
5use bytes::{Buf, BufMut, BytesMut};
6use std::convert::TryFrom;
7use std::error::Error;
8use std::io;
9use std::marker;
10
11use crate::{write_nullable, FromUsize, IsNull, Oid};
12
13#[inline]
14fn write_body<F, E>(buf: &mut BytesMut, f: F) -> Result<(), E>
15where
16    F: FnOnce(&mut BytesMut) -> Result<(), E>,
17    E: From<io::Error>,
18{
19    let base = buf.len();
20    buf.extend_from_slice(&[0; 4]);
21
22    f(buf)?;
23
24    let size = i32::from_usize(buf.len() - base)?;
25    BigEndian::write_i32(&mut buf[base..], size);
26    Ok(())
27}
28
29pub enum BindError {
30    Conversion(Box<dyn Error + marker::Sync + Send>),
31    Serialization(io::Error),
32}
33
34impl From<Box<dyn Error + marker::Sync + Send>> for BindError {
35    #[inline]
36    fn from(e: Box<dyn Error + marker::Sync + Send>) -> BindError {
37        BindError::Conversion(e)
38    }
39}
40
41impl From<io::Error> for BindError {
42    #[inline]
43    fn from(e: io::Error) -> BindError {
44        BindError::Serialization(e)
45    }
46}
47
48#[inline]
49pub fn bind<I, J, F, T, K>(
50    portal: &str,
51    statement: &str,
52    formats: I,
53    values: J,
54    mut serializer: F,
55    result_formats: K,
56    buf: &mut BytesMut,
57) -> Result<(), BindError>
58where
59    I: IntoIterator<Item = i16>,
60    J: IntoIterator<Item = T>,
61    F: FnMut(T, &mut BytesMut) -> Result<IsNull, Box<dyn Error + marker::Sync + Send>>,
62    K: IntoIterator<Item = i16>,
63{
64    buf.put_u8(b'B');
65
66    write_body(buf, |buf| {
67        write_cstr(portal.as_bytes(), buf)?;
68        write_cstr(statement.as_bytes(), buf)?;
69        write_counted(
70            formats,
71            |f, buf| {
72                buf.put_i16(f);
73                Ok::<_, io::Error>(())
74            },
75            buf,
76        )?;
77        write_counted(
78            values,
79            |v, buf| write_nullable(|buf| serializer(v, buf), buf),
80            buf,
81        )?;
82        write_counted(
83            result_formats,
84            |f, buf| {
85                buf.put_i16(f);
86                Ok::<_, io::Error>(())
87            },
88            buf,
89        )?;
90
91        Ok(())
92    })
93}
94
95#[inline]
96fn write_counted<I, T, F, E>(items: I, mut serializer: F, buf: &mut BytesMut) -> Result<(), E>
97where
98    I: IntoIterator<Item = T>,
99    F: FnMut(T, &mut BytesMut) -> Result<(), E>,
100    E: From<io::Error>,
101{
102    let base = buf.len();
103    buf.extend_from_slice(&[0; 2]);
104    let mut count = 0;
105    for item in items {
106        serializer(item, buf)?;
107        count += 1;
108    }
109    let count = i16::from_usize(count)?;
110    BigEndian::write_i16(&mut buf[base..], count);
111
112    Ok(())
113}
114
115#[inline]
116pub fn cancel_request(process_id: i32, secret_key: i32, buf: &mut BytesMut) {
117    write_body(buf, |buf| {
118        buf.put_i32(80_877_102);
119        buf.put_i32(process_id);
120        buf.put_i32(secret_key);
121        Ok::<_, io::Error>(())
122    })
123    .unwrap();
124}
125
126#[inline]
127pub fn close(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> {
128    buf.put_u8(b'C');
129    write_body(buf, |buf| {
130        buf.put_u8(variant);
131        write_cstr(name.as_bytes(), buf)
132    })
133}
134
135pub struct CopyData<T> {
136    buf: T,
137    len: i32,
138}
139
140impl<T> CopyData<T>
141where
142    T: Buf,
143{
144    pub fn new(buf: T) -> io::Result<CopyData<T>> {
145        let len = buf
146            .remaining()
147            .checked_add(4)
148            .and_then(|l| i32::try_from(l).ok())
149            .ok_or_else(|| {
150                io::Error::new(io::ErrorKind::InvalidInput, "message length overflow")
151            })?;
152
153        Ok(CopyData { buf, len })
154    }
155
156    pub fn write(self, out: &mut BytesMut) {
157        out.put_u8(b'd');
158        out.put_i32(self.len);
159        out.put(self.buf);
160    }
161}
162
163#[inline]
164pub fn copy_done(buf: &mut BytesMut) {
165    buf.put_u8(b'c');
166    write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
167}
168
169#[inline]
170pub fn copy_fail(message: &str, buf: &mut BytesMut) -> io::Result<()> {
171    buf.put_u8(b'f');
172    write_body(buf, |buf| write_cstr(message.as_bytes(), buf))
173}
174
175#[inline]
176pub fn describe(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> {
177    buf.put_u8(b'D');
178    write_body(buf, |buf| {
179        buf.put_u8(variant);
180        write_cstr(name.as_bytes(), buf)
181    })
182}
183
184#[inline]
185pub fn execute(portal: &str, max_rows: i32, buf: &mut BytesMut) -> io::Result<()> {
186    buf.put_u8(b'E');
187    write_body(buf, |buf| {
188        write_cstr(portal.as_bytes(), buf)?;
189        buf.put_i32(max_rows);
190        Ok(())
191    })
192}
193
194#[inline]
195pub fn parse<I>(name: &str, query: &str, param_types: I, buf: &mut BytesMut) -> io::Result<()>
196where
197    I: IntoIterator<Item = Oid>,
198{
199    buf.put_u8(b'P');
200    write_body(buf, |buf| {
201        write_cstr(name.as_bytes(), buf)?;
202        write_cstr(query.as_bytes(), buf)?;
203        write_counted(
204            param_types,
205            |t, buf| {
206                buf.put_u32(t);
207                Ok::<_, io::Error>(())
208            },
209            buf,
210        )?;
211        Ok(())
212    })
213}
214
215#[inline]
216pub fn password_message(password: &[u8], buf: &mut BytesMut) -> io::Result<()> {
217    buf.put_u8(b'p');
218    write_body(buf, |buf| write_cstr(password, buf))
219}
220
221#[inline]
222pub fn query(query: &str, buf: &mut BytesMut) -> io::Result<()> {
223    buf.put_u8(b'Q');
224    write_body(buf, |buf| write_cstr(query.as_bytes(), buf))
225}
226
227#[inline]
228pub fn sasl_initial_response(mechanism: &str, data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
229    buf.put_u8(b'p');
230    write_body(buf, |buf| {
231        write_cstr(mechanism.as_bytes(), buf)?;
232        let len = i32::from_usize(data.len())?;
233        buf.put_i32(len);
234        buf.put_slice(data);
235        Ok(())
236    })
237}
238
239#[inline]
240pub fn sasl_response(data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
241    buf.put_u8(b'p');
242    write_body(buf, |buf| {
243        buf.put_slice(data);
244        Ok(())
245    })
246}
247
248#[inline]
249pub fn ssl_request(buf: &mut BytesMut) {
250    write_body(buf, |buf| {
251        buf.put_i32(80_877_103);
252        Ok::<_, io::Error>(())
253    })
254    .unwrap();
255}
256
257#[inline]
258pub fn startup_message<'a, I>(parameters: I, buf: &mut BytesMut) -> io::Result<()>
259where
260    I: IntoIterator<Item = (&'a str, &'a str)>,
261{
262    write_body(buf, |buf| {
263        // postgres protocol version 3.0(196608) in bigger-endian
264        buf.put_i32(0x00_03_00_00);
265        for (key, value) in parameters {
266            write_cstr(key.as_bytes(), buf)?;
267            write_cstr(value.as_bytes(), buf)?;
268        }
269        buf.put_u8(0);
270        Ok(())
271    })
272}
273
274#[inline]
275pub fn flush(buf: &mut BytesMut) {
276    buf.put_u8(b'H');
277    write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
278}
279
280#[inline]
281pub fn sync(buf: &mut BytesMut) {
282    buf.put_u8(b'S');
283    write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
284}
285
286#[inline]
287pub fn terminate(buf: &mut BytesMut) {
288    buf.put_u8(b'X');
289    write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
290}
291
292#[inline]
293fn write_cstr(s: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> {
294    if s.contains(&0) {
295        return Err(io::Error::new(
296            io::ErrorKind::InvalidInput,
297            "string contains embedded null",
298        ));
299    }
300    buf.put_slice(s);
301    buf.put_u8(0);
302    Ok(())
303}