netlink_proto/framed.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
// SPDX-License-Identifier: MIT
use bytes::BytesMut;
use std::{
fmt::Debug,
io,
marker::PhantomData,
mem::size_of,
pin::Pin,
task::{Context, Poll},
};
use futures::{Sink, Stream};
use log::error;
use crate::{
codecs::NetlinkMessageCodec,
sys::{AsyncSocket, SocketAddr},
};
use netlink_packet_core::{
NetlinkDeserializable, NetlinkHeader, NetlinkMessage, NetlinkPayload,
NetlinkSerializable, NLMSG_OVERRUN,
};
/// Buffer overrun condition
const ENOBUFS: i32 = 105;
pub struct NetlinkFramed<T, S, C> {
socket: S,
// see https://doc.rust-lang.org/nomicon/phantom-data.html
// "invariant" seems like the safe choice; using `fn(T) -> T`
// should make it invariant but still Send+Sync.
msg_type: PhantomData<fn(T) -> T>, // invariant
codec: PhantomData<fn(C) -> C>, // invariant
reader: BytesMut,
writer: BytesMut,
in_addr: SocketAddr,
out_addr: SocketAddr,
flushed: bool,
}
impl<T, S, C> Stream for NetlinkFramed<T, S, C>
where
T: NetlinkDeserializable + Debug,
S: AsyncSocket,
C: NetlinkMessageCodec,
{
type Item = (NetlinkMessage<T>, SocketAddr);
fn poll_next(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let Self {
ref mut socket,
ref mut in_addr,
ref mut reader,
..
} = Pin::get_mut(self);
loop {
match C::decode::<T>(reader) {
Ok(Some(item)) => return Poll::Ready(Some((item, *in_addr))),
Ok(None) => {}
Err(e) => {
error!("unrecoverable error in decoder: {:?}", e);
return Poll::Ready(None);
}
}
reader.clear();
reader.reserve(INITIAL_READER_CAPACITY);
*in_addr = match ready!(socket.poll_recv_from(cx, reader)) {
Ok(addr) => addr,
// When receiving messages in multicast mode (i.e. we subscribed
// to notifications), the kernel will not wait
// for us to read datagrams before sending more.
// The receive buffer has a finite size, so once it is full (no
// more message can fit in), new messages will be dropped and
// recv calls will return `ENOBUFS`.
// This needs to be handled for applications to resynchronize
// with the contents of the kernel if necessary.
// We don't need to do anything special:
// - contents of the reader is still valid because we won't have
// partial messages in there anyways (large enough buffer)
// - contents of the socket's internal buffer is still valid
// because the kernel won't put partial data in it
Err(e) if e.raw_os_error() == Some(ENOBUFS) => {
warn!("netlink socket buffer full");
let mut hdr = NetlinkHeader::default();
hdr.length = size_of::<NetlinkHeader>() as u32;
hdr.message_type = NLMSG_OVERRUN;
let msg = NetlinkMessage::new(
hdr,
NetlinkPayload::Overrun(Vec::new()),
);
return Poll::Ready(Some((msg, SocketAddr::new(0, 0))));
}
Err(e) => {
error!("failed to read from netlink socket: {:?}", e);
return Poll::Ready(None);
}
};
}
}
}
impl<T, S, C> Sink<(NetlinkMessage<T>, SocketAddr)> for NetlinkFramed<T, S, C>
where
T: NetlinkSerializable + Debug,
S: AsyncSocket,
C: NetlinkMessageCodec,
{
type Error = io::Error;
fn poll_ready(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
if !self.flushed {
match self.poll_flush(cx)? {
Poll::Ready(()) => {}
Poll::Pending => return Poll::Pending,
}
}
Poll::Ready(Ok(()))
}
fn start_send(
self: Pin<&mut Self>,
item: (NetlinkMessage<T>, SocketAddr),
) -> Result<(), Self::Error> {
trace!("sending frame");
let (frame, out_addr) = item;
let pin = self.get_mut();
C::encode(frame, &mut pin.writer)?;
pin.out_addr = out_addr;
pin.flushed = false;
trace!("frame encoded; length={}", pin.writer.len());
Ok(())
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
if self.flushed {
return Poll::Ready(Ok(()));
}
trace!("flushing frame; length={}", self.writer.len());
let Self {
ref mut socket,
ref mut out_addr,
ref mut writer,
..
} = *self;
let n = ready!(socket.poll_send_to(cx, writer, out_addr))?;
trace!("written {}", n);
let wrote_all = n == self.writer.len();
self.writer.clear();
self.flushed = true;
let res = if wrote_all {
Ok(())
} else {
Err(io::Error::new(
io::ErrorKind::Other,
"failed to write entire datagram to socket",
))
};
Poll::Ready(res)
}
fn poll_close(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
ready!(self.poll_flush(cx))?;
Poll::Ready(Ok(()))
}
}
// The theoritical max netlink packet size is 32KB for a netlink
// message since Linux 4.9 (16KB before). See:
// https://git.kernel.org/pub/scm/linux/kernel/git/davem/net-next.git/commit/?id=d35c99ff77ecb2eb239731b799386f3b3637a31e
const INITIAL_READER_CAPACITY: usize = 64 * 1024;
const INITIAL_WRITER_CAPACITY: usize = 8 * 1024;
impl<T, S, C> NetlinkFramed<T, S, C> {
/// Create a new `NetlinkFramed` backed by the given socket and codec.
///
/// See struct level documentation for more details.
pub fn new(socket: S) -> Self {
Self {
socket,
msg_type: PhantomData,
codec: PhantomData,
out_addr: SocketAddr::new(0, 0),
in_addr: SocketAddr::new(0, 0),
reader: BytesMut::with_capacity(INITIAL_READER_CAPACITY),
writer: BytesMut::with_capacity(INITIAL_WRITER_CAPACITY),
flushed: true,
}
}
/// Returns a reference to the underlying I/O stream wrapped by `Framed`.
///
/// # Note
///
/// Care should be taken to not tamper with the underlying stream of data
/// coming in as it may corrupt the stream of frames otherwise being worked
/// with.
pub fn get_ref(&self) -> &S {
&self.socket
}
/// Returns a mutable reference to the underlying I/O stream wrapped by
/// `Framed`.
///
/// # Note
///
/// Care should be taken to not tamper with the underlying stream of data
/// coming in as it may corrupt the stream of frames otherwise being worked
/// with.
pub fn get_mut(&mut self) -> &mut S {
&mut self.socket
}
/// Consumes the `Framed`, returning its underlying I/O stream.
pub fn into_inner(self) -> S {
self.socket
}
}