use std::fmt::{Arguments, Write as FmtWrite};
use std::io::{self, ErrorKind, Write};
use std::str::from_utf8;
#[rustfmt::skip]
static HREF_SAFE: [u8; 128] = [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1,
0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,
];
static HEX_CHARS: &[u8] = b"0123456789ABCDEF";
static AMP_ESCAPE: &str = "&";
static SLASH_ESCAPE: &str = "'";
pub struct WriteWrapper<W>(pub W);
pub trait StrWrite {
fn write_str(&mut self, s: &str) -> io::Result<()>;
fn write_fmt(&mut self, args: Arguments) -> io::Result<()>;
}
impl<W> StrWrite for WriteWrapper<W>
where
W: Write,
{
#[inline]
fn write_str(&mut self, s: &str) -> io::Result<()> {
self.0.write_all(s.as_bytes())
}
#[inline]
fn write_fmt(&mut self, args: Arguments) -> io::Result<()> {
self.0.write_fmt(args)
}
}
impl<'w> StrWrite for String {
#[inline]
fn write_str(&mut self, s: &str) -> io::Result<()> {
self.push_str(s);
Ok(())
}
#[inline]
fn write_fmt(&mut self, args: Arguments) -> io::Result<()> {
FmtWrite::write_fmt(self, args).map_err(|_| ErrorKind::Other.into())
}
}
impl<W> StrWrite for &'_ mut W
where
W: StrWrite,
{
#[inline]
fn write_str(&mut self, s: &str) -> io::Result<()> {
(**self).write_str(s)
}
#[inline]
fn write_fmt(&mut self, args: Arguments) -> io::Result<()> {
(**self).write_fmt(args)
}
}
pub fn escape_href<W>(mut w: W, s: &str) -> io::Result<()>
where
W: StrWrite,
{
let bytes = s.as_bytes();
let mut mark = 0;
for i in 0..bytes.len() {
let c = bytes[i];
if c >= 0x80 || HREF_SAFE[c as usize] == 0 {
if mark < i {
w.write_str(&s[mark..i])?;
}
match c {
b'&' => {
w.write_str(AMP_ESCAPE)?;
}
b'\'' => {
w.write_str(SLASH_ESCAPE)?;
}
_ => {
let mut buf = [0u8; 3];
buf[0] = b'%';
buf[1] = HEX_CHARS[((c as usize) >> 4) & 0xF];
buf[2] = HEX_CHARS[(c as usize) & 0xF];
let escaped = from_utf8(&buf).unwrap();
w.write_str(escaped)?;
}
}
mark = i + 1;
}
}
w.write_str(&s[mark..])
}
const fn create_html_escape_table() -> [u8; 256] {
let mut table = [0; 256];
table[b'"' as usize] = 1;
table[b'&' as usize] = 2;
table[b'<' as usize] = 3;
table[b'>' as usize] = 4;
table
}
static HTML_ESCAPE_TABLE: [u8; 256] = create_html_escape_table();
static HTML_ESCAPES: [&'static str; 5] = ["", """, "&", "<", ">"];
pub fn escape_html<W: StrWrite>(w: W, s: &str) -> io::Result<()> {
#[cfg(all(target_arch = "x86_64", feature = "simd"))]
{
simd::escape_html(w, s)
}
#[cfg(not(all(target_arch = "x86_64", feature = "simd")))]
{
escape_html_scalar(w, s)
}
}
fn escape_html_scalar<W: StrWrite>(mut w: W, s: &str) -> io::Result<()> {
let bytes = s.as_bytes();
let mut mark = 0;
let mut i = 0;
while i < s.len() {
match bytes[i..]
.iter()
.position(|&c| HTML_ESCAPE_TABLE[c as usize] != 0)
{
Some(pos) => {
i += pos;
}
None => break,
}
let c = bytes[i];
let escape = HTML_ESCAPE_TABLE[c as usize];
let escape_seq = HTML_ESCAPES[escape as usize];
w.write_str(&s[mark..i])?;
w.write_str(escape_seq)?;
i += 1;
mark = i;
}
w.write_str(&s[mark..])
}
#[cfg(all(target_arch = "x86_64", feature = "simd"))]
mod simd {
use super::StrWrite;
use std::arch::x86_64::*;
use std::io;
use std::mem::size_of;
const VECTOR_SIZE: usize = size_of::<__m128i>();
pub(crate) fn escape_html<W: StrWrite>(mut w: W, s: &str) -> io::Result<()> {
if is_x86_feature_detected!("ssse3") && s.len() >= VECTOR_SIZE {
let bytes = s.as_bytes();
let mut mark = 0;
unsafe {
foreach_special_simd(bytes, 0, |i| {
let escape_ix = *bytes.get_unchecked(i) as usize;
let replacement =
super::HTML_ESCAPES[super::HTML_ESCAPE_TABLE[escape_ix] as usize];
w.write_str(&s.get_unchecked(mark..i))?;
mark = i + 1;
w.write_str(replacement)
})?;
w.write_str(&s.get_unchecked(mark..))
}
} else {
super::escape_html_scalar(w, s)
}
}
const fn create_lookup() -> [u8; 16] {
let mut table = [0; 16];
table[(b'<' & 0x0f) as usize] = b'<';
table[(b'>' & 0x0f) as usize] = b'>';
table[(b'&' & 0x0f) as usize] = b'&';
table[(b'"' & 0x0f) as usize] = b'"';
table[0] = 0b0111_1111;
table
}
#[target_feature(enable = "ssse3")]
unsafe fn compute_mask(bytes: &[u8], offset: usize) -> i32 {
debug_assert!(bytes.len() >= offset + VECTOR_SIZE);
let table = create_lookup();
let lookup = _mm_loadu_si128(table.as_ptr() as *const __m128i);
let raw_ptr = bytes.as_ptr().offset(offset as isize) as *const __m128i;
let vector = _mm_loadu_si128(raw_ptr);
let expected = _mm_shuffle_epi8(lookup, vector);
let matches = _mm_cmpeq_epi8(expected, vector);
_mm_movemask_epi8(matches)
}
#[target_feature(enable = "ssse3")]
unsafe fn foreach_special_simd<F>(
bytes: &[u8],
mut offset: usize,
mut callback: F,
) -> io::Result<()>
where
F: FnMut(usize) -> io::Result<()>,
{
debug_assert!(bytes.len() >= VECTOR_SIZE);
let upperbound = bytes.len() - VECTOR_SIZE;
while offset < upperbound {
let mut mask = compute_mask(bytes, offset);
while mask != 0 {
let ix = mask.trailing_zeros();
callback(offset + ix as usize)?;
mask ^= mask & -mask;
}
offset += VECTOR_SIZE;
}
let mut mask = compute_mask(bytes, upperbound);
mask >>= offset - upperbound;
while mask != 0 {
let ix = mask.trailing_zeros();
callback(offset + ix as usize)?;
mask ^= mask & -mask;
}
Ok(())
}
#[cfg(test)]
mod html_scan_tests {
#[test]
fn multichunk() {
let mut vec = Vec::new();
unsafe {
super::foreach_special_simd("&aXaaaa.a'aa9a<>aab&".as_bytes(), 0, |ix| {
Ok(vec.push(ix))
})
.unwrap();
}
assert_eq!(vec, vec![0, 14, 15, 19]);
}
#[test]
fn only_right_bytes_matched() {
for b in 0..255u8 {
let right_byte = b == b'&' || b == b'<' || b == b'>' || b == b'"';
let vek = vec![b; super::VECTOR_SIZE];
let mut match_count = 0;
unsafe {
super::foreach_special_simd(&vek, 0, |_| {
match_count += 1;
Ok(())
})
.unwrap();
}
assert!((match_count > 0) == (match_count == super::VECTOR_SIZE));
assert_eq!(
(match_count == super::VECTOR_SIZE),
right_byte,
"match_count: {}, byte: {:?}",
match_count,
b as char
);
}
}
}
}