wayland_commons/
socket.rs

1//! Wayland socket manipulation
2
3use std::io::{IoSlice, IoSliceMut};
4use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
5
6use nix::{sys::socket, Result as NixResult};
7
8use crate::wire::{ArgumentType, Message, MessageParseError, MessageWriteError};
9
10/// Maximum number of FD that can be sent in a single socket message
11pub const MAX_FDS_OUT: usize = 28;
12/// Maximum number of bytes that can be sent in a single socket message
13pub const MAX_BYTES_OUT: usize = 4096;
14
15/*
16 * Socket
17 */
18
19/// A wayland socket
20#[derive(Debug)]
21pub struct Socket {
22    fd: RawFd,
23}
24
25impl Socket {
26    /// Send a single message to the socket
27    ///
28    /// A single socket message can contain several wayland messages
29    ///
30    /// The `fds` slice should not be longer than `MAX_FDS_OUT`, and the `bytes`
31    /// slice should not be longer than `MAX_BYTES_OUT` otherwise the receiving
32    /// end may lose some data.
33    pub fn send_msg(&self, bytes: &[u8], fds: &[RawFd]) -> NixResult<()> {
34        let flags = socket::MsgFlags::MSG_DONTWAIT | socket::MsgFlags::MSG_NOSIGNAL;
35        let iov = [IoSlice::new(bytes)];
36
37        if !fds.is_empty() {
38            let cmsgs = [socket::ControlMessage::ScmRights(fds)];
39            socket::sendmsg::<()>(self.fd, &iov, &cmsgs, flags, None)?;
40        } else {
41            socket::sendmsg::<()>(self.fd, &iov, &[], flags, None)?;
42        };
43        Ok(())
44    }
45
46    /// Receive a single message from the socket
47    ///
48    /// Return the number of bytes received and the number of Fds received.
49    ///
50    /// Errors with `WouldBlock` is no message is available.
51    ///
52    /// A single socket message can contain several wayland messages.
53    ///
54    /// The `buffer` slice should be at least `MAX_BYTES_OUT` long and the `fds`
55    /// slice `MAX_FDS_OUT` long, otherwise some data of the received message may
56    /// be lost.
57    pub fn rcv_msg(&self, buffer: &mut [u8], fds: &mut [RawFd]) -> NixResult<(usize, usize)> {
58        let mut cmsg = cmsg_space!([RawFd; MAX_FDS_OUT]);
59        let mut iov = [IoSliceMut::new(buffer)];
60
61        let msg = socket::recvmsg::<()>(
62            self.fd,
63            &mut iov[..],
64            Some(&mut cmsg),
65            socket::MsgFlags::MSG_DONTWAIT
66                | socket::MsgFlags::MSG_CMSG_CLOEXEC
67                | socket::MsgFlags::MSG_NOSIGNAL,
68        )?;
69
70        let mut fd_count = 0;
71        let received_fds = msg.cmsgs().flat_map(|cmsg| match cmsg {
72            socket::ControlMessageOwned::ScmRights(s) => s,
73            _ => Vec::new(),
74        });
75        for (fd, place) in received_fds.zip(fds.iter_mut()) {
76            fd_count += 1;
77            *place = fd;
78        }
79        Ok((msg.bytes, fd_count))
80    }
81
82    /// Retrieve the current value of the requested [socket::GetSockOpt]
83    pub fn opt<O: socket::GetSockOpt>(&self, opt: O) -> NixResult<O::Val> {
84        socket::getsockopt(self.fd, opt)
85    }
86}
87
88impl FromRawFd for Socket {
89    unsafe fn from_raw_fd(fd: RawFd) -> Socket {
90        Socket { fd }
91    }
92}
93
94impl AsRawFd for Socket {
95    fn as_raw_fd(&self) -> RawFd {
96        self.fd
97    }
98}
99
100impl IntoRawFd for Socket {
101    fn into_raw_fd(self) -> RawFd {
102        self.fd
103    }
104}
105
106impl Drop for Socket {
107    fn drop(&mut self) {
108        let _ = ::nix::unistd::close(self.fd);
109    }
110}
111
112/*
113 * BufferedSocket
114 */
115
116/// An adapter around a raw Socket that directly handles buffering and
117/// conversion from/to wayland messages
118#[derive(Debug)]
119pub struct BufferedSocket {
120    socket: Socket,
121    in_data: Buffer<u32>,
122    in_fds: Buffer<RawFd>,
123    out_data: Buffer<u32>,
124    out_fds: Buffer<RawFd>,
125}
126
127impl BufferedSocket {
128    /// Wrap a Socket into a Buffered Socket
129    pub fn new(socket: Socket) -> BufferedSocket {
130        BufferedSocket {
131            socket,
132            in_data: Buffer::new(2 * MAX_BYTES_OUT / 4), // Incoming buffers are twice as big in order to be
133            in_fds: Buffer::new(2 * MAX_FDS_OUT),        // able to store leftover data if needed
134            out_data: Buffer::new(MAX_BYTES_OUT / 4),
135            out_fds: Buffer::new(MAX_FDS_OUT),
136        }
137    }
138
139    /// Get direct access to the underlying socket
140    pub fn get_socket(&mut self) -> &mut Socket {
141        &mut self.socket
142    }
143
144    /// Retrieve ownership of the underlying Socket
145    ///
146    /// Any leftover content in the internal buffers will be lost
147    pub fn into_socket(self) -> Socket {
148        self.socket
149    }
150
151    /// Flush the contents of the outgoing buffer into the socket
152    pub fn flush(&mut self) -> NixResult<()> {
153        {
154            let words = self.out_data.get_contents();
155            if words.is_empty() {
156                return Ok(());
157            }
158            let bytes = unsafe {
159                ::std::slice::from_raw_parts(words.as_ptr() as *const u8, words.len() * 4)
160            };
161            let fds = self.out_fds.get_contents();
162            self.socket.send_msg(bytes, fds)?;
163            for &fd in fds {
164                // once the fds are sent, we can close them
165                let _ = ::nix::unistd::close(fd);
166            }
167        }
168        self.out_data.clear();
169        self.out_fds.clear();
170        Ok(())
171    }
172
173    // internal method
174    //
175    // attempts to write a message in the internal out buffers,
176    // returns true if successful
177    //
178    // if false is returned, it means there is not enough space
179    // in the buffer
180    fn attempt_write_message(&mut self, msg: &Message) -> NixResult<bool> {
181        match msg.write_to_buffers(
182            self.out_data.get_writable_storage(),
183            self.out_fds.get_writable_storage(),
184        ) {
185            Ok((bytes_out, fds_out)) => {
186                self.out_data.advance(bytes_out);
187                self.out_fds.advance(fds_out);
188                Ok(true)
189            }
190            Err(MessageWriteError::BufferTooSmall) => Ok(false),
191            Err(MessageWriteError::DupFdFailed(e)) => Err(e),
192        }
193    }
194
195    /// Write a message to the outgoing buffer
196    ///
197    /// This method may flush the internal buffer if necessary (if it is full).
198    ///
199    /// If the message is too big to fit in the buffer, the error `Error::Sys(E2BIG)`
200    /// will be returned.
201    pub fn write_message(&mut self, msg: &Message) -> NixResult<()> {
202        if !self.attempt_write_message(msg)? {
203            // the attempt failed, there is not enough space in the buffer
204            // we need to flush it
205            self.flush()?;
206            if !self.attempt_write_message(msg)? {
207                // If this fails again, this means the message is too big
208                // to be transmitted at all
209                return Err(::nix::Error::E2BIG);
210            }
211        }
212        Ok(())
213    }
214
215    /// Try to fill the incoming buffers of this socket, to prepare
216    /// a new round of parsing.
217    pub fn fill_incoming_buffers(&mut self) -> NixResult<()> {
218        // clear the buffers if they have no content
219        if !self.in_data.has_content() {
220            self.in_data.clear();
221        }
222        if !self.in_fds.has_content() {
223            self.in_fds.clear();
224        }
225        // receive a message
226        let (in_bytes, in_fds) = {
227            let words = self.in_data.get_writable_storage();
228            let bytes = unsafe {
229                ::std::slice::from_raw_parts_mut(words.as_ptr() as *mut u8, words.len() * 4)
230            };
231            let fds = self.in_fds.get_writable_storage();
232            self.socket.rcv_msg(bytes, fds)?
233        };
234        if in_bytes == 0 {
235            // the other end of the socket was closed
236            return Err(::nix::Error::EPIPE);
237        }
238        // advance the storage
239        self.in_data.advance(in_bytes / 4 + if in_bytes % 4 > 0 { 1 } else { 0 });
240        self.in_fds.advance(in_fds);
241        Ok(())
242    }
243
244    /// Read and deserialize a single message from the incoming buffers socket
245    ///
246    /// This method requires one closure that given an object id and an opcode,
247    /// must provide the signature of the associated request/event, in the form of
248    /// a `&'static [ArgumentType]`. If it returns `None`, meaning that
249    /// the couple object/opcode does not exist, an error will be returned.
250    ///
251    /// There are 3 possibilities of return value:
252    ///
253    /// - `Ok(Ok(msg))`: no error occurred, this is the message
254    /// - `Ok(Err(e))`: either a malformed message was encountered or we need more data,
255    ///    in the latter case you need to try calling `fill_incoming_buffers()`.
256    /// - `Err(e)`: an I/O error occurred reading from the socked, details are in `e`
257    ///   (this can be a "wouldblock" error, which just means that no message is available
258    ///   to read)
259    pub fn read_one_message<F>(&mut self, mut signature: F) -> Result<Message, MessageParseError>
260    where
261        F: FnMut(u32, u16) -> Option<&'static [ArgumentType]>,
262    {
263        let (msg, read_data, read_fd) = {
264            let data = self.in_data.get_contents();
265            let fds = self.in_fds.get_contents();
266            if data.len() < 2 {
267                return Err(MessageParseError::MissingData);
268            }
269            let object_id = data[0];
270            let opcode = (data[1] & 0x0000_FFFF) as u16;
271            if let Some(sig) = signature(object_id, opcode) {
272                match Message::from_raw(data, sig, fds) {
273                    Ok((msg, rest_data, rest_fds)) => {
274                        (msg, data.len() - rest_data.len(), fds.len() - rest_fds.len())
275                    }
276                    // TODO: gracefully handle wayland messages split across unix messages ?
277                    Err(e) => return Err(e),
278                }
279            } else {
280                // no signature found ?
281                return Err(MessageParseError::Malformed);
282            }
283        };
284
285        self.in_data.offset(read_data);
286        self.in_fds.offset(read_fd);
287
288        Ok(msg)
289    }
290
291    /// Read and deserialize messages from the socket
292    ///
293    /// This method requires two closures:
294    ///
295    /// - The first one, given an object id and an opcode, must provide
296    ///   the signature of the associated request/event, in the form of
297    ///   a `&'static [ArgumentType]`. If it returns `None`, meaning that
298    ///   the couple object/opcode does not exist, the parsing will be
299    ///   prematurely interrupted and this method will return a
300    ///   `MessageParseError::Malformed` error.
301    /// - The second closure is charged to process the parsed message. If it
302    ///   returns `false`, the iteration will be prematurely stopped.
303    ///
304    /// In both cases of early stopping, the remaining unused data will be left
305    /// in the buffers, and will start to be processed at the next call of this
306    /// method.
307    ///
308    /// There are 3 possibilities of return value:
309    ///
310    /// - `Ok(Ok(n))`: no error occurred, `n` messages where processed
311    /// - `Ok(Err(MessageParseError::Malformed))`: a malformed message was encountered
312    ///   (this is a protocol error and is supposed to be fatal to the connection).
313    /// - `Err(e)`: an I/O error occurred reading from the socked, details are in `e`
314    ///   (this can be a "wouldblock" error, which just means that no message is available
315    ///   to read)
316    pub fn read_messages<F1, F2>(
317        &mut self,
318        mut signature: F1,
319        mut callback: F2,
320    ) -> NixResult<Result<usize, MessageParseError>>
321    where
322        F1: FnMut(u32, u16) -> Option<&'static [ArgumentType]>,
323        F2: FnMut(Message) -> bool,
324    {
325        // message parsing
326        let mut dispatched = 0;
327
328        loop {
329            let mut err = None;
330            // first parse any leftover messages
331            loop {
332                match self.read_one_message(&mut signature) {
333                    Ok(msg) => {
334                        let keep_going = callback(msg);
335                        dispatched += 1;
336                        if !keep_going {
337                            break;
338                        }
339                    }
340                    Err(e) => {
341                        err = Some(e);
342                        break;
343                    }
344                }
345            }
346
347            // copy back any leftover content to the front of the buffer
348            self.in_data.move_to_front();
349            self.in_fds.move_to_front();
350
351            if let Some(MessageParseError::Malformed) = err {
352                // early stop here
353                return Ok(Err(MessageParseError::Malformed));
354            }
355
356            if err.is_none() && self.in_data.has_content() {
357                // we stopped reading without error while there is content? That means
358                // the user requested an early stopping
359                return Ok(Ok(dispatched));
360            }
361
362            // now, try to get more data
363            match self.fill_incoming_buffers() {
364                Ok(()) => (),
365                Err(e @ ::nix::Error::EAGAIN) => {
366                    // stop looping, returning Ok() or EAGAIN depending on whether messages
367                    // were dispatched
368                    if dispatched == 0 {
369                        return Err(e);
370                    } else {
371                        break;
372                    }
373                }
374                Err(e) => return Err(e),
375            }
376        }
377
378        Ok(Ok(dispatched))
379    }
380}
381
382/*
383 * Buffer
384 */
385#[derive(Debug)]
386struct Buffer<T: Copy> {
387    storage: Vec<T>,
388    occupied: usize,
389    offset: usize,
390}
391
392impl<T: Copy + Default> Buffer<T> {
393    fn new(size: usize) -> Buffer<T> {
394        Buffer { storage: vec![T::default(); size], occupied: 0, offset: 0 }
395    }
396
397    /// Check if this buffer has content to read
398    fn has_content(&self) -> bool {
399        self.occupied > self.offset
400    }
401
402    /// Advance the internal counter of occupied space
403    fn advance(&mut self, bytes: usize) {
404        self.occupied += bytes;
405    }
406
407    /// Advance the read offset of current occupied space
408    fn offset(&mut self, bytes: usize) {
409        self.offset += bytes;
410    }
411
412    /// Clears the contents of the buffer
413    ///
414    /// This only sets the counter of occupied space back to zero,
415    /// allowing previous content to be overwritten.
416    fn clear(&mut self) {
417        self.occupied = 0;
418        self.offset = 0;
419    }
420
421    /// Get the current contents of the occupied space of the buffer
422    fn get_contents(&self) -> &[T] {
423        &self.storage[(self.offset)..(self.occupied)]
424    }
425
426    /// Get mutable access to the unoccupied space of the buffer
427    fn get_writable_storage(&mut self) -> &mut [T] {
428        &mut self.storage[(self.occupied)..]
429    }
430
431    /// Move the unread contents of the buffer to the front, to ensure
432    /// maximal write space availability
433    fn move_to_front(&mut self) {
434        unsafe {
435            ::std::ptr::copy(
436                &self.storage[self.offset] as *const T,
437                &mut self.storage[0] as *mut T,
438                self.occupied - self.offset,
439            );
440        }
441        self.occupied -= self.offset;
442        self.offset = 0;
443    }
444}
445
446#[cfg(test)]
447mod tests {
448    use super::*;
449    use crate::wire::{Argument, ArgumentType, Message};
450
451    use std::ffi::CString;
452
453    use smallvec::smallvec;
454
455    fn same_file(a: RawFd, b: RawFd) -> bool {
456        let stat1 = ::nix::sys::stat::fstat(a).unwrap();
457        let stat2 = ::nix::sys::stat::fstat(b).unwrap();
458        stat1.st_dev == stat2.st_dev && stat1.st_ino == stat2.st_ino
459    }
460
461    // check if two messages are equal
462    //
463    // if arguments contain FDs, check that the fd point to
464    // the same file, rather than are the same number.
465    fn assert_eq_msgs(msg1: &Message, msg2: &Message) {
466        assert_eq!(msg1.sender_id, msg2.sender_id);
467        assert_eq!(msg1.opcode, msg2.opcode);
468        assert_eq!(msg1.args.len(), msg2.args.len());
469        for (arg1, arg2) in msg1.args.iter().zip(msg2.args.iter()) {
470            if let (&Argument::Fd(fd1), &Argument::Fd(fd2)) = (arg1, arg2) {
471                assert!(same_file(fd1, fd2));
472            } else {
473                assert_eq!(arg1, arg2);
474            }
475        }
476    }
477
478    #[test]
479    fn write_read_cycle() {
480        let msg = Message {
481            sender_id: 42,
482            opcode: 7,
483            args: smallvec![
484                Argument::Uint(3),
485                Argument::Fixed(-89),
486                Argument::Str(Box::new(CString::new(&b"I like trains!"[..]).unwrap())),
487                Argument::Array(vec![1, 2, 3, 4, 5, 6, 7, 8, 9].into()),
488                Argument::Object(88),
489                Argument::NewId(56),
490                Argument::Int(-25),
491            ],
492        };
493
494        let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
495        let mut client = BufferedSocket::new(unsafe { Socket::from_raw_fd(client.into_raw_fd()) });
496        let mut server = BufferedSocket::new(unsafe { Socket::from_raw_fd(server.into_raw_fd()) });
497
498        client.write_message(&msg).unwrap();
499        client.flush().unwrap();
500
501        static SIGNATURE: &'static [ArgumentType] = &[
502            ArgumentType::Uint,
503            ArgumentType::Fixed,
504            ArgumentType::Str,
505            ArgumentType::Array,
506            ArgumentType::Object,
507            ArgumentType::NewId,
508            ArgumentType::Int,
509        ];
510
511        let ret = server
512            .read_messages(
513                |sender_id, opcode| {
514                    if sender_id == 42 && opcode == 7 {
515                        Some(SIGNATURE)
516                    } else {
517                        None
518                    }
519                },
520                |message| {
521                    assert_eq_msgs(&message, &msg);
522                    true
523                },
524            )
525            .unwrap()
526            .unwrap();
527
528        assert_eq!(ret, 1);
529    }
530
531    #[test]
532    fn write_read_cycle_fd() {
533        let msg = Message {
534            sender_id: 42,
535            opcode: 7,
536            args: smallvec![
537                Argument::Fd(1), // stdin
538                Argument::Fd(0), // stdout
539            ],
540        };
541
542        let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
543        let mut client = BufferedSocket::new(unsafe { Socket::from_raw_fd(client.into_raw_fd()) });
544        let mut server = BufferedSocket::new(unsafe { Socket::from_raw_fd(server.into_raw_fd()) });
545
546        client.write_message(&msg).unwrap();
547        client.flush().unwrap();
548
549        static SIGNATURE: &'static [ArgumentType] = &[ArgumentType::Fd, ArgumentType::Fd];
550
551        let ret = server
552            .read_messages(
553                |sender_id, opcode| {
554                    if sender_id == 42 && opcode == 7 {
555                        Some(SIGNATURE)
556                    } else {
557                        None
558                    }
559                },
560                |message| {
561                    assert_eq_msgs(&message, &msg);
562                    true
563                },
564            )
565            .unwrap()
566            .unwrap();
567
568        assert_eq!(ret, 1);
569    }
570
571    #[test]
572    fn write_read_cycle_multiple() {
573        let messages = [
574            Message {
575                sender_id: 42,
576                opcode: 0,
577                args: smallvec![
578                    Argument::Int(42),
579                    Argument::Str(Box::new(CString::new(&b"I like trains"[..]).unwrap())),
580                ],
581            },
582            Message {
583                sender_id: 42,
584                opcode: 1,
585                args: smallvec![
586                    Argument::Fd(1), // stdin
587                    Argument::Fd(0), // stdout
588                ],
589            },
590            Message {
591                sender_id: 42,
592                opcode: 2,
593                args: smallvec![
594                    Argument::Uint(3),
595                    Argument::Fd(2), // stderr
596                ],
597            },
598        ];
599
600        static SIGNATURES: &'static [&'static [ArgumentType]] = &[
601            &[ArgumentType::Int, ArgumentType::Str],
602            &[ArgumentType::Fd, ArgumentType::Fd],
603            &[ArgumentType::Uint, ArgumentType::Fd],
604        ];
605
606        let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
607        let mut client = BufferedSocket::new(unsafe { Socket::from_raw_fd(client.into_raw_fd()) });
608        let mut server = BufferedSocket::new(unsafe { Socket::from_raw_fd(server.into_raw_fd()) });
609
610        for msg in &messages {
611            client.write_message(msg).unwrap();
612        }
613        client.flush().unwrap();
614
615        let mut recv_msgs = Vec::new();
616        let ret = server
617            .read_messages(
618                |sender_id, opcode| {
619                    if sender_id == 42 {
620                        Some(SIGNATURES[opcode as usize])
621                    } else {
622                        None
623                    }
624                },
625                |message| {
626                    recv_msgs.push(message);
627                    true
628                },
629            )
630            .unwrap()
631            .unwrap();
632
633        assert_eq!(ret, 3);
634        assert_eq!(recv_msgs.len(), 3);
635        for (msg1, msg2) in messages.iter().zip(recv_msgs.iter()) {
636            assert_eq_msgs(msg1, msg2);
637        }
638    }
639
640    #[test]
641    fn parse_with_string_len_multiple_of_4() {
642        let msg = Message {
643            sender_id: 2,
644            opcode: 0,
645            args: smallvec![
646                Argument::Uint(18),
647                Argument::Str(Box::new(CString::new(&b"wl_shell"[..]).unwrap())),
648                Argument::Uint(1),
649            ],
650        };
651
652        let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
653        let mut client = BufferedSocket::new(unsafe { Socket::from_raw_fd(client.into_raw_fd()) });
654        let mut server = BufferedSocket::new(unsafe { Socket::from_raw_fd(server.into_raw_fd()) });
655
656        client.write_message(&msg).unwrap();
657        client.flush().unwrap();
658
659        static SIGNATURE: &'static [ArgumentType] =
660            &[ArgumentType::Uint, ArgumentType::Str, ArgumentType::Uint];
661
662        let ret = server
663            .read_messages(
664                |sender_id, opcode| {
665                    if sender_id == 2 && opcode == 0 {
666                        Some(SIGNATURE)
667                    } else {
668                        None
669                    }
670                },
671                |message| {
672                    assert_eq_msgs(&message, &msg);
673                    true
674                },
675            )
676            .unwrap()
677            .unwrap();
678
679        assert_eq!(ret, 1);
680    }
681}