1use std::char::decode_utf16;
2use std::ffi::{OsStr, OsString};
3use std::os::windows::ffi::{OsStrExt, OsStringExt};
4use std::{io, str};
5
6pub fn bytes_to_host(bytes: &[u8]) -> io::Result<OsString> {
9 let s = str::from_utf8(bytes).map_err(|_| encoding_error())?;
10 str_to_host(s)
11}
12
13pub fn str_to_host(s: &str) -> io::Result<OsString> {
16 if let Some(nul_position) = s.chars().position(|c| c == '\0') {
17 from_arf(s, nul_position)
18 } else {
19 Ok(OsString::from_wide(&s.encode_utf16().collect::<Vec<_>>()))
20 }
21}
22
23pub fn host_to_str(host: &OsStr) -> io::Result<String> {
26 let wide = host.encode_wide().collect::<Vec<_>>();
27 if wide.contains(&0) {
28 return Err(encoding_error());
29 }
30 Ok(if let Ok(s) = String::from_utf16(&wide) {
31 s
32 } else {
33 to_arf(&wide)
34 })
35}
36
37pub fn host_to_bytes(host: &OsStr) -> io::Result<Vec<u8>> {
40 host_to_str(host).map(String::into_bytes)
41}
42
43#[cold]
45fn from_arf(s: &str, nul: usize) -> io::Result<OsString> {
46 let mut lossy = s.chars();
47 if lossy.next() != Some('\u{feff}') {
48 return Err(encoding_error());
49 }
50
51 let mut nul_escaped = s.chars().skip(nul + 1);
52 let mut any_invalid = false;
53 let mut vec = Vec::new();
54 while let Some(c) = nul_escaped.next() {
55 if c == '\0' {
56 let more = nul_escaped.next().ok_or_else(encoding_error)?;
57 if more > '\u{7ff}' {
58 return Err(encoding_error());
59 }
60 let l = lossy.next().ok_or_else(encoding_error)?;
62 if l != '\u{fffd}' {
63 return Err(encoding_error());
64 }
65 any_invalid = true;
66 let unit = u16::try_from((more as u16) + 0xd800).map_err(|_| encoding_error())?;
67 vec.push(unit);
68 } else {
69 if lossy.next() != Some(c) {
70 return Err(encoding_error());
71 }
72 let mut buf = [0; 2];
73 let utf16 = c.encode_utf16(&mut buf);
74 for unit in utf16 {
75 vec.push(*unit);
76 }
77 }
78 }
79 if !any_invalid {
80 return Err(encoding_error());
81 }
82 if lossy.next() != Some('\0') {
83 return Err(encoding_error());
84 }
85
86 Ok(OsString::from_wide(&vec))
88}
89
90#[cold]
92fn to_arf(units: &[u16]) -> String {
93 let mut data = String::new();
94
95 data.push('\u{feff}');
96
97 for unit in decode_utf16(units.iter().copied()) {
98 match unit {
99 Ok(c) => data.push(c),
100 Err(_) => data.push('\u{fffd}'),
101 }
102 }
103
104 data.push('\0');
105
106 for unit in decode_utf16(units.iter().copied()) {
107 match unit {
108 Ok(c) => data.push(c),
109 Err(e) => {
110 let bad = e.unpaired_surrogate();
111 assert!(bad >= 0xd800 && bad <= 0xdfff);
112 data.push('\0');
113 data.push(std::char::from_u32(u32::from(bad - 0xd800)).unwrap());
114 }
115 }
116 }
117
118 data
119}
120
121#[cold]
122fn encoding_error() -> io::Error {
123 io::Error::new(io::ErrorKind::InvalidData, "invalid path string")
124}
125
126#[test]
127fn utf16_inputs() {
128 assert_eq!(
129 String::from_utf16(&str_to_host("").unwrap().encode_wide().collect::<Vec<_>>()).unwrap(),
130 ""
131 );
132 str_to_host("\0").unwrap_err();
133 assert_eq!(
134 String::from_utf16(&str_to_host("f").unwrap().encode_wide().collect::<Vec<_>>()).unwrap(),
135 "f"
136 );
137 assert_eq!(
138 String::from_utf16(
139 &str_to_host("foo")
140 .unwrap()
141 .encode_wide()
142 .collect::<Vec<_>>()
143 )
144 .unwrap(),
145 "foo"
146 );
147 assert_eq!(
148 String::from_utf16(
149 &str_to_host("\u{fffd}")
150 .unwrap()
151 .encode_wide()
152 .collect::<Vec<_>>()
153 )
154 .unwrap(),
155 "\u{fffd}"
156 );
157 assert_eq!(
158 String::from_utf16(
159 &str_to_host("\u{fffd}foo")
160 .unwrap()
161 .encode_wide()
162 .collect::<Vec<_>>()
163 )
164 .unwrap(),
165 "\u{fffd}foo"
166 );
167 assert_eq!(
168 String::from_utf16(
169 &str_to_host("\u{feff}foo")
170 .unwrap()
171 .encode_wide()
172 .collect::<Vec<_>>()
173 )
174 .unwrap(),
175 "\u{feff}foo"
176 );
177}
178
179#[test]
180fn arf_inputs() {
181 assert_eq!(
182 str_to_host("\u{feff}hello\u{fffd}world\0hello\0\x05world")
183 .unwrap()
184 .encode_wide()
185 .collect::<Vec<_>>(),
186 [
187 'h' as u16, 'e' as u16, 'l' as u16, 'l' as u16, 'o' as u16, 0xd805_u16, 'w' as u16,
188 'o' as u16, 'r' as u16, 'l' as u16, 'd' as u16
189 ]
190 );
191 assert_eq!(
192 str_to_host("\u{feff}hello\u{fffd}\0hello\0\x05")
193 .unwrap()
194 .encode_wide()
195 .collect::<Vec<_>>(),
196 ['h' as u16, 'e' as u16, 'l' as u16, 'l' as u16, 'o' as u16, 0xd805_u16]
197 );
198}
199
200#[test]
201fn errors_from_bytes() {
202 assert!(bytes_to_host(b"\xfe").is_err());
203 assert!(bytes_to_host(b"\xc0\xff").is_err());
204}
205
206#[test]
207fn errors_from_str() {
208 assert!(str_to_host("\u{feff}hello world\0hello world").is_err());
209 assert!(str_to_host("\u{feff}hello world\0\0hello world\0").is_err());
210 assert!(str_to_host("\u{feff}hello\u{fffd}world\0\0hello\0\x05world\0").is_err());
211 assert!(str_to_host("\u{fffe}hello\u{fffd}world\0hello\0\x05world").is_err());
212 assert!(str_to_host("\u{feff}hello\u{fffd}\0hello\0").is_err());
213}
214
215#[test]
216fn valid_utf16() {
217 assert_eq!(host_to_str(OsStr::new("")).unwrap(), "");
218 assert_eq!(host_to_str(OsStr::new("foo")).unwrap(), "foo");
219}
220
221#[test]
222fn not_utf16() {
223 assert_eq!(
224 host_to_str(&OsString::from_wide(&[0xd800_u16])).unwrap(),
225 "\u{feff}\u{fffd}\0\0\u{0}"
226 );
227 assert_eq!(
228 host_to_str(&OsString::from_wide(&[0xdfff_u16])).unwrap(),
229 "\u{feff}\u{fffd}\0\0\u{7ff}"
230 );
231}
232
233#[test]
234fn round_trip() {
235 assert_eq!(host_to_str(&bytes_to_host(b"").unwrap()).unwrap(), "");
236 assert_eq!(
237 host_to_str(&bytes_to_host(b"hello").unwrap()).unwrap(),
238 "hello"
239 );
240 assert_eq!(
241 str_to_host(&host_to_str(OsStr::new("hello")).unwrap()).unwrap(),
242 OsStr::new("hello")
243 );
244 assert_eq!(
245 str_to_host(&host_to_str(&OsString::from_wide(&[0x47_u16, 0xd800_u16, 0x48_u16])).unwrap())
246 .unwrap(),
247 OsString::from_wide(&[0x47_u16, 0xd800_u16, 0x48_u16])
248 );
249 assert_eq!(
250 str_to_host(&host_to_str(&OsString::from_wide(&[0x49_u16, 0xdfff_u16, 0x50_u16])).unwrap())
251 .unwrap(),
252 OsString::from_wide(&[0x49_u16, 0xdfff_u16, 0x50_u16])
253 );
254 assert_eq!(
255 str_to_host(&host_to_str(OsStr::new("")).unwrap()).unwrap(),
256 OsStr::new("")
257 );
258}