kinesin_rdt/stream/
inbound.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
//! Stream inbound implementation

use std::collections::BTreeMap;
use std::ops::Range;

use tracing::trace;

use crate::common::range_set::RangeSet;
use crate::common::ring_buffer::{RingBuf, RingBufSlice};

/// stream inbound buffer
pub struct StreamInboundState {
    /// buffer for received data
    pub buffer: RingBuf<u8>,
    /// stream offset at which buffer starts
    pub buffer_offset: u64,

    /// received segments
    pub received: RangeSet,
    /// offsets into the stream where messages begin, if applicable
    pub message_offsets: BTreeMap<u64, Option<u32>>,
    /// whether stream is operating in reliable mode
    pub is_reliable: bool,
    /// flow control limit
    pub window_limit: u64,
    /// final length of stream (offset of final byte + 1)
    pub final_offset: Option<u64>,
}

/// result enum of StreamInboundState::receive_segment
#[derive(PartialEq, Debug)]
pub enum ReceiveSegmentResult {
    /// some or all of the segment is new and has been processed
    Received,
    /// all of the segment has already been received
    Duplicate,
    /// segment exceeds window limit and stream state is inconsistent
    ExceedsWindow,
}

// Invariants:
// - `window_limit - buffer_offset <= isize::MAX` to ensure `buffer` remains
//   within capacity limits
// - `received` must contain the range 0..buffer_offset
// - `received` must not contain segments past `buffer_offset + buffer.len()`

impl StreamInboundState {
    /// create new instance
    pub fn new(initial_window_limit: u64, is_reliable: bool) -> StreamInboundState {
        assert!(
            initial_window_limit <= isize::MAX as u64,
            "initial window limit out of range"
        );
        StreamInboundState {
            buffer: RingBuf::new(),
            buffer_offset: 0,
            received: RangeSet::unlimited(),
            message_offsets: BTreeMap::new(),
            is_reliable,
            window_limit: initial_window_limit,
            final_offset: None,
        }
    }

    /// process incoming segment
    #[must_use = "must check if segment exceeds window limit"]
    pub fn receive_segment(&mut self, offset: u64, data: &[u8]) -> ReceiveSegmentResult {
        let tail = offset + data.len() as u64;
        if tail > self.window_limit {
            return ReceiveSegmentResult::ExceedsWindow;
        }

        let segment = offset..tail;
        if self.received.has_range(segment.clone()) {
            return ReceiveSegmentResult::Duplicate;
        }

        // ensure buffer is long enough
        let buffer_end: usize = (segment.end - self.buffer_offset)
            .try_into()
            .expect("window limit invalid");
        if buffer_end > self.buffer.len() {
            self.buffer.fill_at_back(buffer_end - self.buffer.len(), 0);
        }

        // copy new ranges
        for to_copy in self.received.range_complement(segment.clone()) {
            let len: usize = (to_copy.end - to_copy.start).try_into().unwrap();
            let buffer_index: usize = to_copy
                .start
                .checked_sub(self.buffer_offset)
                .expect("received set inconsistent with buffer")
                .try_into()
                .unwrap();

            let slice_start = (to_copy.start - offset) as usize;
            let data_slice = &data[slice_start..slice_start + len];
            trace!("copy {} bytes to offset {}", len, to_copy.start);
            self.buffer
                .range_mut(buffer_index..buffer_index + len)
                .copy_from_slice(data_slice);
        }

        self.received.insert_range(segment);

        ReceiveSegmentResult::Received
    }

    /// advance window limit
    pub fn set_limit(&mut self, new_limit: u64) {
        assert!(new_limit >= self.window_limit, "limit cannot go backwards");

        // ensure buffer size is within limits
        if new_limit - self.buffer_offset > isize::MAX as u64 {
            panic!("new window limit exceeds maximum buffer capaciity");
        }

        trace!(
            "advance window limit by {} bytes (window_limit = {})",
            new_limit - self.window_limit,
            new_limit
        );

        self.window_limit = new_limit;
    }

    /// set message marker at offset
    pub fn set_message_marker(&mut self, offset: u64) {
        if offset < self.buffer_offset {
            return;
        }

        trace!("message at offset {}", offset);
        self.message_offsets.insert(offset, None);
    }

