1use axum::{
2 body,
3 response::{IntoResponse, Response},
4 BoxError,
5};
6use bytes::Bytes;
7use futures_util::TryStream;
8use http::{header, StatusCode};
9use std::{io, path::Path};
10use tokio::{
11 fs::File,
12 io::{AsyncReadExt, AsyncSeekExt},
13};
14use tokio_util::io::ReaderStream;
15
16#[derive(Debug)]
48pub struct FileStream<S> {
49 pub stream: S,
51 pub file_name: Option<String>,
53 pub content_size: Option<u64>,
55}
56
57impl<S> FileStream<S>
58where
59 S: TryStream + Send + 'static,
60 S::Ok: Into<Bytes>,
61 S::Error: Into<BoxError>,
62{
63 pub fn new(stream: S) -> Self {
65 Self {
66 stream,
67 file_name: None,
68 content_size: None,
69 }
70 }
71
72 pub async fn from_path(path: impl AsRef<Path>) -> io::Result<FileStream<ReaderStream<File>>> {
97 let file = File::open(&path).await?;
98 let mut content_size = None;
99 let mut file_name = None;
100
101 if let Ok(metadata) = file.metadata().await {
102 content_size = Some(metadata.len());
103 }
104
105 if let Some(file_name_os) = path.as_ref().file_name() {
106 if let Some(file_name_str) = file_name_os.to_str() {
107 file_name = Some(file_name_str.to_owned());
108 }
109 }
110
111 Ok(FileStream {
112 stream: ReaderStream::new(file),
113 file_name,
114 content_size,
115 })
116 }
117
118 pub fn file_name(mut self, file_name: impl Into<String>) -> Self {
122 self.file_name = Some(file_name.into());
123 self
124 }
125
126 pub fn content_size(mut self, len: u64) -> Self {
128 self.content_size = Some(len);
129 self
130 }
131
132 pub fn into_range_response(self, start: u64, end: u64, total_size: u64) -> Response {
172 let mut resp = Response::builder().header(header::CONTENT_TYPE, "application/octet-stream");
173 resp = resp.status(StatusCode::PARTIAL_CONTENT);
174
175 resp = resp.header(
176 header::CONTENT_RANGE,
177 format!("bytes {start}-{end}/{total_size}"),
178 );
179
180 resp.body(body::Body::from_stream(self.stream))
181 .unwrap_or_else(|e| {
182 (
183 StatusCode::INTERNAL_SERVER_ERROR,
184 format!("build FileStream response error: {e}"),
185 )
186 .into_response()
187 })
188 }
189
190 pub async fn try_range_response(
230 file_path: impl AsRef<Path>,
231 start: u64,
232 mut end: u64,
233 ) -> io::Result<Response> {
234 let mut file = File::open(file_path).await?;
235
236 let metadata = file.metadata().await?;
237 let total_size = metadata.len();
238
239 if end == 0 {
240 end = total_size - 1;
241 }
242
243 if start > total_size {
244 return Ok((StatusCode::RANGE_NOT_SATISFIABLE, "Range Not Satisfiable").into_response());
245 }
246 if start > end {
247 return Ok((StatusCode::RANGE_NOT_SATISFIABLE, "Range Not Satisfiable").into_response());
248 }
249 if end >= total_size {
250 return Ok((StatusCode::RANGE_NOT_SATISFIABLE, "Range Not Satisfiable").into_response());
251 }
252
253 file.seek(std::io::SeekFrom::Start(start)).await?;
254
255 let stream = ReaderStream::new(file.take(end - start + 1));
256
257 Ok(FileStream::new(stream).into_range_response(start, end, total_size))
258 }
259}
260
261impl<S> IntoResponse for FileStream<S>
262where
263 S: TryStream + Send + 'static,
264 S::Ok: Into<Bytes>,
265 S::Error: Into<BoxError>,
266{
267 fn into_response(self) -> Response {
268 let mut resp = Response::builder().header(header::CONTENT_TYPE, "application/octet-stream");
269
270 if let Some(file_name) = self.file_name {
271 resp = resp.header(
272 header::CONTENT_DISPOSITION,
273 format!("attachment; filename=\"{file_name}\""),
274 );
275 }
276
277 if let Some(content_size) = self.content_size {
278 resp = resp.header(header::CONTENT_LENGTH, content_size);
279 }
280
281 resp.body(body::Body::from_stream(self.stream))
282 .unwrap_or_else(|e| {
283 (
284 StatusCode::INTERNAL_SERVER_ERROR,
285 format!("build FileStream responsec error: {e}"),
286 )
287 .into_response()
288 })
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295 use axum::{extract::Request, routing::get, Router};
296 use body::Body;
297 use http::HeaderMap;
298 use http_body_util::BodyExt;
299 use std::io::Cursor;
300 use tokio_util::io::ReaderStream;
301 use tower::ServiceExt;
302
303 #[tokio::test]
304 async fn response() -> Result<(), Box<dyn std::error::Error>> {
305 let app = Router::new().route(
306 "/file",
307 get(|| async {
308 let file_content = b"Hello, this is the simulated file content!".to_vec();
310 let reader = Cursor::new(file_content);
311
312 let stream = ReaderStream::new(reader);
315 FileStream::new(stream).into_response()
316 }),
317 );
318
319 let response = app
321 .oneshot(Request::builder().uri("/file").body(Body::empty())?)
322 .await?;
323
324 assert_eq!(response.status(), StatusCode::OK);
326
327 assert_eq!(
329 response.headers().get("content-type").unwrap(),
330 "application/octet-stream"
331 );
332
333 let body: &[u8] = &response.into_body().collect().await?.to_bytes();
335 assert_eq!(
336 std::str::from_utf8(body)?,
337 "Hello, this is the simulated file content!"
338 );
339 Ok(())
340 }
341
342 #[tokio::test]
343 async fn response_not_set_filename() -> Result<(), Box<dyn std::error::Error>> {
344 let app = Router::new().route(
345 "/file",
346 get(|| async {
347 let file_content = b"Hello, this is the simulated file content!".to_vec();
349 let size = file_content.len() as u64;
350 let reader = Cursor::new(file_content);
351
352 let stream = ReaderStream::new(reader);
354 FileStream::new(stream).content_size(size).into_response()
355 }),
356 );
357
358 let response = app
360 .oneshot(Request::builder().uri("/file").body(Body::empty())?)
361 .await?;
362
363 assert_eq!(response.status(), StatusCode::OK);
365
366 assert_eq!(
368 response.headers().get("content-type").unwrap(),
369 "application/octet-stream"
370 );
371 assert_eq!(response.headers().get("content-length").unwrap(), "42");
372
373 let body: &[u8] = &response.into_body().collect().await?.to_bytes();
375 assert_eq!(
376 std::str::from_utf8(body)?,
377 "Hello, this is the simulated file content!"
378 );
379 Ok(())
380 }
381
382 #[tokio::test]
383 async fn response_not_set_content_size() -> Result<(), Box<dyn std::error::Error>> {
384 let app = Router::new().route(
385 "/file",
386 get(|| async {
387 let file_content = b"Hello, this is the simulated file content!".to_vec();
389 let reader = Cursor::new(file_content);
390
391 let stream = ReaderStream::new(reader);
393 FileStream::new(stream).file_name("test").into_response()
394 }),
395 );
396
397 let response = app
399 .oneshot(Request::builder().uri("/file").body(Body::empty())?)
400 .await?;
401
402 assert_eq!(response.status(), StatusCode::OK);
404
405 assert_eq!(
407 response.headers().get("content-type").unwrap(),
408 "application/octet-stream"
409 );
410 assert_eq!(
411 response.headers().get("content-disposition").unwrap(),
412 "attachment; filename=\"test\""
413 );
414
415 let body: &[u8] = &response.into_body().collect().await?.to_bytes();
417 assert_eq!(
418 std::str::from_utf8(body)?,
419 "Hello, this is the simulated file content!"
420 );
421 Ok(())
422 }
423
424 #[tokio::test]
425 async fn response_with_content_size_and_filename() -> Result<(), Box<dyn std::error::Error>> {
426 let app = Router::new().route(
427 "/file",
428 get(|| async {
429 let file_content = b"Hello, this is the simulated file content!".to_vec();
431 let size = file_content.len() as u64;
432 let reader = Cursor::new(file_content);
433
434 let stream = ReaderStream::new(reader);
436 FileStream::new(stream)
437 .file_name("test")
438 .content_size(size)
439 .into_response()
440 }),
441 );
442
443 let response = app
445 .oneshot(Request::builder().uri("/file").body(Body::empty())?)
446 .await?;
447
448 assert_eq!(response.status(), StatusCode::OK);
450
451 assert_eq!(
453 response.headers().get("content-type").unwrap(),
454 "application/octet-stream"
455 );
456 assert_eq!(
457 response.headers().get("content-disposition").unwrap(),
458 "attachment; filename=\"test\""
459 );
460 assert_eq!(response.headers().get("content-length").unwrap(), "42");
461
462 let body: &[u8] = &response.into_body().collect().await?.to_bytes();
464 assert_eq!(
465 std::str::from_utf8(body)?,
466 "Hello, this is the simulated file content!"
467 );
468 Ok(())
469 }
470
471 #[tokio::test]
472 async fn response_from_path() -> Result<(), Box<dyn std::error::Error>> {
473 let app = Router::new().route(
474 "/from_path",
475 get(move || async move {
476 FileStream::<ReaderStream<File>>::from_path(Path::new("CHANGELOG.md"))
477 .await
478 .unwrap()
479 .into_response()
480 }),
481 );
482
483 let response = app
485 .oneshot(
486 Request::builder()
487 .uri("/from_path")
488 .body(Body::empty())
489 .unwrap(),
490 )
491 .await
492 .unwrap();
493
494 assert_eq!(response.status(), StatusCode::OK);
496
497 assert_eq!(
499 response.headers().get("content-type").unwrap(),
500 "application/octet-stream"
501 );
502 assert_eq!(
503 response.headers().get("content-disposition").unwrap(),
504 "attachment; filename=\"CHANGELOG.md\""
505 );
506
507 let file = File::open("CHANGELOG.md").await.unwrap();
508 let content_length = file.metadata().await.unwrap().len();
510
511 assert_eq!(
512 response
513 .headers()
514 .get("content-length")
515 .unwrap()
516 .to_str()
517 .unwrap(),
518 content_length.to_string()
519 );
520 Ok(())
521 }
522
523 #[tokio::test]
524 async fn response_range_file() -> Result<(), Box<dyn std::error::Error>> {
525 let app = Router::new().route("/range_response", get(range_stream));
526
527 let response = app
529 .oneshot(
530 Request::builder()
531 .uri("/range_response")
532 .header(header::RANGE, "bytes=20-1000")
533 .body(Body::empty())
534 .unwrap(),
535 )
536 .await
537 .unwrap();
538
539 assert_eq!(response.status(), StatusCode::PARTIAL_CONTENT);
541
542 assert_eq!(
544 response.headers().get("content-type").unwrap(),
545 "application/octet-stream"
546 );
547
548 let file = File::open("CHANGELOG.md").await.unwrap();
549 let content_length = file.metadata().await.unwrap().len();
551
552 assert_eq!(
553 response
554 .headers()
555 .get("content-range")
556 .unwrap()
557 .to_str()
558 .unwrap(),
559 format!("bytes 20-1000/{content_length}")
560 );
561 Ok(())
562 }
563
564 async fn range_stream(headers: HeaderMap) -> Response {
565 let range_header = headers
566 .get(header::RANGE)
567 .and_then(|value| value.to_str().ok());
568
569 let (start, end) = if let Some(range) = range_header {
570 if let Some(range) = parse_range_header(range) {
571 range
572 } else {
573 return (StatusCode::RANGE_NOT_SATISFIABLE, "Invalid Range").into_response();
574 }
575 } else {
576 (0, 0) };
578
579 FileStream::<ReaderStream<File>>::try_range_response(Path::new("CHANGELOG.md"), start, end)
580 .await
581 .unwrap()
582 }
583
584 fn parse_range_header(range: &str) -> Option<(u64, u64)> {
585 let range = range.strip_prefix("bytes=")?;
586 let mut parts = range.split('-');
587 let start = parts.next()?.parse::<u64>().ok()?;
588 let end = parts
589 .next()
590 .and_then(|s| s.parse::<u64>().ok())
591 .unwrap_or(0);
592 if start > end {
593 return None;
594 }
595 Some((start, end))
596 }
597}