use std::{
cmp::Ordering,
collections::{binary_heap::PeekMut, BinaryHeap},
mem,
};
use bytes::{Buf, Bytes, BytesMut};
use crate::range_set::RangeSet;
#[derive(Debug, Default)]
pub(super) struct Assembler {
state: State,
data: BinaryHeap<Buffer>,
buffered: usize,
allocated: usize,
bytes_read: u64,
end: u64,
}
impl Assembler {
pub(super) fn new() -> Self {
Self::default()
}
pub(super) fn ensure_ordering(&mut self, ordered: bool) -> Result<(), IllegalOrderedRead> {
if ordered && !self.state.is_ordered() {
return Err(IllegalOrderedRead);
} else if !ordered && self.state.is_ordered() {
if !self.data.is_empty() {
self.defragment();
}
let mut recvd = RangeSet::new();
recvd.insert(0..self.bytes_read);
for chunk in &self.data {
recvd.insert(chunk.offset..chunk.offset + chunk.bytes.len() as u64);
}
self.state = State::Unordered { recvd };
}
Ok(())
}
pub(super) fn read(&mut self, max_length: usize, ordered: bool) -> Option<Chunk> {
loop {
let mut chunk = self.data.peek_mut()?;
if ordered {
if chunk.offset > self.bytes_read {
return None;
} else if (chunk.offset + chunk.bytes.len() as u64) <= self.bytes_read {
self.buffered -= chunk.bytes.len();
self.allocated -= chunk.allocation_size;
PeekMut::pop(chunk);
continue;
}
let start = (self.bytes_read - chunk.offset) as usize;
if start > 0 {
chunk.bytes.advance(start);
chunk.offset += start as u64;
self.buffered -= start;
}
}
return Some(if max_length < chunk.bytes.len() {
self.bytes_read += max_length as u64;
let offset = chunk.offset;
chunk.offset += max_length as u64;
self.buffered -= max_length;
Chunk::new(offset, chunk.bytes.split_to(max_length))
} else {
self.bytes_read += chunk.bytes.len() as u64;
self.buffered -= chunk.bytes.len();
self.allocated -= chunk.allocation_size;
let chunk = PeekMut::pop(chunk);
Chunk::new(chunk.offset, chunk.bytes)
});
}
}
fn defragment(&mut self) {
let new = BinaryHeap::with_capacity(self.data.len());
let old = mem::replace(&mut self.data, new);
let mut buffers = old.into_sorted_vec();
self.buffered = 0;
let mut fragmented_buffered = 0;
let mut offset = 0;
for chunk in buffers.iter_mut().rev() {
chunk.try_mark_defragment(offset);
let size = chunk.bytes.len();
offset = chunk.offset + size as u64;
self.buffered += size;
if !chunk.defragmented {
fragmented_buffered += size;
}
}
self.allocated = self.buffered;
let mut buffer = BytesMut::with_capacity(fragmented_buffered);
let mut offset = 0;
for chunk in buffers.into_iter().rev() {
if chunk.defragmented {
if !chunk.bytes.is_empty() {
self.data.push(chunk);
}
continue;
}
if chunk.offset != offset + (buffer.len() as u64) {
if !buffer.is_empty() {
self.data
.push(Buffer::new_defragmented(offset, buffer.split().freeze()));
}
offset = chunk.offset;
}
buffer.extend_from_slice(&chunk.bytes);
}
if !buffer.is_empty() {
self.data
.push(Buffer::new_defragmented(offset, buffer.split().freeze()));
}
}
pub(super) fn insert(&mut self, mut offset: u64, mut bytes: Bytes, allocation_size: usize) {
debug_assert!(
bytes.len() <= allocation_size,
"allocation_size less than bytes.len(): {:?} < {:?}",
allocation_size,
bytes.len()
);
self.end = self.end.max(offset + bytes.len() as u64);
if let State::Unordered { ref mut recvd } = self.state {
for duplicate in recvd.replace(offset..offset + bytes.len() as u64) {
if duplicate.start > offset {
let buffer = Buffer::new(
offset,
bytes.split_to((duplicate.start - offset) as usize),
allocation_size,
);
self.buffered += buffer.bytes.len();
self.allocated += buffer.allocation_size;
self.data.push(buffer);
offset = duplicate.start;
}
bytes.advance((duplicate.end - offset) as usize);
offset = duplicate.end;
}
} else if offset < self.bytes_read {
if (offset + bytes.len() as u64) <= self.bytes_read {
return;
} else {
let diff = self.bytes_read - offset;
offset += diff;
bytes.advance(diff as usize);
}
}
if bytes.is_empty() {
return;
}
let buffer = Buffer::new(offset, bytes, allocation_size);
self.buffered += buffer.bytes.len();
self.allocated += buffer.allocation_size;
self.data.push(buffer);
let buffered = self.buffered.min((self.end - self.bytes_read) as usize);
let over_allocation = self.allocated - buffered;
let threshold = 32768.max(buffered * 3 / 2);
if over_allocation > threshold {
self.defragment()
}
}
pub(super) fn bytes_read(&self) -> u64 {
self.bytes_read
}
pub(super) fn clear(&mut self) {
self.data.clear();
self.buffered = 0;
self.allocated = 0;
}
}
#[derive(Debug, PartialEq, Eq)]
pub struct Chunk {
pub offset: u64,
pub bytes: Bytes,
}
impl Chunk {
fn new(offset: u64, bytes: Bytes) -> Self {
Self { offset, bytes }
}
}
#[derive(Debug, Eq)]
struct Buffer {
offset: u64,
bytes: Bytes,
allocation_size: usize,
defragmented: bool,
}
impl Buffer {
fn new(offset: u64, bytes: Bytes, allocation_size: usize) -> Self {
Self {
offset,
bytes,
allocation_size,
defragmented: false,
}
}
fn new_defragmented(offset: u64, bytes: Bytes) -> Self {
let allocation_size = bytes.len();
Self {
offset,
bytes,
allocation_size,
defragmented: true,
}
}
fn try_mark_defragment(&mut self, offset: u64) {
let duplicate = offset.saturating_sub(self.offset) as usize;
self.offset = self.offset.max(offset);
if duplicate >= self.bytes.len() {
self.bytes = Bytes::new();
self.defragmented = true;
self.allocation_size = 0;
return;
}
self.bytes.advance(duplicate);
self.defragmented = self.defragmented || self.bytes.len() * 6 / 5 >= self.allocation_size;
if self.defragmented {
self.allocation_size = self.bytes.len();
}
}
}
impl Ord for Buffer {
fn cmp(&self, other: &Self) -> Ordering {
self.offset
.cmp(&other.offset)
.reverse()
.then(self.bytes.len().cmp(&other.bytes.len()))
}
}
impl PartialOrd for Buffer {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl PartialEq for Buffer {
fn eq(&self, other: &Self) -> bool {
(self.offset, self.bytes.len()) == (other.offset, other.bytes.len())
}
}
#[derive(Debug)]
enum State {
Ordered,
Unordered {
recvd: RangeSet,
},
}
impl State {
fn is_ordered(&self) -> bool {
matches!(self, Self::Ordered)
}
}
impl Default for State {
fn default() -> Self {
Self::Ordered
}
}
#[derive(Debug)]
pub struct IllegalOrderedRead;
#[cfg(test)]
mod test {
use super::*;
use assert_matches::assert_matches;
#[test]
fn assemble_ordered() {
let mut x = Assembler::new();
assert_matches!(next(&mut x, 32), None);
x.insert(0, Bytes::from_static(b"123"), 3);
assert_matches!(next(&mut x, 1), Some(ref y) if &y[..] == b"1");
assert_matches!(next(&mut x, 3), Some(ref y) if &y[..] == b"23");
x.insert(3, Bytes::from_static(b"456"), 3);
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"456");
x.insert(6, Bytes::from_static(b"789"), 3);
x.insert(9, Bytes::from_static(b"10"), 2);
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"789");
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"10");
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn assemble_unordered() {
let mut x = Assembler::new();
x.ensure_ordering(false).unwrap();
x.insert(3, Bytes::from_static(b"456"), 3);
assert_matches!(next(&mut x, 32), None);
x.insert(0, Bytes::from_static(b"123"), 3);
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123");
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"456");
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn assemble_duplicate() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"123"), 3);
x.insert(0, Bytes::from_static(b"123"), 3);
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123");
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn assemble_duplicate_compact() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"123"), 3);
x.insert(0, Bytes::from_static(b"123"), 3);
x.defragment();
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123");
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn assemble_contained() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"12345"), 5);
x.insert(1, Bytes::from_static(b"234"), 3);
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"12345");
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn assemble_contained_compact() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"12345"), 5);
x.insert(1, Bytes::from_static(b"234"), 3);
x.defragment();
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"12345");
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn assemble_contains() {
let mut x = Assembler::new();
x.insert(1, Bytes::from_static(b"234"), 3);
x.insert(0, Bytes::from_static(b"12345"), 5);
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"12345");
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn assemble_contains_compact() {
let mut x = Assembler::new();
x.insert(1, Bytes::from_static(b"234"), 3);
x.insert(0, Bytes::from_static(b"12345"), 5);
x.defragment();
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"12345");
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn assemble_overlapping() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"123"), 3);
x.insert(1, Bytes::from_static(b"234"), 3);
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123");
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"4");
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn assemble_overlapping_compact() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"123"), 4);
x.insert(1, Bytes::from_static(b"234"), 4);
x.defragment();
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"1234");
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn assemble_complex() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"1"), 1);
x.insert(2, Bytes::from_static(b"3"), 1);
x.insert(4, Bytes::from_static(b"5"), 1);
x.insert(0, Bytes::from_static(b"123456"), 6);
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123456");
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn assemble_complex_compact() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"1"), 1);
x.insert(2, Bytes::from_static(b"3"), 1);
x.insert(4, Bytes::from_static(b"5"), 1);
x.insert(0, Bytes::from_static(b"123456"), 6);
x.defragment();
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123456");
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn assemble_old() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"1234"), 4);
assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"1234");
x.insert(0, Bytes::from_static(b"1234"), 4);
assert_matches!(next(&mut x, 32), None);
}
#[test]
fn compact() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"abc"), 4);
x.insert(3, Bytes::from_static(b"def"), 4);
x.insert(9, Bytes::from_static(b"jkl"), 4);
x.insert(12, Bytes::from_static(b"mno"), 4);
x.defragment();
assert_eq!(
next_unordered(&mut x),
Chunk::new(0, Bytes::from_static(b"abcdef"))
);
assert_eq!(
next_unordered(&mut x),
Chunk::new(9, Bytes::from_static(b"jklmno"))
);
}
#[test]
fn defrag_with_missing_prefix() {
let mut x = Assembler::new();
x.insert(3, Bytes::from_static(b"def"), 3);
x.defragment();
assert_eq!(
next_unordered(&mut x),
Chunk::new(3, Bytes::from_static(b"def"))
);
}
#[test]
fn defrag_read_chunk() {
let mut x = Assembler::new();
x.insert(3, Bytes::from_static(b"def"), 4);
x.insert(0, Bytes::from_static(b"abc"), 4);
x.insert(7, Bytes::from_static(b"hij"), 4);
x.insert(11, Bytes::from_static(b"lmn"), 4);
x.defragment();
assert_matches!(x.read(usize::MAX, true), Some(ref y) if &y.bytes[..] == b"abcdef");
x.insert(5, Bytes::from_static(b"fghijklmn"), 9);
assert_matches!(x.read(usize::MAX, true), Some(ref y) if &y.bytes[..] == b"ghijklmn");
x.insert(13, Bytes::from_static(b"nopq"), 4);
assert_matches!(x.read(usize::MAX, true), Some(ref y) if &y.bytes[..] == b"opq");
x.insert(15, Bytes::from_static(b"pqrs"), 4);
assert_matches!(x.read(usize::MAX, true), Some(ref y) if &y.bytes[..] == b"rs");
assert_matches!(x.read(usize::MAX, true), None);
}
#[test]
fn unordered_happy_path() {
let mut x = Assembler::new();
x.ensure_ordering(false).unwrap();
x.insert(0, Bytes::from_static(b"abc"), 3);
assert_eq!(
next_unordered(&mut x),
Chunk::new(0, Bytes::from_static(b"abc"))
);
assert_eq!(x.read(usize::MAX, false), None);
x.insert(3, Bytes::from_static(b"def"), 3);
assert_eq!(
next_unordered(&mut x),
Chunk::new(3, Bytes::from_static(b"def"))
);
assert_eq!(x.read(usize::MAX, false), None);
}
#[test]
fn unordered_dedup() {
let mut x = Assembler::new();
x.ensure_ordering(false).unwrap();
x.insert(3, Bytes::from_static(b"def"), 3);
assert_eq!(
next_unordered(&mut x),
Chunk::new(3, Bytes::from_static(b"def"))
);
assert_eq!(x.read(usize::MAX, false), None);
x.insert(0, Bytes::from_static(b"a"), 1);
x.insert(0, Bytes::from_static(b"abcdefghi"), 9);
x.insert(0, Bytes::from_static(b"abcd"), 4);
assert_eq!(
next_unordered(&mut x),
Chunk::new(0, Bytes::from_static(b"a"))
);
assert_eq!(
next_unordered(&mut x),
Chunk::new(1, Bytes::from_static(b"bc"))
);
assert_eq!(
next_unordered(&mut x),
Chunk::new(6, Bytes::from_static(b"ghi"))
);
assert_eq!(x.read(usize::MAX, false), None);
x.insert(8, Bytes::from_static(b"ijkl"), 4);
assert_eq!(
next_unordered(&mut x),
Chunk::new(9, Bytes::from_static(b"jkl"))
);
assert_eq!(x.read(usize::MAX, false), None);
x.insert(12, Bytes::from_static(b"mno"), 3);
assert_eq!(
next_unordered(&mut x),
Chunk::new(12, Bytes::from_static(b"mno"))
);
assert_eq!(x.read(usize::MAX, false), None);
x.insert(2, Bytes::from_static(b"cde"), 3);
assert_eq!(x.read(usize::MAX, false), None);
}
#[test]
fn chunks_dedup() {
let mut x = Assembler::new();
x.insert(3, Bytes::from_static(b"def"), 3);
assert_eq!(x.read(usize::MAX, true), None);
x.insert(0, Bytes::from_static(b"a"), 1);
x.insert(1, Bytes::from_static(b"bcdefghi"), 9);
x.insert(0, Bytes::from_static(b"abcd"), 4);
assert_eq!(
x.read(usize::MAX, true),
Some(Chunk::new(0, Bytes::from_static(b"abcd")))
);
assert_eq!(
x.read(usize::MAX, true),
Some(Chunk::new(4, Bytes::from_static(b"efghi")))
);
assert_eq!(x.read(usize::MAX, true), None);
x.insert(8, Bytes::from_static(b"ijkl"), 4);
assert_eq!(
x.read(usize::MAX, true),
Some(Chunk::new(9, Bytes::from_static(b"jkl")))
);
assert_eq!(x.read(usize::MAX, true), None);
x.insert(12, Bytes::from_static(b"mno"), 3);
assert_eq!(
x.read(usize::MAX, true),
Some(Chunk::new(12, Bytes::from_static(b"mno")))
);
assert_eq!(x.read(usize::MAX, true), None);
x.insert(2, Bytes::from_static(b"cde"), 3);
assert_eq!(x.read(usize::MAX, true), None);
}
#[test]
fn ordered_eager_discard() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"abc"), 3);
assert_eq!(x.data.len(), 1);
assert_eq!(
x.read(usize::MAX, true),
Some(Chunk::new(0, Bytes::from_static(b"abc")))
);
x.insert(0, Bytes::from_static(b"ab"), 2);
assert_eq!(x.data.len(), 0);
x.insert(2, Bytes::from_static(b"cd"), 2);
assert_eq!(
x.data.peek(),
Some(&Buffer::new(3, Bytes::from_static(b"d"), 2))
);
}
#[test]
fn ordered_insert_unordered_read() {
let mut x = Assembler::new();
x.insert(0, Bytes::from_static(b"abc"), 3);
x.insert(0, Bytes::from_static(b"abc"), 3);
x.ensure_ordering(false).unwrap();
assert_eq!(
x.read(3, false),
Some(Chunk::new(0, Bytes::from_static(b"abc")))
);
assert_eq!(x.read(3, false), None);
}
fn next_unordered(x: &mut Assembler) -> Chunk {
x.read(usize::MAX, false).unwrap()
}
fn next(x: &mut Assembler, size: usize) -> Option<Bytes> {
x.read(size, true).map(|chunk| chunk.bytes)
}
}