odbc_api/handles/
sql_char.rs

1//! The idea is to handle most of the conditional compilation around different SQL character types
2//! in this module, so the rest of the crate doesn't have to.
3
4// The rather akward expression:
5// `#[cfg(any(feature = "wide", all(not(feature = "narrow"), target_os = "windows")))]` is used to
6// annotate things which should only compile if we use UTF-16 to communicate to the data source.
7// We use its negation:
8// `#[cfg(not(any(feature = "wide", all(not(feature = "narrow"), target_os = "windows"))))]` to
9// indicate a "narrow" charset for communicating with the datasource, which we assume to be UTF-8
10//
11// Currently I did not find a better way to use narrow function on non-windows platforms and wide
12// functions on windows platforms by default. I also want to enable explicitly overwriting the
13// default on both platforms. See also the documentation of the `narrow` and `wide` features in the
14// Cargo.toml manifest.
15
16use super::buffer::{buf_ptr, mut_buf_ptr};
17use std::{
18    borrow::Cow,
19    mem::{size_of, size_of_val},
20};
21
22#[cfg(not(any(feature = "wide", all(not(feature = "narrow"), target_os = "windows"))))]
23use std::{ffi::CStr, string::FromUtf8Error};
24
25#[cfg(any(feature = "wide", all(not(feature = "narrow"), target_os = "windows")))]
26use std::{
27    char::{decode_utf16, DecodeUtf16Error},
28    marker::PhantomData,
29};
30
31#[cfg(any(feature = "wide", all(not(feature = "narrow"), target_os = "windows")))]
32use widestring::{U16CStr, U16String};
33
34#[cfg(not(any(feature = "wide", all(not(feature = "narrow"), target_os = "windows"))))]
35pub type SqlChar = u8;
36#[cfg(any(feature = "wide", all(not(feature = "narrow"), target_os = "windows")))]
37pub type SqlChar = u16;
38
39#[cfg(not(any(feature = "wide", all(not(feature = "narrow"), target_os = "windows"))))]
40pub type DecodingError = FromUtf8Error;
41#[cfg(any(feature = "wide", all(not(feature = "narrow"), target_os = "windows")))]
42pub type DecodingError = DecodeUtf16Error;
43
44#[cfg(not(any(feature = "wide", all(not(feature = "narrow"), target_os = "windows"))))]
45pub fn slice_to_utf8(text: &[u8]) -> Result<String, FromUtf8Error> {
46    String::from_utf8(text.to_owned())
47}
48#[cfg(any(feature = "wide", all(not(feature = "narrow"), target_os = "windows")))]
49pub fn slice_to_utf8(text: &[u16]) -> Result<String, DecodeUtf16Error> {
50    decode_utf16(text.iter().copied()).collect()
51}
52
53#[cfg(not(any(feature = "wide", all(not(feature = "narrow"), target_os = "windows"))))]
54pub fn slice_to_cow_utf8(text: &[u8]) -> Cow<str> {
55    String::from_utf8_lossy(text)
56}
57#[cfg(any(feature = "wide", all(not(feature = "narrow"), target_os = "windows")))]
58pub fn slice_to_cow_utf8(text: &[u16]) -> Cow<str> {
59    let text: Result<String, _> = decode_utf16(text.iter().copied()).collect();
60    text.unwrap().into()
61}
62
63#[cfg(any(feature = "wide", all(not(feature = "narrow"), target_os = "windows")))]
64fn sz_to_utf8(buffer: &[u16]) -> String {
65    let c_str = U16CStr::from_slice_truncate(buffer).unwrap();
66    c_str.to_string_lossy()
67}
68#[cfg(not(any(feature = "wide", all(not(feature = "narrow"), target_os = "windows"))))]
69fn sz_to_utf8(buffer: &[u8]) -> String {
70    // Truncate slice at first zero.
71    let end = buffer
72        .iter()
73        .enumerate()
74        .find(|(_index, &character)| character == b'\0')
75        .expect("Buffer must contain terminating zero.")
76        .0;
77    let c_str = unsafe { CStr::from_bytes_with_nul_unchecked(&buffer[..=end]) };
78    c_str.to_string_lossy().into_owned()
79}
80
81/// Buffer length in bytes, not characters
82pub fn binary_length(buffer: &[SqlChar]) -> usize {
83    size_of_val(buffer)
84}
85
86/// `true` if the buffer has not been large enough to hold the entire string.
87///
88/// # Parameters
89///
90/// - `actuel_length_bin`: Actual length in bytes, but excluding the terminating zero.
91pub fn is_truncated_bin(buffer: &[SqlChar], actual_length_bin: usize) -> bool {
92    size_of_val(buffer) <= actual_length_bin
93}
94
95/// Resizes the underlying buffer to fit the size required to hold the entire string including
96/// terminating zero. Required length is provided in bytes (not characters), excluding the
97/// terminating zero.
98pub fn resize_to_fit_with_tz(buffer: &mut Vec<SqlChar>, required_binary_length: usize) {
99    // In order to use only minimal memory for drivers which stick to the ODBC standard we would
100    // use `+1` in the statement beneath. However it turns out the PostgreSQL driver will fill the
101    // last value with `0` instead of the last latter when used with a wide `SqlChar`. So we use
102    // `+2` to make it work with PostgreSql on windows, too.
103    buffer.resize((required_binary_length / size_of::<SqlChar>()) + 2, 0);
104}
105
106/// Resizes the underlying buffer to fit the size required to hold the entire string excluding
107/// terminating zero. Required length is provided in bytes (not characters), excluding the
108/// terminating zero.
109pub fn resize_to_fit_without_tz(buffer: &mut Vec<SqlChar>, required_binary_length: usize) {
110    buffer.resize(required_binary_length / size_of::<SqlChar>(), 0);
111}
112
113/// Handles conversion from UTF-8 string slices to ODBC SQL char encoding. Depending on the
114/// conditional compiliation due to feature flags, the UTF-8 strings are either passed without
115/// conversion to narrow method calls, or they are converted to UTF-16, before passed to the wide
116/// methods.
117pub struct SqlText<'a> {
118    /// In case we use wide methods we need to convert to UTF-16. We'll take ownership of the buffer
119    /// here.
120    #[cfg(any(feature = "wide", all(not(feature = "narrow"), target_os = "windows")))]
121    text: U16String,
122    /// We include the lifetime in the declaration of the type still, so the borrow checker
123    /// complains, if we would mess up the compilation for narrow methods.
124    #[cfg(any(feature = "wide", all(not(feature = "narrow"), target_os = "windows")))]
125    _ref: PhantomData<&'a str>,
126    /// In the case of narrow compiliation we just forward the string silce unchanged
127    #[cfg(not(any(feature = "wide", all(not(feature = "narrow"), target_os = "windows"))))]
128    text: &'a str,
129}
130
131impl<'a> SqlText<'a> {
132    #[cfg(any(feature = "wide", all(not(feature = "narrow"), target_os = "windows")))]
133    /// Create an SqlText buffer from an UTF-8 string slice
134    pub fn new(text: &'a str) -> Self {
135        Self {
136            text: U16String::from_str(text),
137            _ref: PhantomData,
138        }
139    }
140    #[cfg(not(any(feature = "wide", all(not(feature = "narrow"), target_os = "windows"))))]
141    /// Create an SqlText buffer from an UTF-8 string slice
142    pub fn new(text: &'a str) -> Self {
143        Self { text }
144    }
145
146    #[cfg(any(feature = "wide", all(not(feature = "narrow"), target_os = "windows")))]
147    pub fn ptr(&self) -> *const u16 {
148        buf_ptr(self.text.as_slice())
149    }
150    #[cfg(not(any(feature = "wide", all(not(feature = "narrow"), target_os = "windows"))))]
151    pub fn ptr(&self) -> *const u8 {
152        buf_ptr(self.text.as_bytes())
153    }
154
155    /// Length in characters
156    pub fn len_char(&self) -> usize {
157        self.text.len()
158    }
159}
160
161/// Use this buffer type to fetch zero terminated strings from the ODBC API. Either allocates a
162/// buffer for wide or narrow strings dependend on the features set.
163pub struct SzBuffer {
164    buffer: Vec<SqlChar>,
165}
166
167impl SzBuffer {
168    /// Creates a buffer which can hold at least `capacity` characters, excluding the terminating
169    /// zero. Or phrased differently. It will allocate one additional character to hold the
170    /// terminating zero, so the caller should not factor it into the size of capacity.
171    pub fn with_capacity(capacity: usize) -> Self {
172        Self {
173            // Allocate +1 character extra for terminating zero
174            buffer: vec![0; capacity + 1],
175        }
176    }
177
178    pub fn mut_buf(&mut self) -> &mut [SqlChar] {
179        // Use full capacity
180        self.buffer.resize(self.buffer.capacity(), 0);
181        &mut self.buffer
182    }
183
184    /// Create an owned utf-8 string from the internal buffer representation.
185    pub fn to_utf8(&self) -> String {
186        sz_to_utf8(&self.buffer)
187    }
188}
189
190/// We use this as an output buffer for strings. Allows for detecting truncation.
191pub struct OutputStringBuffer {
192    /// Buffer holding the string. Must also contains space for a terminating zero.
193    buffer: Vec<SqlChar>,
194    /// After the buffer has been filled, this should contain the actual length of the string. Can
195    /// be used to detect truncation.
196    actual_length: i16,
197}
198
199impl OutputStringBuffer {
200    /// Creates an empty string buffer. Useful if you want to e.g. use a prompt to complete the
201    /// connection string, but are not interested in the actual completed connection string.
202    pub fn empty() -> Self {
203        Self::with_buffer_size(0)
204    }
205
206    /// Creates a new instance of an output string buffer which can hold strings up to a size of
207    /// `max_str_len - 1` characters. `-1 because one place is needed for the terminating zero.
208    /// To hold a connection string the size should be at least 1024.
209    pub fn with_buffer_size(max_str_len: usize) -> Self {
210        Self {
211            buffer: vec![0; max_str_len],
212            actual_length: 0,
213        }
214    }
215
216    /// Ptr to the internal buffer. Used by ODBC API calls to fill the buffer.
217    pub fn mut_buf_ptr(&mut self) -> *mut SqlChar {
218        mut_buf_ptr(&mut self.buffer)
219    }
220
221    /// Length of the internal buffer in characters including the terminating zero.
222    pub fn buf_len(&self) -> i16 {
223        // Since buffer must always be able to hold at least one element, substracting `1` is always
224        // defined
225        self.buffer.len().try_into().unwrap()
226    }
227
228    /// Mutable pointer to actual output string length. Used by ODBC API calls to report truncation.
229    pub fn mut_actual_len_ptr(&mut self) -> *mut i16 {
230        &mut self.actual_length as *mut i16
231    }
232
233    /// Call this method to extract string from buffer after ODBC has filled it.
234    pub fn to_utf8(&self) -> String {
235        if self.buffer.is_empty() {
236            return String::new();
237        }
238
239        if self.is_truncated() {
240            // If the string is truncated we return the entire buffer excluding the terminating
241            // zero.
242            slice_to_utf8(&self.buffer[0..(self.buffer.len() - 1)]).unwrap()
243        } else {
244            // If the string is not truncated, we return not the entire buffer, but only the slice
245            // containing the actual string.
246            let actual_length: usize = self.actual_length.try_into().unwrap();
247            slice_to_utf8(&self.buffer[0..actual_length]).unwrap()
248        }
249    }
250
251    /// True if the buffer had not been large enough to hold the string.
252    pub fn is_truncated(&self) -> bool {
253        self.actual_length >= self.buffer.len().try_into().unwrap()
254    }
255}