    /// set final offset from sender
    pub fn set_final_offset(&mut self, offset: u64) -> bool {
        if self.final_offset.is_some() {
            false
        } else {
            self.final_offset = Some(offset);
            true
        }
    }

    /// advance buffer, discarding data lower than the new base offset
    pub fn advance_buffer(&mut self, new_base: u64) {
        if new_base < self.buffer_offset {
            panic!("cannot advance buffer backwards");
        }

        let delta = new_base - self.buffer_offset;
        if delta == 0 {
            return;
        }

        // shift buffer forward
        if (self.buffer.len() as u64) < delta {
            self.buffer.clear();
        } else {
            // cast safety: checked by branch
            self.buffer.drain(..(delta as usize));
        }
        self.buffer_offset += delta;

        trace!(delta, "advance buffer");

        // discard old message offsets
        if !self.message_offsets.is_empty() {
            self.message_offsets = self.message_offsets.split_off(&new_base);
        }

        // mark everything prior as received
        self.received.insert_range(0..new_base);
    }

    /// read segment from buffer, if available
    pub fn read_segment(&self, segment: Range<u64>) -> Option<RingBufSlice<'_, u8>> {
        let len: usize = segment
            .end
            .checked_sub(segment.start)
            .expect("range cannot be reverse")
            .try_into()
            .expect("range out of bounds");

        if !self.received.has_range(segment.clone()) {
            // requested segment not complete
            return None;
        }
        if segment.start < self.buffer_offset {
            // requested segment no longer present
            return None;
        }

        // checked by len calculation
        let start = (segment.start - self.buffer_offset) as usize;
        if start + len > self.buffer.len() {
            return None;
        }
        Some(self.buffer.range(start..start + len))
    }

    /// return the highest offset into the stream for which no gaps exist
    /// between it and `buffer_offset`
    pub fn max_contiguous_offset(&self) -> Option<u64> {
        self.received.peek_first().map(|r| r.end)
    }

    /// read available bytes from start of buffer
    ///
    /// Only really makes sense when `is_reliable = true`.
    pub fn read_next(&self, limit: usize) -> Option<RingBufSlice<'_, u8>> {
        let available = self.max_contiguous_offset()?;
        debug_assert!(available >= self.buffer_offset);
        if self.buffer_offset == available {
            None
        } else {
            let len = u64::min(available, limit as u64) as usize;
            Some(self.buffer.range(0..len))
        }
    }

    /// check if stream is fully received
    ///
    /// If unreliable, will return true as soon as a final offset is received,
    /// even if more segments are in transit.
    pub fn finished(&self) -> bool {
        if let Some(final_offset) = self.final_offset {
            if !self.is_reliable {
                true
            } else if let Some(max_received) = self.max_contiguous_offset() {
                max_received >= final_offset
            } else {
                false
            }
        } else {
            false
        }
    }
}

#[cfg(test)]
pub mod test {
    use crate::stream::inbound::ReceiveSegmentResult;

    use super::StreamInboundState;

    #[test]
    fn receive() {
        let mut inbound = StreamInboundState::new(4096, true);
        let hello = String::from("Hello, ");
        let world = String::from("world!");
        assert_eq!(
            inbound.receive_segment(hello.len() as u64, world.as_bytes()),
            ReceiveSegmentResult::Received
        );
        assert_eq!(
            inbound.receive_segment(0, hello.as_bytes()),
            ReceiveSegmentResult::Received
        );
        assert_eq!(
            inbound.receive_segment(8192, &[3, 4, 5, 6]),
            ReceiveSegmentResult::ExceedsWindow
        );
        assert_eq!(
            inbound.receive_segment(3, &[3]),
            ReceiveSegmentResult::Duplicate
        );
        assert!(inbound.set_final_offset((hello.len() + world.len()) as u64));
        let slice = inbound.read_next(64).unwrap();
        let mut read = vec![0; slice.len()];
        slice.copy_to_slice(&mut read);
        let hello2 = String::from_utf8(read).unwrap();
        assert_eq!(hello2, hello + &world);
        assert!(inbound.finished());
    }
}