netlink_packet_core/
message.rs

1// SPDX-License-Identifier: MIT
2
3use std::fmt::Debug;
4
5use anyhow::Context;
6use netlink_packet_utils::DecodeError;
7
8use crate::{
9    payload::{NLMSG_DONE, NLMSG_ERROR, NLMSG_NOOP, NLMSG_OVERRUN},
10    DoneBuffer, DoneMessage, Emitable, ErrorBuffer, ErrorMessage,
11    NetlinkBuffer, NetlinkDeserializable, NetlinkHeader, NetlinkPayload,
12    NetlinkSerializable, Parseable,
13};
14
15/// Represent a netlink message.
16#[derive(Debug, PartialEq, Eq, Clone)]
17#[non_exhaustive]
18pub struct NetlinkMessage<I> {
19    /// Message header (this is common to all the netlink protocols)
20    pub header: NetlinkHeader,
21    /// Inner message, which depends on the netlink protocol being used.
22    pub payload: NetlinkPayload<I>,
23}
24
25impl<I> NetlinkMessage<I> {
26    /// Create a new netlink message from the given header and payload
27    pub fn new(header: NetlinkHeader, payload: NetlinkPayload<I>) -> Self {
28        NetlinkMessage { header, payload }
29    }
30
31    /// Consume this message and return its header and payload
32    pub fn into_parts(self) -> (NetlinkHeader, NetlinkPayload<I>) {
33        (self.header, self.payload)
34    }
35}
36
37impl<I> NetlinkMessage<I>
38where
39    I: NetlinkDeserializable,
40{
41    /// Parse the given buffer as a netlink message
42    pub fn deserialize(buffer: &[u8]) -> Result<Self, DecodeError> {
43        let netlink_buffer = NetlinkBuffer::new_checked(&buffer)?;
44        <Self as Parseable<NetlinkBuffer<&&[u8]>>>::parse(&netlink_buffer)
45    }
46}
47
48impl<I> NetlinkMessage<I>
49where
50    I: NetlinkSerializable,
51{
52    /// Return the length of this message in bytes
53    pub fn buffer_len(&self) -> usize {
54        <Self as Emitable>::buffer_len(self)
55    }
56
57    /// Serialize this message and write the serialized data into the
58    /// given buffer. `buffer` must big large enough for the whole
59    /// message to fit, otherwise, this method will panic. To know how
60    /// big the serialized message is, call `buffer_len()`.
61    ///
62    /// # Panic
63    ///
64    /// This method panics if the buffer is not big enough.
65    pub fn serialize(&self, buffer: &mut [u8]) {
66        self.emit(buffer)
67    }
68
69    /// Ensure the header (`NetlinkHeader`) is consistent with the payload
70    /// (`NetlinkPayload`):
71    ///
72    /// - compute the payload length and set the header's length field
73    /// - check the payload type and set the header's message type field
74    ///   accordingly
75    ///
76    /// If you are not 100% sure the header is correct, this method should be
77    /// called before calling [`Emitable::emit()`](trait.Emitable.html#
78    /// tymethod.emit), as it could panic if the header is inconsistent with
79    /// the rest of the message.
80    pub fn finalize(&mut self) {
81        self.header.length = self.buffer_len() as u32;
82        self.header.message_type = self.payload.message_type();
83    }
84}
85
86impl<'buffer, B, I> Parseable<NetlinkBuffer<&'buffer B>> for NetlinkMessage<I>
87where
88    B: AsRef<[u8]> + 'buffer,
89    I: NetlinkDeserializable,
90{
91    fn parse(buf: &NetlinkBuffer<&'buffer B>) -> Result<Self, DecodeError> {
92        use self::NetlinkPayload::*;
93
94        let header =
95            <NetlinkHeader as Parseable<NetlinkBuffer<&'buffer B>>>::parse(buf)
96                .context("failed to parse netlink header")?;
97
98        let bytes = buf.payload();
99        let payload = match header.message_type {
100            NLMSG_ERROR => {
101                let msg = ErrorBuffer::new_checked(&bytes)
102                    .and_then(|buf| ErrorMessage::parse(&buf))
103                    .context("failed to parse NLMSG_ERROR")?;
104                Error(msg)
105            }
106            NLMSG_NOOP => Noop,
107            NLMSG_DONE => {
108                let msg = DoneBuffer::new_checked(&bytes)
109                    .and_then(|buf| DoneMessage::parse(&buf))
110                    .context("failed to parse NLMSG_DONE")?;
111                Done(msg)
112            }
113            NLMSG_OVERRUN => Overrun(bytes.to_vec()),
114            message_type => {
115                let inner_msg = I::deserialize(&header, bytes).context(
116                    format!("Failed to parse message with type {message_type}"),
117                )?;
118                InnerMessage(inner_msg)
119            }
120        };
121        Ok(NetlinkMessage { header, payload })
122    }
123}
124
125impl<I> Emitable for NetlinkMessage<I>
126where
127    I: NetlinkSerializable,
128{
129    fn buffer_len(&self) -> usize {
130        use self::NetlinkPayload::*;
131
132        let payload_len = match self.payload {
133            Noop => 0,
134            Done(ref msg) => msg.buffer_len(),
135            Overrun(ref bytes) => bytes.len(),
136            Error(ref msg) => msg.buffer_len(),
137            InnerMessage(ref msg) => msg.buffer_len(),
138        };
139
140        self.header.buffer_len() + payload_len
141    }
142
143    fn emit(&self, buffer: &mut [u8]) {
144        use self::NetlinkPayload::*;
145
146        self.header.emit(buffer);
147
148        let buffer =
149            &mut buffer[self.header.buffer_len()..self.header.length as usize];
150        match self.payload {
151            Noop => {}
152            Done(ref msg) => msg.emit(buffer),
153            Overrun(ref bytes) => buffer.copy_from_slice(bytes),
154            Error(ref msg) => msg.emit(buffer),
155            InnerMessage(ref msg) => msg.serialize(buffer),
156        }
157    }
158}
159
160impl<T> From<T> for NetlinkMessage<T>
161where
162    T: Into<NetlinkPayload<T>>,
163{
164    fn from(inner_message: T) -> Self {
165        NetlinkMessage {
166            header: NetlinkHeader::default(),
167            payload: inner_message.into(),
168        }
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    use std::{convert::Infallible, mem::size_of, num::NonZeroI32};
177
178    #[derive(Clone, Debug, Default, PartialEq)]
179    struct FakeNetlinkInnerMessage;
180
181    impl NetlinkSerializable for FakeNetlinkInnerMessage {
182        fn message_type(&self) -> u16 {
183            unimplemented!("unused by tests")
184        }
185
186        fn buffer_len(&self) -> usize {
187            unimplemented!("unused by tests")
188        }
189
190        fn serialize(&self, _buffer: &mut [u8]) {
191            unimplemented!("unused by tests")
192        }
193    }
194
195    impl NetlinkDeserializable for FakeNetlinkInnerMessage {
196        type Error = Infallible;
197
198        fn deserialize(
199            _header: &NetlinkHeader,
200            _payload: &[u8],
201        ) -> Result<Self, Self::Error> {
202            unimplemented!("unused by tests")
203        }
204    }
205
206    #[test]
207    fn test_done() {
208        let header = NetlinkHeader::default();
209        let done_msg = DoneMessage {
210            code: 0,
211            extended_ack: vec![6, 7, 8, 9],
212        };
213        let mut want = NetlinkMessage::new(
214            header,
215            NetlinkPayload::<FakeNetlinkInnerMessage>::Done(done_msg.clone()),
216        );
217        want.finalize();
218
219        let len = want.buffer_len();
220        assert_eq!(
221            len,
222            header.buffer_len()
223                + size_of::<i32>()
224                + done_msg.extended_ack.len()
225        );
226
227        let mut buf = vec![1; len];
228        want.emit(&mut buf);
229
230        let done_buf = DoneBuffer::new(&buf[header.buffer_len()..]);
231        assert_eq!(done_buf.code(), done_msg.code);
232        assert_eq!(done_buf.extended_ack(), &done_msg.extended_ack);
233
234        let got = NetlinkMessage::parse(&NetlinkBuffer::new(&buf)).unwrap();
235        assert_eq!(got, want);
236    }
237
238    #[test]
239    fn test_error() {
240        // SAFETY: value is non-zero.
241        const ERROR_CODE: NonZeroI32 =
242            unsafe { NonZeroI32::new_unchecked(-8765) };
243
244        let header = NetlinkHeader::default();
245        let error_msg = ErrorMessage {
246            code: Some(ERROR_CODE),
247            header: vec![],
248        };
249        let mut want = NetlinkMessage::new(
250            header,
251            NetlinkPayload::<FakeNetlinkInnerMessage>::Error(error_msg.clone()),
252        );
253        want.finalize();
254
255        let len = want.buffer_len();
256        assert_eq!(len, header.buffer_len() + error_msg.buffer_len());
257
258        let mut buf = vec![1; len];
259        want.emit(&mut buf);
260
261        let error_buf = ErrorBuffer::new(&buf[header.buffer_len()..]);
262        assert_eq!(error_buf.code(), error_msg.code);
263
264        let got = NetlinkMessage::parse(&NetlinkBuffer::new(&buf)).unwrap();
265        assert_eq!(got, want);
266    }
267}