base64ct/
encoding.rs

1//! Base64 encodings
2
3use crate::{
4    alphabet::Alphabet,
5    errors::{Error, InvalidEncodingError, InvalidLengthError},
6};
7use core::str;
8
9#[cfg(feature = "alloc")]
10use alloc::{string::String, vec::Vec};
11
12#[cfg(doc)]
13use crate::{Base64, Base64Bcrypt, Base64Crypt, Base64Unpadded, Base64Url, Base64UrlUnpadded};
14
15/// Padding character
16const PAD: u8 = b'=';
17
18/// Base64 encoding trait.
19///
20/// This trait must be imported to make use of any Base64 alphabet defined
21/// in this crate.
22///
23/// The following encoding types impl this trait:
24///
25/// - [`Base64`]: standard Base64 encoding with `=` padding.
26/// - [`Base64Bcrypt`]: bcrypt Base64 encoding.
27/// - [`Base64Crypt`]: `crypt(3)` Base64 encoding.
28/// - [`Base64Unpadded`]: standard Base64 encoding *without* padding.
29/// - [`Base64Url`]: URL-safe Base64 encoding with `=` padding.
30/// - [`Base64UrlUnpadded`]: URL-safe Base64 encoding *without* padding.
31pub trait Encoding: Alphabet {
32    /// Decode a Base64 string into the provided destination buffer.
33    fn decode(src: impl AsRef<[u8]>, dst: &mut [u8]) -> Result<&[u8], Error>;
34
35    /// Decode a Base64 string in-place.
36    ///
37    /// NOTE: this method does not (yet) validate that padding is well-formed,
38    /// if the given Base64 encoding is padded.
39    fn decode_in_place(buf: &mut [u8]) -> Result<&[u8], InvalidEncodingError>;
40
41    /// Decode a Base64 string into a byte vector.
42    #[cfg(feature = "alloc")]
43    fn decode_vec(input: &str) -> Result<Vec<u8>, Error>;
44
45    /// Encode the input byte slice as Base64.
46    ///
47    /// Writes the result into the provided destination slice, returning an
48    /// ASCII-encoded Base64 string value.
49    fn encode<'a>(src: &[u8], dst: &'a mut [u8]) -> Result<&'a str, InvalidLengthError>;
50
51    /// Encode input byte slice into a [`String`] containing Base64.
52    ///
53    /// # Panics
54    /// If `input` length is greater than `usize::MAX/4`.
55    #[cfg(feature = "alloc")]
56    fn encode_string(input: &[u8]) -> String;
57
58    /// Get the length of Base64 produced by encoding the given bytes.
59    ///
60    /// WARNING: this function will return `0` for lengths greater than `usize::MAX/4`!
61    fn encoded_len(bytes: &[u8]) -> usize;
62}
63
64impl<T: Alphabet> Encoding for T {
65    fn decode(src: impl AsRef<[u8]>, dst: &mut [u8]) -> Result<&[u8], Error> {
66        let (src_unpadded, mut err) = if T::PADDED {
67            let (unpadded_len, e) = decode_padding(src.as_ref())?;
68            (&src.as_ref()[..unpadded_len], e)
69        } else {
70            (src.as_ref(), 0)
71        };
72
73        let dlen = decoded_len(src_unpadded.len());
74
75        if dlen > dst.len() {
76            return Err(Error::InvalidLength);
77        }
78
79        let dst = &mut dst[..dlen];
80
81        let mut src_chunks = src_unpadded.chunks_exact(4);
82        let mut dst_chunks = dst.chunks_exact_mut(3);
83        for (s, d) in (&mut src_chunks).zip(&mut dst_chunks) {
84            err |= Self::decode_3bytes(s, d);
85        }
86        let src_rem = src_chunks.remainder();
87        let dst_rem = dst_chunks.into_remainder();
88
89        err |= !(src_rem.is_empty() || src_rem.len() >= 2) as i16;
90        let mut tmp_out = [0u8; 3];
91        let mut tmp_in = [b'A'; 4];
92        tmp_in[..src_rem.len()].copy_from_slice(src_rem);
93        err |= Self::decode_3bytes(&tmp_in, &mut tmp_out);
94        dst_rem.copy_from_slice(&tmp_out[..dst_rem.len()]);
95
96        if err == 0 {
97            validate_last_block::<T>(src.as_ref(), dst)?;
98            Ok(dst)
99        } else {
100            Err(Error::InvalidEncoding)
101        }
102    }
103
104    // TODO(tarcieri): explicitly checked/wrapped arithmetic
105    #[allow(clippy::arithmetic_side_effects)]
106    fn decode_in_place(mut buf: &mut [u8]) -> Result<&[u8], InvalidEncodingError> {
107        // TODO: eliminate unsafe code when LLVM12 is stable
108        // See: https://github.com/rust-lang/rust/issues/80963
109        let mut err = if T::PADDED {
110            let (unpadded_len, e) = decode_padding(buf)?;
111            buf = &mut buf[..unpadded_len];
112            e
113        } else {
114            0
115        };
116
117        let dlen = decoded_len(buf.len());
118        let full_chunks = buf.len() / 4;
119
120        for chunk in 0..full_chunks {
121            // SAFETY: `p3` and `p4` point inside `buf`, while they may overlap,
122            // read and write are clearly separated from each other and done via
123            // raw pointers.
124            #[allow(unsafe_code)]
125            unsafe {
126                debug_assert!(3 * chunk + 3 <= buf.len());
127                debug_assert!(4 * chunk + 4 <= buf.len());
128
129                let p3 = buf.as_mut_ptr().add(3 * chunk) as *mut [u8; 3];
130                let p4 = buf.as_ptr().add(4 * chunk) as *const [u8; 4];
131
132                let mut tmp_out = [0u8; 3];
133                err |= Self::decode_3bytes(&*p4, &mut tmp_out);
134                *p3 = tmp_out;
135            }
136        }
137
138        let src_rem_pos = 4 * full_chunks;
139        let src_rem_len = buf.len() - src_rem_pos;
140        let dst_rem_pos = 3 * full_chunks;
141        let dst_rem_len = dlen - dst_rem_pos;
142
143        err |= !(src_rem_len == 0 || src_rem_len >= 2) as i16;
144        let mut tmp_in = [b'A'; 4];
145        tmp_in[..src_rem_len].copy_from_slice(&buf[src_rem_pos..]);
146        let mut tmp_out = [0u8; 3];
147
148        err |= Self::decode_3bytes(&tmp_in, &mut tmp_out);
149
150        if err == 0 {
151            // SAFETY: `dst_rem_len` is always smaller than 4, so we don't
152            // read outside of `tmp_out`, write and the final slicing never go
153            // outside of `buf`.
154            #[allow(unsafe_code)]
155            unsafe {
156                debug_assert!(dst_rem_pos + dst_rem_len <= buf.len());
157                debug_assert!(dst_rem_len <= tmp_out.len());
158                debug_assert!(dlen <= buf.len());
159
160                core::ptr::copy_nonoverlapping(
161                    tmp_out.as_ptr(),
162                    buf.as_mut_ptr().add(dst_rem_pos),
163                    dst_rem_len,
164                );
165                Ok(buf.get_unchecked(..dlen))
166            }
167        } else {
168            Err(InvalidEncodingError)
169        }
170    }
171
172    #[cfg(feature = "alloc")]
173    fn decode_vec(input: &str) -> Result<Vec<u8>, Error> {
174        let mut output = vec![0u8; decoded_len(input.len())];
175        let len = Self::decode(input, &mut output)?.len();
176
177        if len <= output.len() {
178            output.truncate(len);
179            Ok(output)
180        } else {
181            Err(Error::InvalidLength)
182        }
183    }
184
185    fn encode<'a>(src: &[u8], dst: &'a mut [u8]) -> Result<&'a str, InvalidLengthError> {
186        let elen = match encoded_len_inner(src.len(), T::PADDED) {
187            Some(v) => v,
188            None => return Err(InvalidLengthError),
189        };
190
191        if elen > dst.len() {
192            return Err(InvalidLengthError);
193        }
194
195        let dst = &mut dst[..elen];
196
197        let mut src_chunks = src.chunks_exact(3);
198        let mut dst_chunks = dst.chunks_exact_mut(4);
199
200        for (s, d) in (&mut src_chunks).zip(&mut dst_chunks) {
201            Self::encode_3bytes(s, d);
202        }
203
204        let src_rem = src_chunks.remainder();
205
206        if T::PADDED {
207            if let Some(dst_rem) = dst_chunks.next() {
208                let mut tmp = [0u8; 3];
209                tmp[..src_rem.len()].copy_from_slice(src_rem);
210                Self::encode_3bytes(&tmp, dst_rem);
211
212                let flag = src_rem.len() == 1;
213                let mask = (flag as u8).wrapping_sub(1);
214                dst_rem[2] = (dst_rem[2] & mask) | (PAD & !mask);
215                dst_rem[3] = PAD;
216            }
217        } else {
218            let dst_rem = dst_chunks.into_remainder();
219
220            let mut tmp_in = [0u8; 3];
221            let mut tmp_out = [0u8; 4];
222            tmp_in[..src_rem.len()].copy_from_slice(src_rem);
223            Self::encode_3bytes(&tmp_in, &mut tmp_out);
224            dst_rem.copy_from_slice(&tmp_out[..dst_rem.len()]);
225        }
226
227        debug_assert!(str::from_utf8(dst).is_ok());
228
229        // SAFETY: values written by `encode_3bytes` are valid one-byte UTF-8 chars
230        #[allow(unsafe_code)]
231        Ok(unsafe { str::from_utf8_unchecked(dst) })
232    }
233
234    #[cfg(feature = "alloc")]
235    fn encode_string(input: &[u8]) -> String {
236        let elen = encoded_len_inner(input.len(), T::PADDED).expect("input is too big");
237        let mut dst = vec![0u8; elen];
238        let res = Self::encode(input, &mut dst).expect("encoding error");
239
240        debug_assert_eq!(elen, res.len());
241        debug_assert!(str::from_utf8(&dst).is_ok());
242
243        // SAFETY: `dst` is fully written and contains only valid one-byte UTF-8 chars
244        #[allow(unsafe_code)]
245        unsafe {
246            String::from_utf8_unchecked(dst)
247        }
248    }
249
250    fn encoded_len(bytes: &[u8]) -> usize {
251        encoded_len_inner(bytes.len(), T::PADDED).unwrap_or(0)
252    }
253}
254
255/// Validate padding is of the expected length compute unpadded length.
256///
257/// Note that this method does not explicitly check that the padded data
258/// is valid in and of itself: that is performed by `validate_last_block` as a
259/// final step.
260///
261/// Returns length-related errors eagerly as a [`Result`], and data-dependent
262/// errors (i.e. malformed padding bytes) as `i16` to be combined with other
263/// encoding-related errors prior to branching.
264#[inline(always)]
265pub(crate) fn decode_padding(input: &[u8]) -> Result<(usize, i16), InvalidEncodingError> {
266    if input.len() % 4 != 0 {
267        return Err(InvalidEncodingError);
268    }
269
270    let unpadded_len = match *input {
271        [.., b0, b1] => is_pad_ct(b0)
272            .checked_add(is_pad_ct(b1))
273            .and_then(|len| len.try_into().ok())
274            .and_then(|len| input.len().checked_sub(len))
275            .ok_or(InvalidEncodingError)?,
276        _ => input.len(),
277    };
278
279    let padding_len = input
280        .len()
281        .checked_sub(unpadded_len)
282        .ok_or(InvalidEncodingError)?;
283
284    let err = match *input {
285        [.., b0] if padding_len == 1 => is_pad_ct(b0) ^ 1,
286        [.., b0, b1] if padding_len == 2 => (is_pad_ct(b0) & is_pad_ct(b1)) ^ 1,
287        _ => {
288            if padding_len == 0 {
289                0
290            } else {
291                return Err(InvalidEncodingError);
292            }
293        }
294    };
295
296    Ok((unpadded_len, err))
297}
298
299/// Validate that the last block of the decoded data round-trips back to the
300/// encoded data.
301fn validate_last_block<T: Alphabet>(encoded: &[u8], decoded: &[u8]) -> Result<(), Error> {
302    if encoded.is_empty() && decoded.is_empty() {
303        return Ok(());
304    }
305
306    // TODO(tarcieri): explicitly checked/wrapped arithmetic
307    #[allow(clippy::arithmetic_side_effects)]
308    fn last_block_start(bytes: &[u8], block_size: usize) -> usize {
309        (bytes.len().saturating_sub(1) / block_size) * block_size
310    }
311
312    let enc_block = encoded
313        .get(last_block_start(encoded, 4)..)
314        .ok_or(Error::InvalidEncoding)?;
315
316    let dec_block = decoded
317        .get(last_block_start(decoded, 3)..)
318        .ok_or(Error::InvalidEncoding)?;
319
320    // Round-trip encode the decoded block
321    let mut buf = [0u8; 4];
322    let block = T::encode(dec_block, &mut buf)?;
323
324    // Non-short-circuiting comparison of padding
325    // TODO(tarcieri): better constant-time mechanisms (e.g. `subtle`)?
326    if block
327        .as_bytes()
328        .iter()
329        .zip(enc_block.iter())
330        .fold(0, |acc, (a, b)| acc | (a ^ b))
331        == 0
332    {
333        Ok(())
334    } else {
335        Err(Error::InvalidEncoding)
336    }
337}
338
339/// Get the length of the output from decoding the provided *unpadded*
340/// Base64-encoded input.
341///
342/// Note that this function does not fully validate the Base64 is well-formed
343/// and may return incorrect results for malformed Base64.
344// TODO(tarcieri): explicitly checked/wrapped arithmetic
345#[allow(clippy::arithmetic_side_effects)]
346#[inline(always)]
347pub(crate) fn decoded_len(input_len: usize) -> usize {
348    // overflow-proof computation of `(3*n)/4`
349    let k = input_len / 4;
350    let l = input_len - 4 * k;
351    3 * k + (3 * l) / 4
352}
353
354/// Branchless match that a given byte is the `PAD` character
355// TODO(tarcieri): explicitly checked/wrapped arithmetic
356#[allow(clippy::arithmetic_side_effects)]
357#[inline(always)]
358fn is_pad_ct(input: u8) -> i16 {
359    ((((PAD as i16 - 1) - input as i16) & (input as i16 - (PAD as i16 + 1))) >> 8) & 1
360}
361
362// TODO(tarcieri): explicitly checked/wrapped arithmetic
363#[allow(clippy::arithmetic_side_effects)]
364#[inline(always)]
365const fn encoded_len_inner(n: usize, padded: bool) -> Option<usize> {
366    match n.checked_mul(4) {
367        Some(q) => {
368            if padded {
369                Some(((q / 3) + 3) & !3)
370            } else {
371                Some((q / 3) + (q % 3 != 0) as usize)
372            }
373        }
374        None => None,
375    }
376}