1use std::{
2 io::{self, BufRead, Read, Seek, SeekFrom},
3 mem,
4 num::NonZeroUsize,
5 thread::{self, JoinHandle},
6};
7
8use crossbeam_channel::{Receiver, Sender};
9
10use crate::{gzi, Block, VirtualPosition};
11
12type BufferedTx = Sender<io::Result<Buffer>>;
13type BufferedRx = Receiver<io::Result<Buffer>>;
14type InflateTx = Sender<(Buffer, BufferedTx)>;
15type InflateRx = Receiver<(Buffer, BufferedTx)>;
16type ReadTx = Sender<BufferedRx>;
17type ReadRx = Receiver<BufferedRx>;
18type RecycleTx = Sender<Buffer>;
19type RecycleRx = Receiver<Buffer>;
20
21enum State<R> {
22 Paused(R),
23 Running {
24 reader_handle: JoinHandle<Result<R, ReadError<R>>>,
25 inflater_handles: Vec<JoinHandle<()>>,
26 read_rx: ReadRx,
27 recycle_tx: RecycleTx,
28 },
29 Done,
30}
31
32#[derive(Debug, Default)]
33struct Buffer {
34 buf: Vec<u8>,
35 block: Block,
36}
37
38pub struct MultithreadedReader<R> {
43 state: State<R>,
44 worker_count: NonZeroUsize,
45 position: u64,
46 buffer: Buffer,
47}
48
49impl<R> MultithreadedReader<R> {
50 pub fn position(&self) -> u64 {
61 self.position
62 }
63
64 pub fn virtual_position(&self) -> VirtualPosition {
75 self.buffer.block.virtual_position()
76 }
77
78 pub fn finish(&mut self) -> io::Result<R> {
90 let state = mem::replace(&mut self.state, State::Done);
91
92 match state {
93 State::Paused(inner) => Ok(inner),
94 State::Running {
95 reader_handle,
96 mut inflater_handles,
97 recycle_tx,
98 ..
99 } => {
100 drop(recycle_tx);
101
102 for handle in inflater_handles.drain(..) {
103 handle.join().unwrap();
104 }
105
106 reader_handle.join().unwrap().map_err(|e| e.1)
107 }
108 State::Done => panic!("invalid state"),
109 }
110 }
111}
112
113impl<R> MultithreadedReader<R>
114where
115 R: Read + Send + 'static,
116{
117 pub fn new(inner: R) -> Self {
127 Self::with_worker_count(NonZeroUsize::MIN, inner)
128 }
129
130 pub fn with_worker_count(worker_count: NonZeroUsize, inner: R) -> Self {
141 Self {
142 state: State::Paused(inner),
143 worker_count,
144 position: 0,
145 buffer: Buffer::default(),
146 }
147 }
148
149 pub fn get_mut(&mut self) -> &mut R {
160 self.pause();
161
162 match &mut self.state {
163 State::Paused(inner) => inner,
164 _ => panic!("invalid state"),
165 }
166 }
167
168 fn resume(&mut self) {
169 if matches!(self.state, State::Running { .. }) {
170 return;
171 }
172
173 let state = mem::replace(&mut self.state, State::Done);
174
175 let State::Paused(inner) = state else {
176 panic!("invalid state");
177 };
178
179 let worker_count = self.worker_count.get();
180
181 let (inflate_tx, inflate_rx) = crossbeam_channel::bounded(worker_count);
182 let (read_tx, read_rx) = crossbeam_channel::bounded(worker_count);
183 let (recycle_tx, recycle_rx) = crossbeam_channel::bounded(worker_count);
184
185 for _ in 0..worker_count {
186 recycle_tx.send(Buffer::default()).unwrap();
187 }
188
189 let reader_handle = spawn_reader(inner, inflate_tx, read_tx, recycle_rx);
190 let inflater_handles = spawn_inflaters(self.worker_count, inflate_rx);
191
192 self.state = State::Running {
193 reader_handle,
194 inflater_handles,
195 read_rx,
196 recycle_tx,
197 };
198 }
199
200 fn pause(&mut self) {
201 if matches!(self.state, State::Paused(_)) {
202 return;
203 }
204
205 let state = mem::replace(&mut self.state, State::Done);
206
207 let State::Running {
208 reader_handle,
209 mut inflater_handles,
210 recycle_tx,
211 ..
212 } = state
213 else {
214 panic!("invalid state");
215 };
216
217 drop(recycle_tx);
218
219 for handle in inflater_handles.drain(..) {
220 handle.join().unwrap();
221 }
222
223 let inner = match reader_handle.join().unwrap() {
225 Ok(inner) => inner,
226 Err(ReadError(inner, _)) => inner,
227 };
228
229 self.state = State::Paused(inner);
230 }
231
232 fn read_block(&mut self) -> io::Result<()> {
233 self.resume();
234
235 let State::Running {
236 read_rx,
237 recycle_tx,
238 ..
239 } = &self.state
240 else {
241 panic!("invalid state");
242 };
243
244 while let Some(mut buffer) = recv_buffer(read_rx)? {
245 buffer.block.set_position(self.position);
246 self.position += buffer.block.size();
247
248 let prev_buffer = mem::replace(&mut self.buffer, buffer);
249 recycle_tx.send(prev_buffer).ok();
250
251 if self.buffer.block.data().len() > 0 {
252 break;
253 }
254 }
255
256 Ok(())
257 }
258}
259
260impl<R> Drop for MultithreadedReader<R> {
261 fn drop(&mut self) {
262 if !matches!(self.state, State::Done) {
263 let _ = self.finish();
264 }
265 }
266}
267
268impl<R> Read for MultithreadedReader<R>
269where
270 R: Read + Send + 'static,
271{
272 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
273 let mut src = self.fill_buf()?;
274 let amt = src.read(buf)?;
275 self.consume(amt);
276 Ok(amt)
277 }
278
279 fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
280 use super::reader::default_read_exact;
281
282 if let Some(src) = self.buffer.block.data().as_ref().get(..buf.len()) {
283 buf.copy_from_slice(src);
284 self.consume(src.len());
285 Ok(())
286 } else {
287 default_read_exact(self, buf)
288 }
289 }
290}
291
292impl<R> BufRead for MultithreadedReader<R>
293where
294 R: Read + Send + 'static,
295{
296 fn fill_buf(&mut self) -> io::Result<&[u8]> {
297 if !self.buffer.block.data().has_remaining() {
298 self.read_block()?;
299 }
300
301 Ok(self.buffer.block.data().as_ref())
302 }
303
304 fn consume(&mut self, amt: usize) {
305 self.buffer.block.data_mut().consume(amt);
306 }
307}
308
309impl<R> crate::io::Read for MultithreadedReader<R>
310where
311 R: Read + Send + 'static,
312{
313 fn virtual_position(&self) -> VirtualPosition {
314 self.buffer.block.virtual_position()
315 }
316}
317
318impl<R> crate::io::BufRead for MultithreadedReader<R> where R: Read + Send + 'static {}
319
320impl<R> crate::io::Seek for MultithreadedReader<R>
321where
322 R: Read + Send + Seek + 'static,
323{
324 fn seek_to_virtual_position(&mut self, pos: VirtualPosition) -> io::Result<VirtualPosition> {
325 let (cpos, upos) = pos.into();
326
327 self.get_mut().seek(SeekFrom::Start(cpos))?;
328 self.position = cpos;
329
330 self.read_block()?;
331
332 self.buffer.block.data_mut().set_position(usize::from(upos));
333
334 Ok(pos)
335 }
336
337 fn seek_with_index(&mut self, index: &gzi::Index, pos: SeekFrom) -> io::Result<u64> {
338 let SeekFrom::Start(pos) = pos else {
339 unimplemented!();
340 };
341
342 let virtual_position = index.query(pos)?;
343 self.seek_to_virtual_position(virtual_position)?;
344 Ok(pos)
345 }
346}
347
348fn recv_buffer(read_rx: &ReadRx) -> io::Result<Option<Buffer>> {
349 if let Ok(buffered_rx) = read_rx.recv() {
350 if let Ok(buffer) = buffered_rx.recv() {
351 return buffer.map(Some);
352 }
353 }
354
355 Ok(None)
356}
357
358struct ReadError<R>(R, io::Error);
359
360fn spawn_reader<R>(
361 mut reader: R,
362 inflate_tx: InflateTx,
363 read_tx: ReadTx,
364 recycle_rx: RecycleRx,
365) -> JoinHandle<Result<R, ReadError<R>>>
366where
367 R: Read + Send + 'static,
368{
369 use super::reader::frame::read_frame_into;
370
371 thread::spawn(move || {
372 while let Ok(mut buffer) = recycle_rx.recv() {
373 match read_frame_into(&mut reader, &mut buffer.buf) {
374 Ok(result) if result.is_none() => break,
375 Ok(_) => {}
376 Err(e) => return Err(ReadError(reader, e)),
377 }
378
379 let (buffered_tx, buffered_rx) = crossbeam_channel::bounded(1);
380
381 inflate_tx.send((buffer, buffered_tx)).unwrap();
382 read_tx.send(buffered_rx).unwrap();
383 }
384
385 Ok(reader)
386 })
387}
388
389fn spawn_inflaters(worker_count: NonZeroUsize, inflate_rx: InflateRx) -> Vec<JoinHandle<()>> {
390 use super::reader::frame::parse_block;
391
392 (0..worker_count.get())
393 .map(|_| {
394 let inflate_rx = inflate_rx.clone();
395
396 thread::spawn(move || {
397 while let Ok((mut buffer, buffered_tx)) = inflate_rx.recv() {
398 let result = parse_block(&buffer.buf, &mut buffer.block).map(|_| buffer);
399 buffered_tx.send(result).unwrap();
400 }
401 })
402 })
403 .collect()
404}
405
406#[cfg(test)]
407mod tests {
408 use std::io::Cursor;
409
410 use super::*;
411
412 #[test]
413 fn test_seek_to_virtual_position() -> Result<(), Box<dyn std::error::Error>> {
414 use crate::io::Seek;
415
416 #[rustfmt::skip]
417 static DATA: &[u8] = &[
418 0x1f, 0x8b, 0x08, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0x06, 0x00, 0x42, 0x43,
420 0x02, 0x00, 0x22, 0x00, 0xcb, 0xcb, 0xcf, 0x4f, 0xc9, 0x49, 0x2d, 0x06, 0x00, 0xa1,
421 0x58, 0x2a, 0x80, 0x07, 0x00, 0x00, 0x00,
422 0x1f, 0x8b, 0x08, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0x06, 0x00, 0x42, 0x43,
424 0x02, 0x00, 0x1b, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
425 ];
426
427 const EOF_VIRTUAL_POSITION: VirtualPosition = match VirtualPosition::new(63, 0) {
428 Some(pos) => pos,
429 None => unreachable!(),
430 };
431
432 const VIRTUAL_POSITION: VirtualPosition = match VirtualPosition::new(0, 3) {
433 Some(pos) => pos,
434 None => unreachable!(),
435 };
436
437 let mut reader =
438 MultithreadedReader::with_worker_count(NonZeroUsize::MIN, Cursor::new(DATA));
439
440 let mut buf = Vec::new();
441 reader.read_to_end(&mut buf)?;
442
443 assert_eq!(reader.virtual_position(), EOF_VIRTUAL_POSITION);
444
445 reader.seek_to_virtual_position(VIRTUAL_POSITION)?;
446
447 buf.clear();
448 reader.read_to_end(&mut buf)?;
449
450 assert_eq!(buf, b"dles");
451 assert_eq!(reader.virtual_position(), EOF_VIRTUAL_POSITION);
452
453 Ok(())
454 }
455}