1use std::io::{IoSlice, IoSliceMut};
4use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
5
6use nix::{sys::socket, Result as NixResult};
7
8use crate::wire::{ArgumentType, Message, MessageParseError, MessageWriteError};
9
10pub const MAX_FDS_OUT: usize = 28;
12pub const MAX_BYTES_OUT: usize = 4096;
14
15#[derive(Debug)]
21pub struct Socket {
22 fd: RawFd,
23}
24
25impl Socket {
26 pub fn send_msg(&self, bytes: &[u8], fds: &[RawFd]) -> NixResult<()> {
34 let flags = socket::MsgFlags::MSG_DONTWAIT | socket::MsgFlags::MSG_NOSIGNAL;
35 let iov = [IoSlice::new(bytes)];
36
37 if !fds.is_empty() {
38 let cmsgs = [socket::ControlMessage::ScmRights(fds)];
39 socket::sendmsg::<()>(self.fd, &iov, &cmsgs, flags, None)?;
40 } else {
41 socket::sendmsg::<()>(self.fd, &iov, &[], flags, None)?;
42 };
43 Ok(())
44 }
45
46 pub fn rcv_msg(&self, buffer: &mut [u8], fds: &mut [RawFd]) -> NixResult<(usize, usize)> {
58 let mut cmsg = cmsg_space!([RawFd; MAX_FDS_OUT]);
59 let mut iov = [IoSliceMut::new(buffer)];
60
61 let msg = socket::recvmsg::<()>(
62 self.fd,
63 &mut iov[..],
64 Some(&mut cmsg),
65 socket::MsgFlags::MSG_DONTWAIT
66 | socket::MsgFlags::MSG_CMSG_CLOEXEC
67 | socket::MsgFlags::MSG_NOSIGNAL,
68 )?;
69
70 let mut fd_count = 0;
71 let received_fds = msg.cmsgs().flat_map(|cmsg| match cmsg {
72 socket::ControlMessageOwned::ScmRights(s) => s,
73 _ => Vec::new(),
74 });
75 for (fd, place) in received_fds.zip(fds.iter_mut()) {
76 fd_count += 1;
77 *place = fd;
78 }
79 Ok((msg.bytes, fd_count))
80 }
81
82 pub fn opt<O: socket::GetSockOpt>(&self, opt: O) -> NixResult<O::Val> {
84 socket::getsockopt(self.fd, opt)
85 }
86}
87
88impl FromRawFd for Socket {
89 unsafe fn from_raw_fd(fd: RawFd) -> Socket {
90 Socket { fd }
91 }
92}
93
94impl AsRawFd for Socket {
95 fn as_raw_fd(&self) -> RawFd {
96 self.fd
97 }
98}
99
100impl IntoRawFd for Socket {
101 fn into_raw_fd(self) -> RawFd {
102 self.fd
103 }
104}
105
106impl Drop for Socket {
107 fn drop(&mut self) {
108 let _ = ::nix::unistd::close(self.fd);
109 }
110}
111
112#[derive(Debug)]
119pub struct BufferedSocket {
120 socket: Socket,
121 in_data: Buffer<u32>,
122 in_fds: Buffer<RawFd>,
123 out_data: Buffer<u32>,
124 out_fds: Buffer<RawFd>,
125}
126
127impl BufferedSocket {
128 pub fn new(socket: Socket) -> BufferedSocket {
130 BufferedSocket {
131 socket,
132 in_data: Buffer::new(2 * MAX_BYTES_OUT / 4), in_fds: Buffer::new(2 * MAX_FDS_OUT), out_data: Buffer::new(MAX_BYTES_OUT / 4),
135 out_fds: Buffer::new(MAX_FDS_OUT),
136 }
137 }
138
139 pub fn get_socket(&mut self) -> &mut Socket {
141 &mut self.socket
142 }
143
144 pub fn into_socket(self) -> Socket {
148 self.socket
149 }
150
151 pub fn flush(&mut self) -> NixResult<()> {
153 {
154 let words = self.out_data.get_contents();
155 if words.is_empty() {
156 return Ok(());
157 }
158 let bytes = unsafe {
159 ::std::slice::from_raw_parts(words.as_ptr() as *const u8, words.len() * 4)
160 };
161 let fds = self.out_fds.get_contents();
162 self.socket.send_msg(bytes, fds)?;
163 for &fd in fds {
164 let _ = ::nix::unistd::close(fd);
166 }
167 }
168 self.out_data.clear();
169 self.out_fds.clear();
170 Ok(())
171 }
172
173 fn attempt_write_message(&mut self, msg: &Message) -> NixResult<bool> {
181 match msg.write_to_buffers(
182 self.out_data.get_writable_storage(),
183 self.out_fds.get_writable_storage(),
184 ) {
185 Ok((bytes_out, fds_out)) => {
186 self.out_data.advance(bytes_out);
187 self.out_fds.advance(fds_out);
188 Ok(true)
189 }
190 Err(MessageWriteError::BufferTooSmall) => Ok(false),
191 Err(MessageWriteError::DupFdFailed(e)) => Err(e),
192 }
193 }
194
195 pub fn write_message(&mut self, msg: &Message) -> NixResult<()> {
202 if !self.attempt_write_message(msg)? {
203 self.flush()?;
206 if !self.attempt_write_message(msg)? {
207 return Err(::nix::Error::E2BIG);
210 }
211 }
212 Ok(())
213 }
214
215 pub fn fill_incoming_buffers(&mut self) -> NixResult<()> {
218 if !self.in_data.has_content() {
220 self.in_data.clear();
221 }
222 if !self.in_fds.has_content() {
223 self.in_fds.clear();
224 }
225 let (in_bytes, in_fds) = {
227 let words = self.in_data.get_writable_storage();
228 let bytes = unsafe {
229 ::std::slice::from_raw_parts_mut(words.as_ptr() as *mut u8, words.len() * 4)
230 };
231 let fds = self.in_fds.get_writable_storage();
232 self.socket.rcv_msg(bytes, fds)?
233 };
234 if in_bytes == 0 {
235 return Err(::nix::Error::EPIPE);
237 }
238 self.in_data.advance(in_bytes / 4 + if in_bytes % 4 > 0 { 1 } else { 0 });
240 self.in_fds.advance(in_fds);
241 Ok(())
242 }
243
244 pub fn read_one_message<F>(&mut self, mut signature: F) -> Result<Message, MessageParseError>
260 where
261 F: FnMut(u32, u16) -> Option<&'static [ArgumentType]>,
262 {
263 let (msg, read_data, read_fd) = {
264 let data = self.in_data.get_contents();
265 let fds = self.in_fds.get_contents();
266 if data.len() < 2 {
267 return Err(MessageParseError::MissingData);
268 }
269 let object_id = data[0];
270 let opcode = (data[1] & 0x0000_FFFF) as u16;
271 if let Some(sig) = signature(object_id, opcode) {
272 match Message::from_raw(data, sig, fds) {
273 Ok((msg, rest_data, rest_fds)) => {
274 (msg, data.len() - rest_data.len(), fds.len() - rest_fds.len())
275 }
276 Err(e) => return Err(e),
278 }
279 } else {
280 return Err(MessageParseError::Malformed);
282 }
283 };
284
285 self.in_data.offset(read_data);
286 self.in_fds.offset(read_fd);
287
288 Ok(msg)
289 }
290
291 pub fn read_messages<F1, F2>(
317 &mut self,
318 mut signature: F1,
319 mut callback: F2,
320 ) -> NixResult<Result<usize, MessageParseError>>
321 where
322 F1: FnMut(u32, u16) -> Option<&'static [ArgumentType]>,
323 F2: FnMut(Message) -> bool,
324 {
325 let mut dispatched = 0;
327
328 loop {
329 let mut err = None;
330 loop {
332 match self.read_one_message(&mut signature) {
333 Ok(msg) => {
334 let keep_going = callback(msg);
335 dispatched += 1;
336 if !keep_going {
337 break;
338 }
339 }
340 Err(e) => {
341 err = Some(e);
342 break;
343 }
344 }
345 }
346
347 self.in_data.move_to_front();
349 self.in_fds.move_to_front();
350
351 if let Some(MessageParseError::Malformed) = err {
352 return Ok(Err(MessageParseError::Malformed));
354 }
355
356 if err.is_none() && self.in_data.has_content() {
357 return Ok(Ok(dispatched));
360 }
361
362 match self.fill_incoming_buffers() {
364 Ok(()) => (),
365 Err(e @ ::nix::Error::EAGAIN) => {
366 if dispatched == 0 {
369 return Err(e);
370 } else {
371 break;
372 }
373 }
374 Err(e) => return Err(e),
375 }
376 }
377
378 Ok(Ok(dispatched))
379 }
380}
381
382#[derive(Debug)]
386struct Buffer<T: Copy> {
387 storage: Vec<T>,
388 occupied: usize,
389 offset: usize,
390}
391
392impl<T: Copy + Default> Buffer<T> {
393 fn new(size: usize) -> Buffer<T> {
394 Buffer { storage: vec![T::default(); size], occupied: 0, offset: 0 }
395 }
396
397 fn has_content(&self) -> bool {
399 self.occupied > self.offset
400 }
401
402 fn advance(&mut self, bytes: usize) {
404 self.occupied += bytes;
405 }
406
407 fn offset(&mut self, bytes: usize) {
409 self.offset += bytes;
410 }
411
412 fn clear(&mut self) {
417 self.occupied = 0;
418 self.offset = 0;
419 }
420
421 fn get_contents(&self) -> &[T] {
423 &self.storage[(self.offset)..(self.occupied)]
424 }
425
426 fn get_writable_storage(&mut self) -> &mut [T] {
428 &mut self.storage[(self.occupied)..]
429 }
430
431 fn move_to_front(&mut self) {
434 unsafe {
435 ::std::ptr::copy(
436 &self.storage[self.offset] as *const T,
437 &mut self.storage[0] as *mut T,
438 self.occupied - self.offset,
439 );
440 }
441 self.occupied -= self.offset;
442 self.offset = 0;
443 }
444}
445
446#[cfg(test)]
447mod tests {
448 use super::*;
449 use crate::wire::{Argument, ArgumentType, Message};
450
451 use std::ffi::CString;
452
453 use smallvec::smallvec;
454
455 fn same_file(a: RawFd, b: RawFd) -> bool {
456 let stat1 = ::nix::sys::stat::fstat(a).unwrap();
457 let stat2 = ::nix::sys::stat::fstat(b).unwrap();
458 stat1.st_dev == stat2.st_dev && stat1.st_ino == stat2.st_ino
459 }
460
461 fn assert_eq_msgs(msg1: &Message, msg2: &Message) {
466 assert_eq!(msg1.sender_id, msg2.sender_id);
467 assert_eq!(msg1.opcode, msg2.opcode);
468 assert_eq!(msg1.args.len(), msg2.args.len());
469 for (arg1, arg2) in msg1.args.iter().zip(msg2.args.iter()) {
470 if let (&Argument::Fd(fd1), &Argument::Fd(fd2)) = (arg1, arg2) {
471 assert!(same_file(fd1, fd2));
472 } else {
473 assert_eq!(arg1, arg2);
474 }
475 }
476 }
477
478 #[test]
479 fn write_read_cycle() {
480 let msg = Message {
481 sender_id: 42,
482 opcode: 7,
483 args: smallvec![
484 Argument::Uint(3),
485 Argument::Fixed(-89),
486 Argument::Str(Box::new(CString::new(&b"I like trains!"[..]).unwrap())),
487 Argument::Array(vec![1, 2, 3, 4, 5, 6, 7, 8, 9].into()),
488 Argument::Object(88),
489 Argument::NewId(56),
490 Argument::Int(-25),
491 ],
492 };
493
494 let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
495 let mut client = BufferedSocket::new(unsafe { Socket::from_raw_fd(client.into_raw_fd()) });
496 let mut server = BufferedSocket::new(unsafe { Socket::from_raw_fd(server.into_raw_fd()) });
497
498 client.write_message(&msg).unwrap();
499 client.flush().unwrap();
500
501 static SIGNATURE: &'static [ArgumentType] = &[
502 ArgumentType::Uint,
503 ArgumentType::Fixed,
504 ArgumentType::Str,
505 ArgumentType::Array,
506 ArgumentType::Object,
507 ArgumentType::NewId,
508 ArgumentType::Int,
509 ];
510
511 let ret = server
512 .read_messages(
513 |sender_id, opcode| {
514 if sender_id == 42 && opcode == 7 {
515 Some(SIGNATURE)
516 } else {
517 None
518 }
519 },
520 |message| {
521 assert_eq_msgs(&message, &msg);
522 true
523 },
524 )
525 .unwrap()
526 .unwrap();
527
528 assert_eq!(ret, 1);
529 }
530
531 #[test]
532 fn write_read_cycle_fd() {
533 let msg = Message {
534 sender_id: 42,
535 opcode: 7,
536 args: smallvec![
537 Argument::Fd(1), Argument::Fd(0), ],
540 };
541
542 let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
543 let mut client = BufferedSocket::new(unsafe { Socket::from_raw_fd(client.into_raw_fd()) });
544 let mut server = BufferedSocket::new(unsafe { Socket::from_raw_fd(server.into_raw_fd()) });
545
546 client.write_message(&msg).unwrap();
547 client.flush().unwrap();
548
549 static SIGNATURE: &'static [ArgumentType] = &[ArgumentType::Fd, ArgumentType::Fd];
550
551 let ret = server
552 .read_messages(
553 |sender_id, opcode| {
554 if sender_id == 42 && opcode == 7 {
555 Some(SIGNATURE)
556 } else {
557 None
558 }
559 },
560 |message| {
561 assert_eq_msgs(&message, &msg);
562 true
563 },
564 )
565 .unwrap()
566 .unwrap();
567
568 assert_eq!(ret, 1);
569 }
570
571 #[test]
572 fn write_read_cycle_multiple() {
573 let messages = [
574 Message {
575 sender_id: 42,
576 opcode: 0,
577 args: smallvec![
578 Argument::Int(42),
579 Argument::Str(Box::new(CString::new(&b"I like trains"[..]).unwrap())),
580 ],
581 },
582 Message {
583 sender_id: 42,
584 opcode: 1,
585 args: smallvec![
586 Argument::Fd(1), Argument::Fd(0), ],
589 },
590 Message {
591 sender_id: 42,
592 opcode: 2,
593 args: smallvec![
594 Argument::Uint(3),
595 Argument::Fd(2), ],
597 },
598 ];
599
600 static SIGNATURES: &'static [&'static [ArgumentType]] = &[
601 &[ArgumentType::Int, ArgumentType::Str],
602 &[ArgumentType::Fd, ArgumentType::Fd],
603 &[ArgumentType::Uint, ArgumentType::Fd],
604 ];
605
606 let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
607 let mut client = BufferedSocket::new(unsafe { Socket::from_raw_fd(client.into_raw_fd()) });
608 let mut server = BufferedSocket::new(unsafe { Socket::from_raw_fd(server.into_raw_fd()) });
609
610 for msg in &messages {
611 client.write_message(msg).unwrap();
612 }
613 client.flush().unwrap();
614
615 let mut recv_msgs = Vec::new();
616 let ret = server
617 .read_messages(
618 |sender_id, opcode| {
619 if sender_id == 42 {
620 Some(SIGNATURES[opcode as usize])
621 } else {
622 None
623 }
624 },
625 |message| {
626 recv_msgs.push(message);
627 true
628 },
629 )
630 .unwrap()
631 .unwrap();
632
633 assert_eq!(ret, 3);
634 assert_eq!(recv_msgs.len(), 3);
635 for (msg1, msg2) in messages.iter().zip(recv_msgs.iter()) {
636 assert_eq_msgs(msg1, msg2);
637 }
638 }
639
640 #[test]
641 fn parse_with_string_len_multiple_of_4() {
642 let msg = Message {
643 sender_id: 2,
644 opcode: 0,
645 args: smallvec![
646 Argument::Uint(18),
647 Argument::Str(Box::new(CString::new(&b"wl_shell"[..]).unwrap())),
648 Argument::Uint(1),
649 ],
650 };
651
652 let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
653 let mut client = BufferedSocket::new(unsafe { Socket::from_raw_fd(client.into_raw_fd()) });
654 let mut server = BufferedSocket::new(unsafe { Socket::from_raw_fd(server.into_raw_fd()) });
655
656 client.write_message(&msg).unwrap();
657 client.flush().unwrap();
658
659 static SIGNATURE: &'static [ArgumentType] =
660 &[ArgumentType::Uint, ArgumentType::Str, ArgumentType::Uint];
661
662 let ret = server
663 .read_messages(
664 |sender_id, opcode| {
665 if sender_id == 2 && opcode == 0 {
666 Some(SIGNATURE)
667 } else {
668 None
669 }
670 },
671 |message| {
672 assert_eq_msgs(&message, &msg);
673 true
674 },
675 )
676 .unwrap()
677 .unwrap();
678
679 assert_eq!(ret, 1);
680 }
681}