1use crate::{
4 alphabet::Alphabet,
5 errors::{Error, InvalidEncodingError, InvalidLengthError},
6};
7use core::str;
8
9#[cfg(feature = "alloc")]
10use alloc::{string::String, vec::Vec};
11
12#[cfg(doc)]
13use crate::{Base64, Base64Bcrypt, Base64Crypt, Base64Unpadded, Base64Url, Base64UrlUnpadded};
14
15const PAD: u8 = b'=';
17
18pub trait Encoding: Alphabet {
32 fn decode(src: impl AsRef<[u8]>, dst: &mut [u8]) -> Result<&[u8], Error>;
34
35 fn decode_in_place(buf: &mut [u8]) -> Result<&[u8], InvalidEncodingError>;
40
41 #[cfg(feature = "alloc")]
43 fn decode_vec(input: &str) -> Result<Vec<u8>, Error>;
44
45 fn encode<'a>(src: &[u8], dst: &'a mut [u8]) -> Result<&'a str, InvalidLengthError>;
50
51 #[cfg(feature = "alloc")]
56 fn encode_string(input: &[u8]) -> String;
57
58 fn encoded_len(bytes: &[u8]) -> usize;
62}
63
64impl<T: Alphabet> Encoding for T {
65 fn decode(src: impl AsRef<[u8]>, dst: &mut [u8]) -> Result<&[u8], Error> {
66 let (src_unpadded, mut err) = if T::PADDED {
67 let (unpadded_len, e) = decode_padding(src.as_ref())?;
68 (&src.as_ref()[..unpadded_len], e)
69 } else {
70 (src.as_ref(), 0)
71 };
72
73 let dlen = decoded_len(src_unpadded.len());
74
75 if dlen > dst.len() {
76 return Err(Error::InvalidLength);
77 }
78
79 let dst = &mut dst[..dlen];
80
81 let mut src_chunks = src_unpadded.chunks_exact(4);
82 let mut dst_chunks = dst.chunks_exact_mut(3);
83 for (s, d) in (&mut src_chunks).zip(&mut dst_chunks) {
84 err |= Self::decode_3bytes(s, d);
85 }
86 let src_rem = src_chunks.remainder();
87 let dst_rem = dst_chunks.into_remainder();
88
89 err |= !(src_rem.is_empty() || src_rem.len() >= 2) as i16;
90 let mut tmp_out = [0u8; 3];
91 let mut tmp_in = [b'A'; 4];
92 tmp_in[..src_rem.len()].copy_from_slice(src_rem);
93 err |= Self::decode_3bytes(&tmp_in, &mut tmp_out);
94 dst_rem.copy_from_slice(&tmp_out[..dst_rem.len()]);
95
96 if err == 0 {
97 validate_last_block::<T>(src.as_ref(), dst)?;
98 Ok(dst)
99 } else {
100 Err(Error::InvalidEncoding)
101 }
102 }
103
104 #[allow(clippy::arithmetic_side_effects)]
106 fn decode_in_place(mut buf: &mut [u8]) -> Result<&[u8], InvalidEncodingError> {
107 let mut err = if T::PADDED {
110 let (unpadded_len, e) = decode_padding(buf)?;
111 buf = &mut buf[..unpadded_len];
112 e
113 } else {
114 0
115 };
116
117 let dlen = decoded_len(buf.len());
118 let full_chunks = buf.len() / 4;
119
120 for chunk in 0..full_chunks {
121 #[allow(unsafe_code)]
125 unsafe {
126 debug_assert!(3 * chunk + 3 <= buf.len());
127 debug_assert!(4 * chunk + 4 <= buf.len());
128
129 let p3 = buf.as_mut_ptr().add(3 * chunk) as *mut [u8; 3];
130 let p4 = buf.as_ptr().add(4 * chunk) as *const [u8; 4];
131
132 let mut tmp_out = [0u8; 3];
133 err |= Self::decode_3bytes(&*p4, &mut tmp_out);
134 *p3 = tmp_out;
135 }
136 }
137
138 let src_rem_pos = 4 * full_chunks;
139 let src_rem_len = buf.len() - src_rem_pos;
140 let dst_rem_pos = 3 * full_chunks;
141 let dst_rem_len = dlen - dst_rem_pos;
142
143 err |= !(src_rem_len == 0 || src_rem_len >= 2) as i16;
144 let mut tmp_in = [b'A'; 4];
145 tmp_in[..src_rem_len].copy_from_slice(&buf[src_rem_pos..]);
146 let mut tmp_out = [0u8; 3];
147
148 err |= Self::decode_3bytes(&tmp_in, &mut tmp_out);
149
150 if err == 0 {
151 #[allow(unsafe_code)]
155 unsafe {
156 debug_assert!(dst_rem_pos + dst_rem_len <= buf.len());
157 debug_assert!(dst_rem_len <= tmp_out.len());
158 debug_assert!(dlen <= buf.len());
159
160 core::ptr::copy_nonoverlapping(
161 tmp_out.as_ptr(),
162 buf.as_mut_ptr().add(dst_rem_pos),
163 dst_rem_len,
164 );
165 Ok(buf.get_unchecked(..dlen))
166 }
167 } else {
168 Err(InvalidEncodingError)
169 }
170 }
171
172 #[cfg(feature = "alloc")]
173 fn decode_vec(input: &str) -> Result<Vec<u8>, Error> {
174 let mut output = vec![0u8; decoded_len(input.len())];
175 let len = Self::decode(input, &mut output)?.len();
176
177 if len <= output.len() {
178 output.truncate(len);
179 Ok(output)
180 } else {
181 Err(Error::InvalidLength)
182 }
183 }
184
185 fn encode<'a>(src: &[u8], dst: &'a mut [u8]) -> Result<&'a str, InvalidLengthError> {
186 let elen = match encoded_len_inner(src.len(), T::PADDED) {
187 Some(v) => v,
188 None => return Err(InvalidLengthError),
189 };
190
191 if elen > dst.len() {
192 return Err(InvalidLengthError);
193 }
194
195 let dst = &mut dst[..elen];
196
197 let mut src_chunks = src.chunks_exact(3);
198 let mut dst_chunks = dst.chunks_exact_mut(4);
199
200 for (s, d) in (&mut src_chunks).zip(&mut dst_chunks) {
201 Self::encode_3bytes(s, d);
202 }
203
204 let src_rem = src_chunks.remainder();
205
206 if T::PADDED {
207 if let Some(dst_rem) = dst_chunks.next() {
208 let mut tmp = [0u8; 3];
209 tmp[..src_rem.len()].copy_from_slice(src_rem);
210 Self::encode_3bytes(&tmp, dst_rem);
211
212 let flag = src_rem.len() == 1;
213 let mask = (flag as u8).wrapping_sub(1);
214 dst_rem[2] = (dst_rem[2] & mask) | (PAD & !mask);
215 dst_rem[3] = PAD;
216 }
217 } else {
218 let dst_rem = dst_chunks.into_remainder();
219
220 let mut tmp_in = [0u8; 3];
221 let mut tmp_out = [0u8; 4];
222 tmp_in[..src_rem.len()].copy_from_slice(src_rem);
223 Self::encode_3bytes(&tmp_in, &mut tmp_out);
224 dst_rem.copy_from_slice(&tmp_out[..dst_rem.len()]);
225 }
226
227 debug_assert!(str::from_utf8(dst).is_ok());
228
229 #[allow(unsafe_code)]
231 Ok(unsafe { str::from_utf8_unchecked(dst) })
232 }
233
234 #[cfg(feature = "alloc")]
235 fn encode_string(input: &[u8]) -> String {
236 let elen = encoded_len_inner(input.len(), T::PADDED).expect("input is too big");
237 let mut dst = vec![0u8; elen];
238 let res = Self::encode(input, &mut dst).expect("encoding error");
239
240 debug_assert_eq!(elen, res.len());
241 debug_assert!(str::from_utf8(&dst).is_ok());
242
243 #[allow(unsafe_code)]
245 unsafe {
246 String::from_utf8_unchecked(dst)
247 }
248 }
249
250 fn encoded_len(bytes: &[u8]) -> usize {
251 encoded_len_inner(bytes.len(), T::PADDED).unwrap_or(0)
252 }
253}
254
255#[inline(always)]
265pub(crate) fn decode_padding(input: &[u8]) -> Result<(usize, i16), InvalidEncodingError> {
266 if input.len() % 4 != 0 {
267 return Err(InvalidEncodingError);
268 }
269
270 let unpadded_len = match *input {
271 [.., b0, b1] => is_pad_ct(b0)
272 .checked_add(is_pad_ct(b1))
273 .and_then(|len| len.try_into().ok())
274 .and_then(|len| input.len().checked_sub(len))
275 .ok_or(InvalidEncodingError)?,
276 _ => input.len(),
277 };
278
279 let padding_len = input
280 .len()
281 .checked_sub(unpadded_len)
282 .ok_or(InvalidEncodingError)?;
283
284 let err = match *input {
285 [.., b0] if padding_len == 1 => is_pad_ct(b0) ^ 1,
286 [.., b0, b1] if padding_len == 2 => (is_pad_ct(b0) & is_pad_ct(b1)) ^ 1,
287 _ => {
288 if padding_len == 0 {
289 0
290 } else {
291 return Err(InvalidEncodingError);
292 }
293 }
294 };
295
296 Ok((unpadded_len, err))
297}
298
299fn validate_last_block<T: Alphabet>(encoded: &[u8], decoded: &[u8]) -> Result<(), Error> {
302 if encoded.is_empty() && decoded.is_empty() {
303 return Ok(());
304 }
305
306 #[allow(clippy::arithmetic_side_effects)]
308 fn last_block_start(bytes: &[u8], block_size: usize) -> usize {
309 (bytes.len().saturating_sub(1) / block_size) * block_size
310 }
311
312 let enc_block = encoded
313 .get(last_block_start(encoded, 4)..)
314 .ok_or(Error::InvalidEncoding)?;
315
316 let dec_block = decoded
317 .get(last_block_start(decoded, 3)..)
318 .ok_or(Error::InvalidEncoding)?;
319
320 let mut buf = [0u8; 4];
322 let block = T::encode(dec_block, &mut buf)?;
323
324 if block
327 .as_bytes()
328 .iter()
329 .zip(enc_block.iter())
330 .fold(0, |acc, (a, b)| acc | (a ^ b))
331 == 0
332 {
333 Ok(())
334 } else {
335 Err(Error::InvalidEncoding)
336 }
337}
338
339#[allow(clippy::arithmetic_side_effects)]
346#[inline(always)]
347pub(crate) fn decoded_len(input_len: usize) -> usize {
348 let k = input_len / 4;
350 let l = input_len - 4 * k;
351 3 * k + (3 * l) / 4
352}
353
354#[allow(clippy::arithmetic_side_effects)]
357#[inline(always)]
358fn is_pad_ct(input: u8) -> i16 {
359 ((((PAD as i16 - 1) - input as i16) & (input as i16 - (PAD as i16 + 1))) >> 8) & 1
360}
361
362#[allow(clippy::arithmetic_side_effects)]
364#[inline(always)]
365const fn encoded_len_inner(n: usize, padded: bool) -> Option<usize> {
366 match n.checked_mul(4) {
367 Some(q) => {
368 if padded {
369 Some(((q / 3) + 3) & !3)
370 } else {
371 Some((q / 3) + (q % 3 != 0) as usize)
372 }
373 }
374 None => None,
375 }
376}