use std::time::Duration;
use anyhow::{bail, ensure};
use bytes::{Buf, BufMut, Bytes, BytesMut};
#[cfg(feature = "iroh-relay")]
use futures_lite::{Stream, StreamExt};
use futures_sink::Sink;
use futures_util::SinkExt;
use iroh_base::key::{Signature, PUBLIC_KEY_LENGTH};
use postcard::experimental::max_size::MaxSize;
use serde::{Deserialize, Serialize};
use tokio_util::codec::{Decoder, Encoder};
use crate::key::{PublicKey, SecretKey};
pub const MAX_PACKET_SIZE: usize = 64 * 1024;
const MAX_FRAME_SIZE: usize = 1024 * 1024;
const MAGIC: &str = "RELAY🔑";
#[cfg(feature = "iroh-relay")]
#[cfg_attr(iroh_docsrs, doc(cfg(feature = "iroh-relay")))]
pub(super) const KEEP_ALIVE: Duration = Duration::from_secs(60);
#[cfg(feature = "iroh-relay")]
#[cfg_attr(iroh_docsrs, doc(cfg(feature = "iroh-relay")))]
pub(super) const SERVER_CHANNEL_SIZE: usize = 1024 * 100;
pub(super) const PER_CLIENT_SEND_QUEUE_DEPTH: usize = 512; pub(super) const PER_CLIENT_READ_QUEUE_DEPTH: usize = 512;
pub(super) const PROTOCOL_VERSION: usize = 3;
const PREFERRED: u8 = 1u8;
const NOT_PREFERRED: u8 = 0u8;
#[derive(Debug, PartialEq, Eq, num_enum::IntoPrimitive, num_enum::FromPrimitive, Clone, Copy)]
#[repr(u8)]
pub(crate) enum FrameType {
ClientInfo = 2,
SendPacket = 4,
RecvPacket = 5,
KeepAlive = 6,
NotePreferred = 7,
PeerGone = 8,
Ping = 12,
Pong = 13,
Health = 14,
Restarting = 15,
#[num_enum(default)]
Unknown = 255,
}
impl std::fmt::Display for FrameType {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
#[derive(Debug, Serialize, Deserialize, MaxSize, PartialEq, Eq)]
pub(crate) struct ClientInfo {
pub(crate) version: usize,
}
pub(super) async fn write_frame<S: Sink<Frame, Error = std::io::Error> + Unpin>(
mut writer: S,
frame: Frame,
timeout: Option<Duration>,
) -> anyhow::Result<()> {
if let Some(duration) = timeout {
tokio::time::timeout(duration, writer.send(frame)).await??;
} else {
writer.send(frame).await?;
}
Ok(())
}
pub(crate) async fn send_client_key<S: Sink<Frame, Error = std::io::Error> + Unpin>(
mut writer: S,
client_secret_key: &SecretKey,
client_info: &ClientInfo,
) -> anyhow::Result<()> {
let msg = postcard::to_stdvec(client_info)?;
let signature = client_secret_key.sign(&msg);
writer
.send(Frame::ClientInfo {
client_public_key: client_secret_key.public(),
message: msg.into(),
signature,
})
.await?;
writer.flush().await?;
Ok(())
}
#[cfg(feature = "iroh-relay")]
#[cfg_attr(iroh_docsrs, doc(cfg(feature = "iroh-relay")))]
pub(super) async fn recv_client_key<S: Stream<Item = anyhow::Result<Frame>> + Unpin>(
stream: S,
) -> anyhow::Result<(PublicKey, ClientInfo)> {
use anyhow::Context;
let buf = tokio::time::timeout(
Duration::from_secs(10),
recv_frame(FrameType::ClientInfo, stream),
)
.await
.context("recv_frame timeout")?
.context("recv_frame")?;
if let Frame::ClientInfo {
client_public_key,
message,
signature,
} = buf
{
client_public_key
.verify(&message, &signature)
.context("invalid signature")?;
let info: ClientInfo = postcard::from_bytes(&message).context("deserialization")?;
Ok((client_public_key, info))
} else {
anyhow::bail!("expected FrameType::ClientInfo");
}
}
#[derive(Debug, Default, Clone)]
pub(crate) struct DerpCodec;
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum Frame {
ClientInfo {
client_public_key: PublicKey,
message: Bytes,
signature: Signature,
},
SendPacket {
dst_key: PublicKey,
packet: Bytes,
},
RecvPacket {
src_key: PublicKey,
content: Bytes,
},
KeepAlive,
NotePreferred {
preferred: bool,
},
PeerGone {
peer: PublicKey,
},
Ping {
data: [u8; 8],
},
Pong {
data: [u8; 8],
},
Health {
problem: Bytes,
},
Restarting {
reconnect_in: u32,
try_for: u32,
},
}
impl Frame {
pub(super) fn typ(&self) -> FrameType {
match self {
Frame::ClientInfo { .. } => FrameType::ClientInfo,
Frame::SendPacket { .. } => FrameType::SendPacket,
Frame::RecvPacket { .. } => FrameType::RecvPacket,
Frame::KeepAlive => FrameType::KeepAlive,
Frame::NotePreferred { .. } => FrameType::NotePreferred,
Frame::PeerGone { .. } => FrameType::PeerGone,
Frame::Ping { .. } => FrameType::Ping,
Frame::Pong { .. } => FrameType::Pong,
Frame::Health { .. } => FrameType::Health,
Frame::Restarting { .. } => FrameType::Restarting,
}
}
pub(super) fn len(&self) -> usize {
match self {
Frame::ClientInfo {
client_public_key: _,
message,
signature: _,
} => MAGIC.as_bytes().len() + PUBLIC_KEY_LENGTH + message.len() + Signature::BYTE_SIZE,
Frame::SendPacket { dst_key: _, packet } => PUBLIC_KEY_LENGTH + packet.len(),
Frame::RecvPacket {
src_key: _,
content,
} => PUBLIC_KEY_LENGTH + content.len(),
Frame::KeepAlive => 0,
Frame::NotePreferred { .. } => 1,
Frame::PeerGone { .. } => PUBLIC_KEY_LENGTH,
Frame::Ping { .. } => 8,
Frame::Pong { .. } => 8,
Frame::Health { problem } => problem.len(),
Frame::Restarting { .. } => 4 + 4,
}
}
pub(crate) fn decode_from_ws_msg(vec: Vec<u8>) -> anyhow::Result<Self> {
if vec.is_empty() {
bail!("error parsing relay::codec::Frame: too few bytes (0)");
}
let bytes = Bytes::from(vec);
let typ = FrameType::from(bytes[0]);
let frame = Self::from_bytes(typ, bytes.slice(1..))?;
Ok(frame)
}
pub(crate) fn encode_for_ws_msg(self) -> Vec<u8> {
let mut bytes = Vec::new();
bytes.put_u8(self.typ().into());
self.write_to(&mut bytes);
bytes
}
fn write_to(&self, dst: &mut impl BufMut) {
match self {
Frame::ClientInfo {
client_public_key,
message,
signature,
} => {
dst.put(MAGIC.as_bytes());
dst.put(client_public_key.as_ref());
dst.put(&signature.to_bytes()[..]);
dst.put(&message[..]);
}
Frame::SendPacket { dst_key, packet } => {
dst.put(dst_key.as_ref());
dst.put(packet.as_ref());
}
Frame::RecvPacket { src_key, content } => {
dst.put(src_key.as_ref());
dst.put(content.as_ref());
}
Frame::KeepAlive => {}
Frame::NotePreferred { preferred } => {
if *preferred {
dst.put_u8(PREFERRED);
} else {
dst.put_u8(NOT_PREFERRED);
}
}
Frame::PeerGone { peer } => {
dst.put(peer.as_ref());
}
Frame::Ping { data } => {
dst.put(&data[..]);
}
Frame::Pong { data } => {
dst.put(&data[..]);
}
Frame::Health { problem } => {
dst.put(problem.as_ref());
}
Frame::Restarting {
reconnect_in,
try_for,
} => {
dst.put_u32(*reconnect_in);
dst.put_u32(*try_for);
}
}
}
fn from_bytes(frame_type: FrameType, content: Bytes) -> anyhow::Result<Self> {
let res = match frame_type {
FrameType::ClientInfo => {
ensure!(
content.len()
>= PUBLIC_KEY_LENGTH + Signature::BYTE_SIZE + MAGIC.as_bytes().len(),
"invalid client info frame length: {}",
content.len()
);
ensure!(
&content[..MAGIC.as_bytes().len()] == MAGIC.as_bytes(),
"invalid client info frame magic"
);
let start = MAGIC.as_bytes().len();
let client_public_key =
PublicKey::try_from(&content[start..start + PUBLIC_KEY_LENGTH])?;
let start = start + PUBLIC_KEY_LENGTH;
let signature =
Signature::from_slice(&content[start..start + Signature::BYTE_SIZE])?;
let start = start + Signature::BYTE_SIZE;
let message = content.slice(start..);
Self::ClientInfo {
client_public_key,
message,
signature,
}
}
FrameType::SendPacket => {
ensure!(
content.len() >= PUBLIC_KEY_LENGTH,
"invalid send packet frame length: {}",
content.len()
);
let packet_len = content.len() - PUBLIC_KEY_LENGTH;
ensure!(
packet_len <= MAX_PACKET_SIZE,
"data packet longer ({packet_len}) than max of {MAX_PACKET_SIZE}"
);
let dst_key = PublicKey::try_from(&content[..PUBLIC_KEY_LENGTH])?;
let packet = content.slice(PUBLIC_KEY_LENGTH..);
Self::SendPacket { dst_key, packet }
}
FrameType::RecvPacket => {
ensure!(
content.len() >= PUBLIC_KEY_LENGTH,
"invalid recv packet frame length: {}",
content.len()
);
let packet_len = content.len() - PUBLIC_KEY_LENGTH;
ensure!(
packet_len <= MAX_PACKET_SIZE,
"data packet longer ({packet_len}) than max of {MAX_PACKET_SIZE}"
);
let src_key = PublicKey::try_from(&content[..PUBLIC_KEY_LENGTH])?;
let content = content.slice(PUBLIC_KEY_LENGTH..);
Self::RecvPacket { src_key, content }
}
FrameType::KeepAlive => {
anyhow::ensure!(content.is_empty(), "invalid keep alive frame length");
Self::KeepAlive
}
FrameType::NotePreferred => {
anyhow::ensure!(content.len() == 1, "invalid note preferred frame length");
let preferred = match content[0] {
PREFERRED => true,
NOT_PREFERRED => false,
_ => anyhow::bail!("invalid note preferred frame content"),
};
Self::NotePreferred { preferred }
}
FrameType::PeerGone => {
anyhow::ensure!(
content.len() == PUBLIC_KEY_LENGTH,
"invalid peer gone frame length"
);
let peer = PublicKey::try_from(&content[..32])?;
Self::PeerGone { peer }
}
FrameType::Ping => {
anyhow::ensure!(content.len() == 8, "invalid ping frame length");
let mut data = [0u8; 8];
data.copy_from_slice(&content[..8]);
Self::Ping { data }
}
FrameType::Pong => {
anyhow::ensure!(content.len() == 8, "invalid pong frame length");
let mut data = [0u8; 8];
data.copy_from_slice(&content[..8]);
Self::Pong { data }
}
FrameType::Health => Self::Health { problem: content },
FrameType::Restarting => {
ensure!(
content.len() == 4 + 4,
"invalid restarting frame length: {}",
content.len()
);
let reconnect_in = u32::from_be_bytes(content[..4].try_into()?);
let try_for = u32::from_be_bytes(content[4..].try_into()?);
Self::Restarting {
reconnect_in,
try_for,
}
}
_ => {
anyhow::bail!("invalid frame type: {:?}", frame_type);
}
};
Ok(res)
}
}
const HEADER_LEN: usize = 5;
impl Decoder for DerpCodec {
type Item = Frame;
type Error = anyhow::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if src.len() < HEADER_LEN {
return Ok(None);
}
let Some(frame_type) = src.first().map(|b| FrameType::from(*b)) else {
return Ok(None); };
let Some(frame_len) = src
.get(1..5)
.and_then(|s| TryInto::<[u8; 4]>::try_into(s).ok())
.map(u32::from_be_bytes)
.map(|l| l as usize)
else {
return Ok(None); };
if frame_len > MAX_FRAME_SIZE {
anyhow::bail!("Frame of length {} is too large.", frame_len);
}
if src.len() < HEADER_LEN + frame_len {
src.reserve(HEADER_LEN + frame_len - src.len());
return Ok(None);
}
src.advance(HEADER_LEN);
let content = src.split_to(frame_len).freeze();
let frame = Frame::from_bytes(frame_type, content)?;
Ok(Some(frame))
}
}
impl Encoder<Frame> for DerpCodec {
type Error = std::io::Error;
fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
let frame_len: usize = frame.len();
if frame_len > MAX_FRAME_SIZE {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Frame of length {} is too large.", frame_len),
));
}
let frame_len_u32 = u32::try_from(frame_len).expect("just checked");
dst.reserve(HEADER_LEN + frame_len);
dst.put_u8(frame.typ().into());
dst.put_u32(frame_len_u32);
frame.write_to(dst);
Ok(())
}
}
#[cfg(feature = "iroh-relay")]
#[cfg_attr(iroh_docsrs, doc(cfg(feature = "iroh-relay")))]
pub(super) async fn recv_frame<S: Stream<Item = anyhow::Result<Frame>> + Unpin>(
frame_type: FrameType,
mut stream: S,
) -> anyhow::Result<Frame> {
match stream.next().await {
Some(Ok(frame)) => {
ensure!(
frame_type == frame.typ(),
"expected frame {}, found {}",
frame_type,
frame.typ()
);
Ok(frame)
}
Some(Err(err)) => Err(err),
None => bail!("EOF: unexpected stream end, expected frame {}", frame_type),
}
}
#[cfg(test)]
mod tests {
use tokio_util::codec::{FramedRead, FramedWrite};
use super::*;
#[tokio::test]
async fn test_basic_read_write() -> anyhow::Result<()> {
let (reader, writer) = tokio::io::duplex(1024);
let mut reader = FramedRead::new(reader, DerpCodec);
let mut writer = FramedWrite::new(writer, DerpCodec);
let expect_buf = b"hello world!";
let expected_frame = Frame::Health {
problem: expect_buf.to_vec().into(),
};
write_frame(&mut writer, expected_frame.clone(), None).await?;
writer.flush().await?;
println!("{:?}", reader);
let buf = recv_frame(FrameType::Health, &mut reader).await?;
assert_eq!(expect_buf.len(), buf.len());
assert_eq!(expected_frame, buf);
Ok(())
}
#[tokio::test]
async fn test_send_recv_client_key() -> anyhow::Result<()> {
let (reader, writer) = tokio::io::duplex(1024);
let mut reader = FramedRead::new(reader, DerpCodec);
let mut writer = FramedWrite::new(writer, DerpCodec);
let client_key = SecretKey::generate();
let client_info = ClientInfo {
version: PROTOCOL_VERSION,
};
println!("client_key pub {:?}", client_key.public());
send_client_key(&mut writer, &client_key, &client_info).await?;
let (client_pub_key, got_client_info) = recv_client_key(&mut reader).await?;
assert_eq!(client_key.public(), client_pub_key);
assert_eq!(client_info, got_client_info);
Ok(())
}
#[test]
fn test_frame_snapshot() -> anyhow::Result<()> {
let client_key = SecretKey::from_bytes(&[42u8; 32]);
let client_info = ClientInfo {
version: PROTOCOL_VERSION,
};
let message = postcard::to_stdvec(&client_info)?;
let signature = client_key.sign(&message);
let frames = vec![
(
Frame::ClientInfo {
client_public_key: client_key.public(),
message: Bytes::from(message),
signature,
},
"02 52 45 4c 41 59 f0 9f 94 91 19 7f 6b 23 e1 6c
85 32 c6 ab c8 38 fa cd 5e a7 89 be 0c 76 b2 92
03 34 03 9b fa 8b 3d 36 8d 61 88 e7 7b 22 f2 92
ab 37 43 5d a8 de 0b c8 cb 84 e2 88 f4 e7 3b 35
82 a5 27 31 e9 ff 98 65 46 5c 87 e0 5e 8d 42 7d
f4 22 bb 6e 85 e1 c0 5f 6f 74 98 37 ba a4 a5 c7
eb a3 23 0d 77 56 99 10 43 0e 03",
),
(
Frame::Health {
problem: "Hello? Yes this is dog.".into(),
},
"0e 48 65 6c 6c 6f 3f 20 59 65 73 20 74 68 69 73
20 69 73 20 64 6f 67 2e",
),
(Frame::KeepAlive, "06"),
(Frame::NotePreferred { preferred: true }, "07 01"),
(
Frame::PeerGone {
peer: client_key.public(),
},
"08 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e
a7 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d
61",
),
(
Frame::Ping { data: [42u8; 8] },
"0c 2a 2a 2a 2a 2a 2a 2a 2a",
),
(
Frame::Pong { data: [42u8; 8] },
"0d 2a 2a 2a 2a 2a 2a 2a 2a",
),
(
Frame::RecvPacket {
src_key: client_key.public(),
content: "Hello World!".into(),
},
"05 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e
a7 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d
61 48 65 6c 6c 6f 20 57 6f 72 6c 64 21",
),
(
Frame::SendPacket {
dst_key: client_key.public(),
packet: "Goodbye!".into(),
},
"04 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e
a7 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d
61 47 6f 6f 64 62 79 65 21",
),
(
Frame::Restarting {
reconnect_in: 10,
try_for: 20,
},
"0f 00 00 00 0a 00 00 00 14",
),
];
for (frame, expected_hex) in frames {
let bytes = frame.encode_for_ws_msg();
let expected_bytes = iroh_test::hexdump::parse_hexdump(expected_hex)?;
assert_eq!(bytes, expected_bytes);
}
Ok(())
}
}
#[cfg(test)]
mod proptests {
use proptest::prelude::*;
use super::*;
fn secret_key() -> impl Strategy<Value = SecretKey> {
prop::array::uniform32(any::<u8>()).prop_map(SecretKey::from)
}
fn key() -> impl Strategy<Value = PublicKey> {
secret_key().prop_map(|key| key.public())
}
fn data(consumed: usize) -> impl Strategy<Value = Bytes> {
let len = MAX_PACKET_SIZE - consumed;
prop::collection::vec(any::<u8>(), 0..len).prop_map(Bytes::from)
}
fn frame() -> impl Strategy<Value = Frame> {
let client_info = (secret_key()).prop_map(|secret_key| {
let info = ClientInfo {
version: PROTOCOL_VERSION,
};
let msg = postcard::to_stdvec(&info).expect("using default ClientInfo");
let signature = secret_key.sign(&msg);
Frame::ClientInfo {
client_public_key: secret_key.public(),
message: msg.into(),
signature,
}
});
let send_packet =
(key(), data(32)).prop_map(|(dst_key, packet)| Frame::SendPacket { dst_key, packet });
let recv_packet =
(key(), data(32)).prop_map(|(src_key, content)| Frame::RecvPacket { src_key, content });
let keep_alive = Just(Frame::KeepAlive);
let note_preferred = any::<bool>().prop_map(|preferred| Frame::NotePreferred { preferred });
let peer_gone = key().prop_map(|peer| Frame::PeerGone { peer });
let ping = prop::array::uniform8(any::<u8>()).prop_map(|data| Frame::Ping { data });
let pong = prop::array::uniform8(any::<u8>()).prop_map(|data| Frame::Pong { data });
let health = data(0).prop_map(|problem| Frame::Health { problem });
let restarting =
(any::<u32>(), any::<u32>()).prop_map(|(reconnect_in, try_for)| Frame::Restarting {
reconnect_in,
try_for,
});
prop_oneof![
client_info,
send_packet,
recv_packet,
keep_alive,
note_preferred,
peer_gone,
ping,
pong,
health,
restarting,
]
}
fn inject_error(buf: &mut BytesMut) {
fn is_fixed_size(tpe: FrameType) -> bool {
match tpe {
FrameType::KeepAlive
| FrameType::NotePreferred
| FrameType::Ping
| FrameType::Pong
| FrameType::Restarting
| FrameType::PeerGone => true,
FrameType::ClientInfo
| FrameType::Health
| FrameType::SendPacket
| FrameType::RecvPacket
| FrameType::Unknown => false,
}
}
let tpe: FrameType = buf[0].into();
let mut len = u32::from_be_bytes(buf[1..5].try_into().unwrap()) as usize;
if is_fixed_size(tpe) {
buf.put_u8(0);
len += 1;
} else {
buf.resize(MAX_FRAME_SIZE + 1, 0);
len = MAX_FRAME_SIZE + 1;
}
buf[1..5].copy_from_slice(&u32::to_be_bytes(len as u32));
}
proptest! {
#[test]
fn frame_roundtrip(frame in frame()) {
let mut buf = BytesMut::new();
DerpCodec.encode(frame.clone(), &mut buf).unwrap();
let decoded = DerpCodec.decode(&mut buf).unwrap().unwrap();
prop_assert_eq!(frame, decoded);
}
#[test]
fn frame_ws_roundtrip(frame in frame()) {
let encoded = frame.clone().encode_for_ws_msg();
let decoded = Frame::decode_from_ws_msg(encoded).unwrap();
prop_assert_eq!(frame, decoded);
}
#[test]
fn broken_frame_handling(frame in frame()) {
let mut buf = BytesMut::new();
DerpCodec.encode(frame.clone(), &mut buf).unwrap();
inject_error(&mut buf);
let decoded = DerpCodec.decode(&mut buf);
prop_assert!(decoded.is_err());
}
}
}