use std::borrow::Cow;
use std::fmt;
#[derive(Debug, PartialEq, Eq)]
enum EscapeErrorKind {
ExpectedSurrogatePair(String),
InvalidEscapeCharacter(char),
InvalidSurrogatePair(u16, u16),
InvalidUnicodeEscape(String),
InvalidUtf8,
UnexpectedEndOfString,
}
#[derive(Debug)]
#[cfg_attr(test, derive(PartialEq, Eq))]
pub struct EscapeError {
kind: EscapeErrorKind,
}
impl std::error::Error for EscapeError {}
impl fmt::Display for EscapeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use EscapeErrorKind::*;
match &self.kind {
ExpectedSurrogatePair(low) => {
write!(
f,
"expected a UTF-16 surrogate pair, but got {} as the low word",
low
)
}
InvalidEscapeCharacter(chr) => write!(f, "invalid JSON escape: \\{}", chr),
InvalidSurrogatePair(high, low) => {
write!(f, "invalid surrogate pair: \\u{:04X}\\u{:04X}", high, low)
}
InvalidUnicodeEscape(escape) => write!(f, "invalid JSON Unicode escape: \\u{}", escape),
InvalidUtf8 => write!(f, "invalid UTF-8 codepoint in JSON string"),
UnexpectedEndOfString => write!(f, "unexpected end of string"),
}
}
}
impl From<EscapeErrorKind> for EscapeError {
fn from(kind: EscapeErrorKind) -> Self {
Self { kind }
}
}
pub(crate) fn escape_string(value: &str) -> Cow<'_, str> {
let bytes = value.as_bytes();
for (index, byte) in bytes.iter().enumerate() {
match byte {
0..=0x1F | b'"' | b'\\' => {
return Cow::Owned(escape_string_inner(&bytes[0..index], &bytes[index..]))
}
_ => {}
}
}
Cow::Borrowed(value)
}
fn escape_string_inner(start: &[u8], rest: &[u8]) -> String {
let mut escaped = Vec::with_capacity(start.len() + rest.len() + 1);
escaped.extend(start);
for byte in rest {
match byte {
b'"' => escaped.extend(b"\\\""),
b'\\' => escaped.extend(b"\\\\"),
0x08 => escaped.extend(b"\\b"),
0x0C => escaped.extend(b"\\f"),
b'\n' => escaped.extend(b"\\n"),
b'\r' => escaped.extend(b"\\r"),
b'\t' => escaped.extend(b"\\t"),
0..=0x1F => escaped.extend(format!("\\u{:04x}", byte).bytes()),
_ => escaped.push(*byte),
}
}
debug_assert!(std::str::from_utf8(&escaped).is_ok());
unsafe { String::from_utf8_unchecked(escaped) }
}
pub(crate) fn unescape_string(value: &str) -> Result<Cow<'_, str>, EscapeError> {
let bytes = value.as_bytes();
for (index, byte) in bytes.iter().enumerate() {
if *byte == b'\\' {
return unescape_string_inner(&bytes[0..index], &bytes[index..]).map(Cow::Owned);
}
}
Ok(Cow::Borrowed(value))
}
fn unescape_string_inner(start: &[u8], rest: &[u8]) -> Result<String, EscapeError> {
let mut unescaped = Vec::with_capacity(start.len() + rest.len());
unescaped.extend(start);
let mut index = 0;
while index < rest.len() {
match rest[index] {
b'\\' => {
index += 1;
if index == rest.len() {
return Err(EscapeErrorKind::UnexpectedEndOfString.into());
}
match rest[index] {
b'u' => {
index -= 1;
index += read_unicode_escapes(&rest[index..], &mut unescaped)?;
}
byte => {
match byte {
b'\\' => unescaped.push(b'\\'),
b'/' => unescaped.push(b'/'),
b'"' => unescaped.push(b'"'),
b'b' => unescaped.push(0x08),
b'f' => unescaped.push(0x0C),
b'n' => unescaped.push(b'\n'),
b'r' => unescaped.push(b'\r'),
b't' => unescaped.push(b'\t'),
_ => {
return Err(
EscapeErrorKind::InvalidEscapeCharacter(byte.into()).into()
)
}
}
index += 1;
}
}
}
byte => {
unescaped.push(byte);
index += 1
}
}
}
String::from_utf8(unescaped).map_err(|_| EscapeErrorKind::InvalidUtf8.into())
}
fn is_utf16_low_surrogate(codepoint: u16) -> bool {
codepoint & 0xFC00 == 0xDC00
}
fn is_utf16_high_surrogate(codepoint: u16) -> bool {
codepoint & 0xFC00 == 0xD800
}
fn read_codepoint(rest: &[u8]) -> Result<u16, EscapeError> {
if rest.len() < 6 {
return Err(EscapeErrorKind::UnexpectedEndOfString.into());
}
if &rest[0..2] != b"\\u" {
return Err(EscapeErrorKind::ExpectedSurrogatePair(
String::from_utf8_lossy(&rest[0..6]).into(),
)
.into());
}
let codepoint_str =
std::str::from_utf8(&rest[2..6]).map_err(|_| EscapeErrorKind::InvalidUtf8)?;
if codepoint_str.bytes().any(|byte| !byte.is_ascii_hexdigit()) {
return Err(EscapeErrorKind::InvalidUnicodeEscape(codepoint_str.into()).into());
}
Ok(u16::from_str_radix(codepoint_str, 16).expect("hex string is valid 16-bit value"))
}
fn read_unicode_escapes(bytes: &[u8], into: &mut Vec<u8>) -> Result<usize, EscapeError> {
let high = read_codepoint(bytes)?;
let (bytes_read, chr) = if is_utf16_high_surrogate(high) {
let low = read_codepoint(&bytes[6..])?;
if !is_utf16_low_surrogate(low) {
return Err(EscapeErrorKind::InvalidSurrogatePair(high, low).into());
}
let codepoint =
std::char::from_u32(0x10000 + (high - 0xD800) as u32 * 0x400 + (low - 0xDC00) as u32)
.ok_or(EscapeErrorKind::InvalidSurrogatePair(high, low))?;
(12, codepoint)
} else {
let codepoint = std::char::from_u32(high as u32).ok_or_else(|| {
EscapeErrorKind::InvalidUnicodeEscape(String::from_utf8_lossy(&bytes[0..6]).into())
})?;
(6, codepoint)
};
match chr.len_utf8() {
1 => into.push(chr as u8),
_ => into.extend_from_slice(chr.encode_utf8(&mut [0; 4]).as_bytes()),
}
Ok(bytes_read)
}
#[cfg(test)]
mod test {
use super::escape_string;
use crate::escape::{unescape_string, EscapeErrorKind};
use std::borrow::Cow;
#[test]
fn escape() {
assert_eq!("", escape_string("").as_ref());
assert_eq!("foo", escape_string("foo").as_ref());
assert_eq!("foo\\r\\n", escape_string("foo\r\n").as_ref());
assert_eq!("foo\\r\\nbar", escape_string("foo\r\nbar").as_ref());
assert_eq!(r"foo\\bar", escape_string(r"foo\bar").as_ref());
assert_eq!(r"\\foobar", escape_string(r"\foobar").as_ref());
assert_eq!(
r"\bf\fo\to\r\n",
escape_string("\u{08}f\u{0C}o\to\r\n").as_ref()
);
assert_eq!("\\\"test\\\"", escape_string("\"test\"").as_ref());
assert_eq!("\\u0000", escape_string("\u{0}").as_ref());
assert_eq!("\\u001f", escape_string("\u{1f}").as_ref());
}
#[test]
fn unescape_no_escapes() {
let unescaped = unescape_string("test test").unwrap();
assert_eq!("test test", unescaped);
assert!(matches!(unescaped, Cow::Borrowed(_)));
}
#[test]
fn unescape() {
assert_eq!(
"\x08f\x0Co\to\r\n",
unescape_string(r"\bf\fo\to\r\n").unwrap()
);
assert_eq!("\"test\"", unescape_string(r#"\"test\""#).unwrap());
assert_eq!("\x00", unescape_string("\\u0000").unwrap());
assert_eq!("\x1f", unescape_string("\\u001f").unwrap());
assert_eq!("foo\r\nbar", unescape_string("foo\\r\\nbar").unwrap());
assert_eq!("foo\r\n", unescape_string("foo\\r\\n").unwrap());
assert_eq!("\r\nbar", unescape_string("\\r\\nbar").unwrap());
assert_eq!("\u{10437}", unescape_string("\\uD801\\uDC37").unwrap());
assert_eq!(
Err(EscapeErrorKind::UnexpectedEndOfString.into()),
unescape_string("\\")
);
assert_eq!(
Err(EscapeErrorKind::UnexpectedEndOfString.into()),
unescape_string("\\u")
);
assert_eq!(
Err(EscapeErrorKind::UnexpectedEndOfString.into()),
unescape_string("\\u00")
);
assert_eq!(
Err(EscapeErrorKind::InvalidEscapeCharacter('z').into()),
unescape_string("\\z")
);
assert_eq!(
Err(EscapeErrorKind::ExpectedSurrogatePair("\\nasdf".into()).into()),
unescape_string("\\uD801\\nasdf")
);
assert_eq!(
Err(EscapeErrorKind::UnexpectedEndOfString.into()),
unescape_string("\\uD801\\u00")
);
assert_eq!(
Err(EscapeErrorKind::InvalidSurrogatePair(0xD801, 0xC501).into()),
unescape_string("\\uD801\\uC501")
);
assert_eq!(
Err(EscapeErrorKind::InvalidUnicodeEscape("+04D".into()).into()),
unescape_string("\\u+04D")
);
}
use proptest::proptest;
proptest! {
#[test]
fn matches_serde_json(s in ".*") {
let serde_escaped = serde_json::to_string(&s).unwrap();
let serde_escaped = &serde_escaped[1..(serde_escaped.len() - 1)];
assert_eq!(serde_escaped,escape_string(&s))
}
#[test]
fn round_trip(chr in proptest::char::any()) {
let mut original = String::new();
original.push(chr);
let escaped = escape_string(&original);
let unescaped = unescape_string(&escaped).unwrap();
assert_eq!(original, unescaped);
}
#[test]
fn unicode_surrogates(chr in proptest::char::range(
std::char::from_u32(0x10000).unwrap(),
std::char::from_u32(0x10FFFF).unwrap(),
)) {
let mut codepoints = [0; 2];
chr.encode_utf16(&mut codepoints);
let escaped = format!("\\u{:04X}\\u{:04X}", codepoints[0], codepoints[1]);
let unescaped = unescape_string(&escaped).unwrap();
let expected = format!("{}", chr);
assert_eq!(expected, unescaped);
}
}
#[test]
#[ignore] fn all_codepoints() {
for value in 0..u32::MAX {
if let Some(chr) = char::from_u32(value) {
let string = String::from(chr);
let escaped = escape_string(&string);
let serde_escaped = serde_json::to_string(&string).unwrap();
let serde_escaped = &serde_escaped[1..(serde_escaped.len() - 1)];
assert_eq!(&escaped, serde_escaped);
}
}
}
}