netlink_packet_core/
message.rs1use std::fmt::Debug;
4
5use anyhow::Context;
6use netlink_packet_utils::DecodeError;
7
8use crate::{
9 payload::{NLMSG_DONE, NLMSG_ERROR, NLMSG_NOOP, NLMSG_OVERRUN},
10 DoneBuffer, DoneMessage, Emitable, ErrorBuffer, ErrorMessage,
11 NetlinkBuffer, NetlinkDeserializable, NetlinkHeader, NetlinkPayload,
12 NetlinkSerializable, Parseable,
13};
14
15#[derive(Debug, PartialEq, Eq, Clone)]
17#[non_exhaustive]
18pub struct NetlinkMessage<I> {
19 pub header: NetlinkHeader,
21 pub payload: NetlinkPayload<I>,
23}
24
25impl<I> NetlinkMessage<I> {
26 pub fn new(header: NetlinkHeader, payload: NetlinkPayload<I>) -> Self {
28 NetlinkMessage { header, payload }
29 }
30
31 pub fn into_parts(self) -> (NetlinkHeader, NetlinkPayload<I>) {
33 (self.header, self.payload)
34 }
35}
36
37impl<I> NetlinkMessage<I>
38where
39 I: NetlinkDeserializable,
40{
41 pub fn deserialize(buffer: &[u8]) -> Result<Self, DecodeError> {
43 let netlink_buffer = NetlinkBuffer::new_checked(&buffer)?;
44 <Self as Parseable<NetlinkBuffer<&&[u8]>>>::parse(&netlink_buffer)
45 }
46}
47
48impl<I> NetlinkMessage<I>
49where
50 I: NetlinkSerializable,
51{
52 pub fn buffer_len(&self) -> usize {
54 <Self as Emitable>::buffer_len(self)
55 }
56
57 pub fn serialize(&self, buffer: &mut [u8]) {
66 self.emit(buffer)
67 }
68
69 pub fn finalize(&mut self) {
81 self.header.length = self.buffer_len() as u32;
82 self.header.message_type = self.payload.message_type();
83 }
84}
85
86impl<'buffer, B, I> Parseable<NetlinkBuffer<&'buffer B>> for NetlinkMessage<I>
87where
88 B: AsRef<[u8]> + 'buffer,
89 I: NetlinkDeserializable,
90{
91 fn parse(buf: &NetlinkBuffer<&'buffer B>) -> Result<Self, DecodeError> {
92 use self::NetlinkPayload::*;
93
94 let header =
95 <NetlinkHeader as Parseable<NetlinkBuffer<&'buffer B>>>::parse(buf)
96 .context("failed to parse netlink header")?;
97
98 let bytes = buf.payload();
99 let payload = match header.message_type {
100 NLMSG_ERROR => {
101 let msg = ErrorBuffer::new_checked(&bytes)
102 .and_then(|buf| ErrorMessage::parse(&buf))
103 .context("failed to parse NLMSG_ERROR")?;
104 Error(msg)
105 }
106 NLMSG_NOOP => Noop,
107 NLMSG_DONE => {
108 let msg = DoneBuffer::new_checked(&bytes)
109 .and_then(|buf| DoneMessage::parse(&buf))
110 .context("failed to parse NLMSG_DONE")?;
111 Done(msg)
112 }
113 NLMSG_OVERRUN => Overrun(bytes.to_vec()),
114 message_type => {
115 let inner_msg = I::deserialize(&header, bytes).context(
116 format!("Failed to parse message with type {message_type}"),
117 )?;
118 InnerMessage(inner_msg)
119 }
120 };
121 Ok(NetlinkMessage { header, payload })
122 }
123}
124
125impl<I> Emitable for NetlinkMessage<I>
126where
127 I: NetlinkSerializable,
128{
129 fn buffer_len(&self) -> usize {
130 use self::NetlinkPayload::*;
131
132 let payload_len = match self.payload {
133 Noop => 0,
134 Done(ref msg) => msg.buffer_len(),
135 Overrun(ref bytes) => bytes.len(),
136 Error(ref msg) => msg.buffer_len(),
137 InnerMessage(ref msg) => msg.buffer_len(),
138 };
139
140 self.header.buffer_len() + payload_len
141 }
142
143 fn emit(&self, buffer: &mut [u8]) {
144 use self::NetlinkPayload::*;
145
146 self.header.emit(buffer);
147
148 let buffer =
149 &mut buffer[self.header.buffer_len()..self.header.length as usize];
150 match self.payload {
151 Noop => {}
152 Done(ref msg) => msg.emit(buffer),
153 Overrun(ref bytes) => buffer.copy_from_slice(bytes),
154 Error(ref msg) => msg.emit(buffer),
155 InnerMessage(ref msg) => msg.serialize(buffer),
156 }
157 }
158}
159
160impl<T> From<T> for NetlinkMessage<T>
161where
162 T: Into<NetlinkPayload<T>>,
163{
164 fn from(inner_message: T) -> Self {
165 NetlinkMessage {
166 header: NetlinkHeader::default(),
167 payload: inner_message.into(),
168 }
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 use std::{convert::Infallible, mem::size_of, num::NonZeroI32};
177
178 #[derive(Clone, Debug, Default, PartialEq)]
179 struct FakeNetlinkInnerMessage;
180
181 impl NetlinkSerializable for FakeNetlinkInnerMessage {
182 fn message_type(&self) -> u16 {
183 unimplemented!("unused by tests")
184 }
185
186 fn buffer_len(&self) -> usize {
187 unimplemented!("unused by tests")
188 }
189
190 fn serialize(&self, _buffer: &mut [u8]) {
191 unimplemented!("unused by tests")
192 }
193 }
194
195 impl NetlinkDeserializable for FakeNetlinkInnerMessage {
196 type Error = Infallible;
197
198 fn deserialize(
199 _header: &NetlinkHeader,
200 _payload: &[u8],
201 ) -> Result<Self, Self::Error> {
202 unimplemented!("unused by tests")
203 }
204 }
205
206 #[test]
207 fn test_done() {
208 let header = NetlinkHeader::default();
209 let done_msg = DoneMessage {
210 code: 0,
211 extended_ack: vec![6, 7, 8, 9],
212 };
213 let mut want = NetlinkMessage::new(
214 header,
215 NetlinkPayload::<FakeNetlinkInnerMessage>::Done(done_msg.clone()),
216 );
217 want.finalize();
218
219 let len = want.buffer_len();
220 assert_eq!(
221 len,
222 header.buffer_len()
223 + size_of::<i32>()
224 + done_msg.extended_ack.len()
225 );
226
227 let mut buf = vec![1; len];
228 want.emit(&mut buf);
229
230 let done_buf = DoneBuffer::new(&buf[header.buffer_len()..]);
231 assert_eq!(done_buf.code(), done_msg.code);
232 assert_eq!(done_buf.extended_ack(), &done_msg.extended_ack);
233
234 let got = NetlinkMessage::parse(&NetlinkBuffer::new(&buf)).unwrap();
235 assert_eq!(got, want);
236 }
237
238 #[test]
239 fn test_error() {
240 const ERROR_CODE: NonZeroI32 =
242 unsafe { NonZeroI32::new_unchecked(-8765) };
243
244 let header = NetlinkHeader::default();
245 let error_msg = ErrorMessage {
246 code: Some(ERROR_CODE),
247 header: vec![],
248 };
249 let mut want = NetlinkMessage::new(
250 header,
251 NetlinkPayload::<FakeNetlinkInnerMessage>::Error(error_msg.clone()),
252 );
253 want.finalize();
254
255 let len = want.buffer_len();
256 assert_eq!(len, header.buffer_len() + error_msg.buffer_len());
257
258 let mut buf = vec![1; len];
259 want.emit(&mut buf);
260
261 let error_buf = ErrorBuffer::new(&buf[header.buffer_len()..]);
262 assert_eq!(error_buf.code(), error_msg.code);
263
264 let got = NetlinkMessage::parse(&NetlinkBuffer::new(&buf)).unwrap();
265 assert_eq!(got, want);
266 }
267}