data_url/
forgiving_base64.rs

1//! <https://infra.spec.whatwg.org/#forgiving-base64-decode>
2
3use alloc::vec::Vec;
4use core::fmt;
5
6#[derive(Debug)]
7pub struct InvalidBase64(InvalidBase64Details);
8
9impl fmt::Display for InvalidBase64 {
10    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
11        match self.0 {
12            InvalidBase64Details::UnexpectedSymbol(code_point) => {
13                write!(f, "symbol with codepoint {} not expected", code_point)
14            }
15            InvalidBase64Details::AlphabetSymbolAfterPadding => {
16                write!(f, "alphabet symbol present after padding")
17            }
18            InvalidBase64Details::LoneAlphabetSymbol => write!(f, "lone alphabet symbol present"),
19            InvalidBase64Details::Padding => write!(f, "incorrect padding"),
20        }
21    }
22}
23
24#[cfg(feature = "std")]
25impl std::error::Error for InvalidBase64 {}
26
27#[derive(Debug)]
28enum InvalidBase64Details {
29    UnexpectedSymbol(u8),
30    AlphabetSymbolAfterPadding,
31    LoneAlphabetSymbol,
32    Padding,
33}
34
35#[derive(Debug)]
36pub enum DecodeError<E> {
37    InvalidBase64(InvalidBase64),
38    WriteError(E),
39}
40
41impl<E: fmt::Display> fmt::Display for DecodeError<E> {
42    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43        match self {
44            Self::InvalidBase64(inner) => write!(f, "base64 not valid: {}", inner),
45            Self::WriteError(err) => write!(f, "write error: {}", err),
46        }
47    }
48}
49
50#[cfg(feature = "std")]
51impl<E: std::error::Error> std::error::Error for DecodeError<E> {}
52
53impl<E> From<InvalidBase64Details> for DecodeError<E> {
54    fn from(e: InvalidBase64Details) -> Self {
55        DecodeError::InvalidBase64(InvalidBase64(e))
56    }
57}
58
59pub(crate) enum Impossible {}
60
61impl From<DecodeError<Impossible>> for InvalidBase64 {
62    fn from(e: DecodeError<Impossible>) -> Self {
63        match e {
64            DecodeError::InvalidBase64(e) => e,
65            DecodeError::WriteError(e) => match e {},
66        }
67    }
68}
69
70/// `input` is assumed to be in an ASCII-compatible encoding
71pub fn decode_to_vec(input: &[u8]) -> Result<Vec<u8>, InvalidBase64> {
72    let mut v = Vec::new();
73    {
74        let mut decoder = Decoder::new(|bytes| {
75            v.extend_from_slice(bytes);
76            Ok(())
77        });
78        decoder.feed(input)?;
79        decoder.finish()?;
80    }
81    Ok(v)
82}
83
84/// <https://infra.spec.whatwg.org/#forgiving-base64-decode>
85pub struct Decoder<F, E>
86where
87    F: FnMut(&[u8]) -> Result<(), E>,
88{
89    write_bytes: F,
90    bit_buffer: u32,
91    buffer_bit_length: u8,
92    padding_symbols: u8,
93}
94
95impl<F, E> Decoder<F, E>
96where
97    F: FnMut(&[u8]) -> Result<(), E>,
98{
99    pub fn new(write_bytes: F) -> Self {
100        Self {
101            write_bytes,
102            bit_buffer: 0,
103            buffer_bit_length: 0,
104            padding_symbols: 0,
105        }
106    }
107
108    /// Feed to the decoder partial input in an ASCII-compatible encoding
109    pub fn feed(&mut self, input: &[u8]) -> Result<(), DecodeError<E>> {
110        for &byte in input.iter() {
111            let value = BASE64_DECODE_TABLE[byte as usize];
112            if value < 0 {
113                // A character that’s not part of the alphabet
114
115                // Remove ASCII whitespace
116                if matches!(byte, b' ' | b'\t' | b'\n' | b'\r' | b'\x0C') {
117                    continue;
118                }
119
120                if byte == b'=' {
121                    self.padding_symbols = self.padding_symbols.saturating_add(1);
122                    continue;
123                }
124
125                return Err(InvalidBase64Details::UnexpectedSymbol(byte).into());
126            }
127            if self.padding_symbols > 0 {
128                return Err(InvalidBase64Details::AlphabetSymbolAfterPadding.into());
129            }
130            self.bit_buffer <<= 6;
131            self.bit_buffer |= value as u32;
132            // 18 before incrementing means we’ve just reached 24
133            if self.buffer_bit_length < 18 {
134                self.buffer_bit_length += 6;
135            } else {
136                // We’ve accumulated four times 6 bits, which equals three times 8 bits.
137                let byte_buffer = [
138                    (self.bit_buffer >> 16) as u8,
139                    (self.bit_buffer >> 8) as u8,
140                    self.bit_buffer as u8,
141                ];
142                (self.write_bytes)(&byte_buffer).map_err(DecodeError::WriteError)?;
143                self.buffer_bit_length = 0;
144                // No need to reset bit_buffer,
145                // since next time we’re only gonna read relevant bits.
146            }
147        }
148        Ok(())
149    }
150
151    /// Call this to signal the end of the input
152    pub fn finish(mut self) -> Result<(), DecodeError<E>> {
153        match (self.buffer_bit_length, self.padding_symbols) {
154            (0, 0) => {
155                // A multiple of four of alphabet symbols, and nothing else.
156            }
157            (12, 2) | (12, 0) => {
158                // A multiple of four of alphabet symbols, followed by two more symbols,
159                // optionally followed by two padding characters (which make a total multiple of four).
160                let byte_buffer = [(self.bit_buffer >> 4) as u8];
161                (self.write_bytes)(&byte_buffer).map_err(DecodeError::WriteError)?;
162            }
163            (18, 1) | (18, 0) => {
164                // A multiple of four of alphabet symbols, followed by three more symbols,
165                // optionally followed by one padding character (which make a total multiple of four).
166                let byte_buffer = [(self.bit_buffer >> 10) as u8, (self.bit_buffer >> 2) as u8];
167                (self.write_bytes)(&byte_buffer).map_err(DecodeError::WriteError)?;
168            }
169            (6, _) => return Err(InvalidBase64Details::LoneAlphabetSymbol.into()),
170            _ => return Err(InvalidBase64Details::Padding.into()),
171        }
172        Ok(())
173    }
174}
175
176/// Generated by `make_base64_decode_table.py` based on "Table 1: The Base 64 Alphabet"
177/// at <https://tools.ietf.org/html/rfc4648#section-4>
178///
179/// Array indices are the byte value of symbols.
180/// Array values are their positions in the base64 alphabet,
181/// or -1 for symbols not in the alphabet.
182/// The position contributes 6 bits to the decoded bytes.
183#[rustfmt::skip]
184const BASE64_DECODE_TABLE: [i8; 256] = [
185    -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
186    -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
187    -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 62, -1, -1, -1, 63,
188    52, 53, 54, 55, 56, 57, 58, 59, 60, 61, -1, -1, -1, -1, -1, -1,
189    -1,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14,
190    15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, -1, -1, -1, -1, -1,
191    -1, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
192    41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, -1, -1, -1, -1, -1,
193    -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
194    -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
195    -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
196    -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
197    -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
198    -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
199    -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
200    -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
201];