noodles_bgzf/
multithreaded_reader.rs

1use std::{
2    io::{self, BufRead, Read, Seek, SeekFrom},
3    mem,
4    num::NonZeroUsize,
5    thread::{self, JoinHandle},
6};
7
8use crossbeam_channel::{Receiver, Sender};
9
10use crate::{gzi, Block, VirtualPosition};
11
12type BufferedTx = Sender<io::Result<Buffer>>;
13type BufferedRx = Receiver<io::Result<Buffer>>;
14type InflateTx = Sender<(Buffer, BufferedTx)>;
15type InflateRx = Receiver<(Buffer, BufferedTx)>;
16type ReadTx = Sender<BufferedRx>;
17type ReadRx = Receiver<BufferedRx>;
18type RecycleTx = Sender<Buffer>;
19type RecycleRx = Receiver<Buffer>;
20
21enum State<R> {
22    Paused(R),
23    Running {
24        reader_handle: JoinHandle<Result<R, ReadError<R>>>,
25        inflater_handles: Vec<JoinHandle<()>>,
26        read_rx: ReadRx,
27        recycle_tx: RecycleTx,
28    },
29    Done,
30}
31
32#[derive(Debug, Default)]
33struct Buffer {
34    buf: Vec<u8>,
35    block: Block,
36}
37
38/// A multithreaded BGZF reader.
39///
40/// This is a multithreaded BGZF reader that uses a thread pool to decompress block data. It places
41/// the inner reader on its own thread to read raw frames asynchronously.
42pub struct MultithreadedReader<R> {
43    state: State<R>,
44    worker_count: NonZeroUsize,
45    position: u64,
46    buffer: Buffer,
47}
48
49impl<R> MultithreadedReader<R> {
50    /// Returns the current position of the stream.
51    ///
52    /// # Examples
53    ///
54    /// ```
55    /// # use std::io;
56    /// use noodles_bgzf as bgzf;
57    /// let reader = bgzf::MultithreadedReader::new(io::empty());
58    /// assert_eq!(reader.position(), 0);
59    /// ```
60    pub fn position(&self) -> u64 {
61        self.position
62    }
63
64    /// Returns the current virtual position of the stream.
65    ///
66    /// # Examples
67    ///
68    /// ```
69    /// # use std::io;
70    /// use noodles_bgzf as bgzf;
71    /// let reader = bgzf::MultithreadedReader::new(io::empty());
72    /// assert_eq!(reader.virtual_position(), bgzf::VirtualPosition::MIN);
73    /// ```
74    pub fn virtual_position(&self) -> VirtualPosition {
75        self.buffer.block.virtual_position()
76    }
77
78    /// Shuts down the reader.
79    ///
80    /// # Examples
81    ///
82    /// ```
83    /// # use std::io;
84    /// use noodles_bgzf as bgzf;
85    /// let mut reader = bgzf::MultithreadedReader::new(io::empty());
86    /// reader.finish()?;
87    /// # Ok::<_, io::Error>(())
88    /// ```
89    pub fn finish(&mut self) -> io::Result<R> {
90        let state = mem::replace(&mut self.state, State::Done);
91
92        match state {
93            State::Paused(inner) => Ok(inner),
94            State::Running {
95                reader_handle,
96                mut inflater_handles,
97                recycle_tx,
98                ..
99            } => {
100                drop(recycle_tx);
101
102                for handle in inflater_handles.drain(..) {
103                    handle.join().unwrap();
104                }
105
106                reader_handle.join().unwrap().map_err(|e| e.1)
107            }
108            State::Done => panic!("invalid state"),
109        }
110    }
111}
112
113impl<R> MultithreadedReader<R>
114where
115    R: Read + Send + 'static,
116{
117    /// Creates a multithreaded BGZF reader with a worker count of 1.
118    ///
119    /// # Examples
120    ///
121    /// ```
122    /// # use std::io;
123    /// use noodles_bgzf as bgzf;
124    /// let reader = bgzf::MultithreadedReader::new(io::empty());
125    /// ```
126    pub fn new(inner: R) -> Self {
127        Self::with_worker_count(NonZeroUsize::MIN, inner)
128    }
129
130    /// Creates a multithreaded BGZF reader with a worker count.
131    ///
132    /// # Examples
133    ///
134    /// ```
135    /// # use std::io;
136    /// use std::num::NonZeroUsize;
137    /// use noodles_bgzf as bgzf;
138    /// let reader = bgzf::MultithreadedReader::with_worker_count(NonZeroUsize::MIN, io::empty());
139    /// ```
140    pub fn with_worker_count(worker_count: NonZeroUsize, inner: R) -> Self {
141        Self {
142            state: State::Paused(inner),
143            worker_count,
144            position: 0,
145            buffer: Buffer::default(),
146        }
147    }
148
149    /// Returns a mutable reference to the underlying reader.
150    ///
151    /// # Examples
152    ///
153    /// ```
154    /// # use std::io;
155    /// use noodles_bgzf as bgzf;
156    /// let mut reader = bgzf::MultithreadedReader::new(io::empty());
157    /// let _inner = reader.get_mut();
158    /// ```
159    pub fn get_mut(&mut self) -> &mut R {
160        self.pause();
161
162        match &mut self.state {
163            State::Paused(inner) => inner,
164            _ => panic!("invalid state"),
165        }
166    }
167
168    fn resume(&mut self) {
169        if matches!(self.state, State::Running { .. }) {
170            return;
171        }
172
173        let state = mem::replace(&mut self.state, State::Done);
174
175        let State::Paused(inner) = state else {
176            panic!("invalid state");
177        };
178
179        let worker_count = self.worker_count.get();
180
181        let (inflate_tx, inflate_rx) = crossbeam_channel::bounded(worker_count);
182        let (read_tx, read_rx) = crossbeam_channel::bounded(worker_count);
183        let (recycle_tx, recycle_rx) = crossbeam_channel::bounded(worker_count);
184
185        for _ in 0..worker_count {
186            recycle_tx.send(Buffer::default()).unwrap();
187        }
188
189        let reader_handle = spawn_reader(inner, inflate_tx, read_tx, recycle_rx);
190        let inflater_handles = spawn_inflaters(self.worker_count, inflate_rx);
191
192        self.state = State::Running {
193            reader_handle,
194            inflater_handles,
195            read_rx,
196            recycle_tx,
197        };
198    }
199
200    fn pause(&mut self) {
201        if matches!(self.state, State::Paused(_)) {
202            return;
203        }
204
205        let state = mem::replace(&mut self.state, State::Done);
206
207        let State::Running {
208            reader_handle,
209            mut inflater_handles,
210            recycle_tx,
211            ..
212        } = state
213        else {
214            panic!("invalid state");
215        };
216
217        drop(recycle_tx);
218
219        for handle in inflater_handles.drain(..) {
220            handle.join().unwrap();
221        }
222
223        // Discard read errors.
224        let inner = match reader_handle.join().unwrap() {
225            Ok(inner) => inner,
226            Err(ReadError(inner, _)) => inner,
227        };
228
229        self.state = State::Paused(inner);
230    }
231
232    fn read_block(&mut self) -> io::Result<()> {
233        self.resume();
234
235        let State::Running {
236            read_rx,
237            recycle_tx,
238            ..
239        } = &self.state
240        else {
241            panic!("invalid state");
242        };
243
244        while let Some(mut buffer) = recv_buffer(read_rx)? {
245            buffer.block.set_position(self.position);
246            self.position += buffer.block.size();
247
248            let prev_buffer = mem::replace(&mut self.buffer, buffer);
249            recycle_tx.send(prev_buffer).ok();
250
251            if self.buffer.block.data().len() > 0 {
252                break;
253            }
254        }
255
256        Ok(())
257    }
258}
259
260impl<R> Drop for MultithreadedReader<R> {
261    fn drop(&mut self) {
262        if !matches!(self.state, State::Done) {
263            let _ = self.finish();
264        }
265    }
266}
267
268impl<R> Read for MultithreadedReader<R>
269where
270    R: Read + Send + 'static,
271{
272    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
273        let mut src = self.fill_buf()?;
274        let amt = src.read(buf)?;
275        self.consume(amt);
276        Ok(amt)
277    }
278
279    fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
280        use super::reader::default_read_exact;
281
282        if let Some(src) = self.buffer.block.data().as_ref().get(..buf.len()) {
283            buf.copy_from_slice(src);
284            self.consume(src.len());
285            Ok(())
286        } else {
287            default_read_exact(self, buf)
288        }
289    }
290}
291
292impl<R> BufRead for MultithreadedReader<R>
293where
294    R: Read + Send + 'static,
295{
296    fn fill_buf(&mut self) -> io::Result<&[u8]> {
297        if !self.buffer.block.data().has_remaining() {
298            self.read_block()?;
299        }
300
301        Ok(self.buffer.block.data().as_ref())
302    }
303
304    fn consume(&mut self, amt: usize) {
305        self.buffer.block.data_mut().consume(amt);
306    }
307}
308
309impl<R> crate::io::Read for MultithreadedReader<R>
310where
311    R: Read + Send + 'static,
312{
313    fn virtual_position(&self) -> VirtualPosition {
314        self.buffer.block.virtual_position()
315    }
316}
317
318impl<R> crate::io::BufRead for MultithreadedReader<R> where R: Read + Send + 'static {}
319
320impl<R> crate::io::Seek for MultithreadedReader<R>
321where
322    R: Read + Send + Seek + 'static,
323{
324    fn seek_to_virtual_position(&mut self, pos: VirtualPosition) -> io::Result<VirtualPosition> {
325        let (cpos, upos) = pos.into();
326
327        self.get_mut().seek(SeekFrom::Start(cpos))?;
328        self.position = cpos;
329
330        self.read_block()?;
331
332        self.buffer.block.data_mut().set_position(usize::from(upos));
333
334        Ok(pos)
335    }
336
337    fn seek_with_index(&mut self, index: &gzi::Index, pos: SeekFrom) -> io::Result<u64> {
338        let SeekFrom::Start(pos) = pos else {
339            unimplemented!();
340        };
341
342        let virtual_position = index.query(pos)?;
343        self.seek_to_virtual_position(virtual_position)?;
344        Ok(pos)
345    }
346}
347
348fn recv_buffer(read_rx: &ReadRx) -> io::Result<Option<Buffer>> {
349    if let Ok(buffered_rx) = read_rx.recv() {
350        if let Ok(buffer) = buffered_rx.recv() {
351            return buffer.map(Some);
352        }
353    }
354
355    Ok(None)
356}
357
358struct ReadError<R>(R, io::Error);
359
360fn spawn_reader<R>(
361    mut reader: R,
362    inflate_tx: InflateTx,
363    read_tx: ReadTx,
364    recycle_rx: RecycleRx,
365) -> JoinHandle<Result<R, ReadError<R>>>
366where
367    R: Read + Send + 'static,
368{
369    use super::reader::frame::read_frame_into;
370
371    thread::spawn(move || {
372        while let Ok(mut buffer) = recycle_rx.recv() {
373            match read_frame_into(&mut reader, &mut buffer.buf) {
374                Ok(result) if result.is_none() => break,
375                Ok(_) => {}
376                Err(e) => return Err(ReadError(reader, e)),
377            }
378
379            let (buffered_tx, buffered_rx) = crossbeam_channel::bounded(1);
380
381            inflate_tx.send((buffer, buffered_tx)).unwrap();
382            read_tx.send(buffered_rx).unwrap();
383        }
384
385        Ok(reader)
386    })
387}
388
389fn spawn_inflaters(worker_count: NonZeroUsize, inflate_rx: InflateRx) -> Vec<JoinHandle<()>> {
390    use super::reader::frame::parse_block;
391
392    (0..worker_count.get())
393        .map(|_| {
394            let inflate_rx = inflate_rx.clone();
395
396            thread::spawn(move || {
397                while let Ok((mut buffer, buffered_tx)) = inflate_rx.recv() {
398                    let result = parse_block(&buffer.buf, &mut buffer.block).map(|_| buffer);
399                    buffered_tx.send(result).unwrap();
400                }
401            })
402        })
403        .collect()
404}
405
406#[cfg(test)]
407mod tests {
408    use std::io::Cursor;
409
410    use super::*;
411
412    #[test]
413    fn test_seek_to_virtual_position() -> Result<(), Box<dyn std::error::Error>> {
414        use crate::io::Seek;
415
416        #[rustfmt::skip]
417        static DATA: &[u8] = &[
418            // block 0 (b"noodles")
419            0x1f, 0x8b, 0x08, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0x06, 0x00, 0x42, 0x43,
420            0x02, 0x00, 0x22, 0x00, 0xcb, 0xcb, 0xcf, 0x4f, 0xc9, 0x49, 0x2d, 0x06, 0x00, 0xa1,
421            0x58, 0x2a, 0x80, 0x07, 0x00, 0x00, 0x00,
422            // EOF block
423            0x1f, 0x8b, 0x08, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0x06, 0x00, 0x42, 0x43,
424            0x02, 0x00, 0x1b, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
425        ];
426
427        const EOF_VIRTUAL_POSITION: VirtualPosition = match VirtualPosition::new(63, 0) {
428            Some(pos) => pos,
429            None => unreachable!(),
430        };
431
432        const VIRTUAL_POSITION: VirtualPosition = match VirtualPosition::new(0, 3) {
433            Some(pos) => pos,
434            None => unreachable!(),
435        };
436
437        let mut reader =
438            MultithreadedReader::with_worker_count(NonZeroUsize::MIN, Cursor::new(DATA));
439
440        let mut buf = Vec::new();
441        reader.read_to_end(&mut buf)?;
442
443        assert_eq!(reader.virtual_position(), EOF_VIRTUAL_POSITION);
444
445        reader.seek_to_virtual_position(VIRTUAL_POSITION)?;
446
447        buf.clear();
448        reader.read_to_end(&mut buf)?;
449
450        assert_eq!(buf, b"dles");
451        assert_eq!(reader.virtual_position(), EOF_VIRTUAL_POSITION);
452
453        Ok(())
454    }
455}