netlink_packet_core/
error.rs

1// SPDX-License-Identifier: MIT
2
3use std::{fmt, io, mem::size_of, num::NonZeroI32};
4
5use byteorder::{ByteOrder, NativeEndian};
6use netlink_packet_utils::DecodeError;
7
8use crate::{Emitable, Field, Parseable, Rest};
9
10const CODE: Field = 0..4;
11const PAYLOAD: Rest = 4..;
12const ERROR_HEADER_LEN: usize = PAYLOAD.start;
13
14#[derive(Debug, PartialEq, Eq, Clone)]
15#[non_exhaustive]
16pub struct ErrorBuffer<T> {
17    buffer: T,
18}
19
20impl<T: AsRef<[u8]>> ErrorBuffer<T> {
21    pub fn new(buffer: T) -> ErrorBuffer<T> {
22        ErrorBuffer { buffer }
23    }
24
25    /// Consume the packet, returning the underlying buffer.
26    pub fn into_inner(self) -> T {
27        self.buffer
28    }
29
30    pub fn new_checked(buffer: T) -> Result<Self, DecodeError> {
31        let packet = Self::new(buffer);
32        packet.check_buffer_length()?;
33        Ok(packet)
34    }
35
36    fn check_buffer_length(&self) -> Result<(), DecodeError> {
37        let len = self.buffer.as_ref().len();
38        if len < ERROR_HEADER_LEN {
39            Err(format!(
40                "invalid ErrorBuffer: length is {len} but ErrorBuffer are \
41                at least {ERROR_HEADER_LEN} bytes"
42            )
43            .into())
44        } else {
45            Ok(())
46        }
47    }
48
49    /// Return the error code.
50    ///
51    /// Returns `None` when there is no error to report (the message is an ACK),
52    /// or a `Some(e)` if there is a non-zero error code `e` to report (the
53    /// message is a NACK).
54    pub fn code(&self) -> Option<NonZeroI32> {
55        let data = self.buffer.as_ref();
56        NonZeroI32::new(NativeEndian::read_i32(&data[CODE]))
57    }
58}
59
60impl<'a, T: AsRef<[u8]> + ?Sized> ErrorBuffer<&'a T> {
61    /// Return a pointer to the payload.
62    pub fn payload(&self) -> &'a [u8] {
63        let data = self.buffer.as_ref();
64        &data[PAYLOAD]
65    }
66}
67
68impl<'a, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> ErrorBuffer<&'a mut T> {
69    /// Return a mutable pointer to the payload.
70    pub fn payload_mut(&mut self) -> &mut [u8] {
71        let data = self.buffer.as_mut();
72        &mut data[PAYLOAD]
73    }
74}
75
76impl<T: AsRef<[u8]> + AsMut<[u8]>> ErrorBuffer<T> {
77    /// set the error code field
78    pub fn set_code(&mut self, value: i32) {
79        let data = self.buffer.as_mut();
80        NativeEndian::write_i32(&mut data[CODE], value)
81    }
82}
83
84/// An `NLMSG_ERROR` message.
85///
86/// Per [RFC 3549 section 2.3.2.2], this message carries the return code for a
87/// request which will indicate either success (an ACK) or failure (a NACK).
88///
89/// [RFC 3549 section 2.3.2.2]: https://datatracker.ietf.org/doc/html/rfc3549#section-2.3.2.2
90#[derive(Debug, Default, Clone, PartialEq, Eq)]
91#[non_exhaustive]
92pub struct ErrorMessage {
93    /// The error code.
94    ///
95    /// Holds `None` when there is no error to report (the message is an ACK),
96    /// or a `Some(e)` if there is a non-zero error code `e` to report (the
97    /// message is a NACK).
98    ///
99    /// See [Netlink message types] for details.
100    ///
101    /// [Netlink message types]: https://kernel.org/doc/html/next/userspace-api/netlink/intro.html#netlink-message-types
102    pub code: Option<NonZeroI32>,
103    /// The original request's header.
104    pub header: Vec<u8>,
105}
106
107impl Emitable for ErrorMessage {
108    fn buffer_len(&self) -> usize {
109        size_of::<i32>() + self.header.len()
110    }
111    fn emit(&self, buffer: &mut [u8]) {
112        let mut buffer = ErrorBuffer::new(buffer);
113        buffer.set_code(self.raw_code());
114        buffer.payload_mut().copy_from_slice(&self.header)
115    }
116}
117
118impl<'buffer, T: AsRef<[u8]> + 'buffer> Parseable<ErrorBuffer<&'buffer T>>
119    for ErrorMessage
120{
121    fn parse(
122        buf: &ErrorBuffer<&'buffer T>,
123    ) -> Result<ErrorMessage, DecodeError> {
124        // FIXME: The payload of an error is basically a truncated packet, which
125        // requires custom logic to parse correctly. For now we just
126        // return it as a Vec<u8> let header: NetlinkHeader = {
127        //     NetlinkBuffer::new_checked(self.payload())
128        //         .context("failed to parse netlink header")?
129        //         .parse()
130        //         .context("failed to parse nelink header")?
131        // };
132        Ok(ErrorMessage {
133            code: buf.code(),
134            header: buf.payload().to_vec(),
135        })
136    }
137}
138
139impl ErrorMessage {
140    /// Returns the raw error code.
141    pub fn raw_code(&self) -> i32 {
142        self.code.map_or(0, NonZeroI32::get)
143    }
144
145    /// According to [`netlink(7)`](https://linux.die.net/man/7/netlink)
146    /// the `NLMSG_ERROR` return Negative errno or 0 for acknowledgements.
147    ///
148    /// convert into [`std::io::Error`](https://doc.rust-lang.org/std/io/struct.Error.html)
149    /// using the absolute value from errno code
150    pub fn to_io(&self) -> io::Error {
151        io::Error::from_raw_os_error(self.raw_code().abs())
152    }
153}
154
155impl fmt::Display for ErrorMessage {
156    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
157        fmt::Display::fmt(&self.to_io(), f)
158    }
159}
160
161impl From<ErrorMessage> for io::Error {
162    fn from(e: ErrorMessage) -> io::Error {
163        e.to_io()
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170
171    #[test]
172    fn into_io_error() {
173        let io_err = io::Error::from_raw_os_error(95);
174        let err_msg = ErrorMessage {
175            code: NonZeroI32::new(-95),
176            header: vec![],
177        };
178
179        let to_io: io::Error = err_msg.to_io();
180
181        assert_eq!(err_msg.to_string(), io_err.to_string());
182        assert_eq!(to_io.raw_os_error(), io_err.raw_os_error());
183    }
184
185    #[test]
186    fn parse_ack() {
187        let bytes = vec![0, 0, 0, 0];
188        let msg = ErrorBuffer::new_checked(&bytes)
189            .and_then(|buf| ErrorMessage::parse(&buf))
190            .expect("failed to parse NLMSG_ERROR");
191        assert_eq!(
192            ErrorMessage {
193                code: None,
194                header: Vec::new()
195            },
196            msg
197        );
198        assert_eq!(msg.raw_code(), 0);
199    }
200
201    #[test]
202    fn parse_nack() {
203        // SAFETY: value is non-zero.
204        const ERROR_CODE: NonZeroI32 =
205            unsafe { NonZeroI32::new_unchecked(-1234) };
206        let mut bytes = vec![0, 0, 0, 0];
207        NativeEndian::write_i32(&mut bytes, ERROR_CODE.get());
208        let msg = ErrorBuffer::new_checked(&bytes)
209            .and_then(|buf| ErrorMessage::parse(&buf))
210            .expect("failed to parse NLMSG_ERROR");
211        assert_eq!(
212            ErrorMessage {
213                code: Some(ERROR_CODE),
214                header: Vec::new()
215            },
216            msg
217        );
218        assert_eq!(msg.raw_code(), ERROR_CODE.get());
219    }
220}