axum_extra/response/
file_stream.rs

1use axum::{
2    body,
3    response::{IntoResponse, Response},
4    BoxError,
5};
6use bytes::Bytes;
7use futures_util::TryStream;
8use http::{header, StatusCode};
9use std::{io, path::Path};
10use tokio::{
11    fs::File,
12    io::{AsyncReadExt, AsyncSeekExt},
13};
14use tokio_util::io::ReaderStream;
15
16/// Encapsulate the file stream.
17///
18/// The encapsulated file stream construct requires passing in a stream.
19///
20/// # Examples
21///
22/// ```
23/// use axum::{
24///     http::StatusCode,
25///     response::{IntoResponse, Response},
26///     routing::get,
27///     Router,
28/// };
29/// use axum_extra::response::file_stream::FileStream;
30/// use tokio::fs::File;
31/// use tokio_util::io::ReaderStream;
32///
33/// async fn file_stream() -> Result<Response, (StatusCode, String)> {
34///     let file = File::open("test.txt")
35///         .await
36///         .map_err(|e| (StatusCode::NOT_FOUND, format!("File not found: {e}")))?;
37///
38///     let stream = ReaderStream::new(file);
39///     let file_stream_resp = FileStream::new(stream).file_name("test.txt");
40///
41///     Ok(file_stream_resp.into_response())
42/// }
43///
44/// let app = Router::new().route("/file-stream", get(file_stream));
45/// # let _: Router = app;
46/// ```
47#[derive(Debug)]
48pub struct FileStream<S> {
49    /// stream.
50    pub stream: S,
51    /// The file name of the file.
52    pub file_name: Option<String>,
53    /// The size of the file.
54    pub content_size: Option<u64>,
55}
56
57impl<S> FileStream<S>
58where
59    S: TryStream + Send + 'static,
60    S::Ok: Into<Bytes>,
61    S::Error: Into<BoxError>,
62{
63    /// Create a new [`FileStream`]
64    pub fn new(stream: S) -> Self {
65        Self {
66            stream,
67            file_name: None,
68            content_size: None,
69        }
70    }
71
72    /// Create a [`FileStream`] from a file path.
73    ///
74    /// # Examples
75    ///
76    /// ```
77    /// use axum::{
78    ///     http::StatusCode,
79    ///     response::IntoResponse,
80    ///     Router,
81    ///     routing::get
82    /// };
83    /// use axum_extra::response::file_stream::FileStream;
84    /// use tokio::fs::File;
85    /// use tokio_util::io::ReaderStream;
86    ///
87    /// async fn file_stream() -> impl IntoResponse {
88    ///     FileStream::<ReaderStream<File>>::from_path("test.txt")
89    ///         .await
90    ///         .map_err(|e| (StatusCode::NOT_FOUND, format!("File not found: {e}")))
91    /// }
92    ///
93    /// let app = Router::new().route("/file-stream", get(file_stream));
94    /// # let _: Router = app;
95    /// ```
96    pub async fn from_path(path: impl AsRef<Path>) -> io::Result<FileStream<ReaderStream<File>>> {
97        let file = File::open(&path).await?;
98        let mut content_size = None;
99        let mut file_name = None;
100
101        if let Ok(metadata) = file.metadata().await {
102            content_size = Some(metadata.len());
103        }
104
105        if let Some(file_name_os) = path.as_ref().file_name() {
106            if let Some(file_name_str) = file_name_os.to_str() {
107                file_name = Some(file_name_str.to_owned());
108            }
109        }
110
111        Ok(FileStream {
112            stream: ReaderStream::new(file),
113            file_name,
114            content_size,
115        })
116    }
117
118    /// Set the file name of the [`FileStream`].
119    ///
120    /// This adds the attachment `Content-Disposition` header with the given `file_name`.
121    pub fn file_name(mut self, file_name: impl Into<String>) -> Self {
122        self.file_name = Some(file_name.into());
123        self
124    }
125
126    /// Set the size of the file.
127    pub fn content_size(mut self, len: u64) -> Self {
128        self.content_size = Some(len);
129        self
130    }
131
132    /// Return a range response.
133    ///
134    /// range: (start, end, total_size)
135    ///
136    /// # Examples
137    ///
138    /// ```
139    /// use axum::{
140    ///     http::StatusCode,
141    ///     response::IntoResponse,
142    ///     routing::get,
143    ///     Router,
144    /// };
145    /// use axum_extra::response::file_stream::FileStream;
146    /// use tokio::fs::File;
147    /// use tokio::io::AsyncSeekExt;
148    /// use tokio_util::io::ReaderStream;
149    ///
150    /// async fn range_response() -> Result<impl IntoResponse, (StatusCode, String)> {
151    ///     let mut file = File::open("test.txt")
152    ///         .await
153    ///         .map_err(|e| (StatusCode::NOT_FOUND, format!("File not found: {e}")))?;
154    ///     let mut file_size = file
155    ///         .metadata()
156    ///         .await
157    ///         .map_err(|e| (StatusCode::NOT_FOUND, format!("Get file size: {e}")))?
158    ///         .len();
159    ///
160    ///     file.seek(std::io::SeekFrom::Start(10))
161    ///         .await
162    ///         .map_err(|e| (StatusCode::NOT_FOUND, format!("File seek error: {e}")))?;
163    ///     let stream = ReaderStream::new(file);
164    ///
165    ///     Ok(FileStream::new(stream).into_range_response(10, file_size - 1, file_size))
166    /// }
167    ///
168    /// let app = Router::new().route("/file-stream", get(range_response));
169    /// # let _: Router = app;
170    /// ```
171    pub fn into_range_response(self, start: u64, end: u64, total_size: u64) -> Response {
172        let mut resp = Response::builder().header(header::CONTENT_TYPE, "application/octet-stream");
173        resp = resp.status(StatusCode::PARTIAL_CONTENT);
174
175        resp = resp.header(
176            header::CONTENT_RANGE,
177            format!("bytes {start}-{end}/{total_size}"),
178        );
179
180        resp.body(body::Body::from_stream(self.stream))
181            .unwrap_or_else(|e| {
182                (
183                    StatusCode::INTERNAL_SERVER_ERROR,
184                    format!("build FileStream response error: {e}"),
185                )
186                    .into_response()
187            })
188    }
189
190    /// Attempts to return RANGE requests directly from the file path.
191    ///
192    /// # Arguments
193    ///
194    /// * `file_path` - The path of the file to be streamed
195    /// * `start` - The start position of the range
196    /// * `end` - The end position of the range
197    ///
198    /// # Note
199    ///
200    /// * If `end` is 0, then it is used as `file_size - 1`
201    /// * If `start` > `file_size` or `start` > `end`, then `Range Not Satisfiable` is returned
202    ///
203    /// # Examples
204    ///
205    /// ```
206    /// use axum::{
207    ///     http::StatusCode,
208    ///     response::IntoResponse,
209    ///     Router,
210    ///     routing::get
211    /// };
212    /// use std::path::Path;
213    /// use axum_extra::response::file_stream::FileStream;
214    /// use tokio::fs::File;
215    /// use tokio_util::io::ReaderStream;
216    /// use tokio::io::AsyncSeekExt;
217    ///
218    /// async fn range_stream() -> impl IntoResponse {
219    ///     let range_start = 0;
220    ///     let range_end = 1024;
221    ///
222    ///     FileStream::<ReaderStream<File>>::try_range_response("CHANGELOG.md", range_start, range_end).await
223    ///         .map_err(|e| (StatusCode::NOT_FOUND, format!("File not found: {e}")))
224    /// }
225    ///
226    /// let app = Router::new().route("/file-stream", get(range_stream));
227    /// # let _: Router = app;
228    /// ```
229    pub async fn try_range_response(
230        file_path: impl AsRef<Path>,
231        start: u64,
232        mut end: u64,
233    ) -> io::Result<Response> {
234        let mut file = File::open(file_path).await?;
235
236        let metadata = file.metadata().await?;
237        let total_size = metadata.len();
238
239        if end == 0 {
240            end = total_size - 1;
241        }
242
243        if start > total_size {
244            return Ok((StatusCode::RANGE_NOT_SATISFIABLE, "Range Not Satisfiable").into_response());
245        }
246        if start > end {
247            return Ok((StatusCode::RANGE_NOT_SATISFIABLE, "Range Not Satisfiable").into_response());
248        }
249        if end >= total_size {
250            return Ok((StatusCode::RANGE_NOT_SATISFIABLE, "Range Not Satisfiable").into_response());
251        }
252
253        file.seek(std::io::SeekFrom::Start(start)).await?;
254
255        let stream = ReaderStream::new(file.take(end - start + 1));
256
257        Ok(FileStream::new(stream).into_range_response(start, end, total_size))
258    }
259}
260
261impl<S> IntoResponse for FileStream<S>
262where
263    S: TryStream + Send + 'static,
264    S::Ok: Into<Bytes>,
265    S::Error: Into<BoxError>,
266{
267    fn into_response(self) -> Response {
268        let mut resp = Response::builder().header(header::CONTENT_TYPE, "application/octet-stream");
269
270        if let Some(file_name) = self.file_name {
271            resp = resp.header(
272                header::CONTENT_DISPOSITION,
273                format!("attachment; filename=\"{file_name}\""),
274            );
275        }
276
277        if let Some(content_size) = self.content_size {
278            resp = resp.header(header::CONTENT_LENGTH, content_size);
279        }
280
281        resp.body(body::Body::from_stream(self.stream))
282            .unwrap_or_else(|e| {
283                (
284                    StatusCode::INTERNAL_SERVER_ERROR,
285                    format!("build FileStream responsec error: {e}"),
286                )
287                    .into_response()
288            })
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295    use axum::{extract::Request, routing::get, Router};
296    use body::Body;
297    use http::HeaderMap;
298    use http_body_util::BodyExt;
299    use std::io::Cursor;
300    use tokio_util::io::ReaderStream;
301    use tower::ServiceExt;
302
303    #[tokio::test]
304    async fn response() -> Result<(), Box<dyn std::error::Error>> {
305        let app = Router::new().route(
306            "/file",
307            get(|| async {
308                // Simulating a file stream
309                let file_content = b"Hello, this is the simulated file content!".to_vec();
310                let reader = Cursor::new(file_content);
311
312                // Response file stream
313                // Content size and file name are not attached by default
314                let stream = ReaderStream::new(reader);
315                FileStream::new(stream).into_response()
316            }),
317        );
318
319        // Simulating a GET request
320        let response = app
321            .oneshot(Request::builder().uri("/file").body(Body::empty())?)
322            .await?;
323
324        // Validate Response Status Code
325        assert_eq!(response.status(), StatusCode::OK);
326
327        // Validate Response Headers
328        assert_eq!(
329            response.headers().get("content-type").unwrap(),
330            "application/octet-stream"
331        );
332
333        // Validate Response Body
334        let body: &[u8] = &response.into_body().collect().await?.to_bytes();
335        assert_eq!(
336            std::str::from_utf8(body)?,
337            "Hello, this is the simulated file content!"
338        );
339        Ok(())
340    }
341
342    #[tokio::test]
343    async fn response_not_set_filename() -> Result<(), Box<dyn std::error::Error>> {
344        let app = Router::new().route(
345            "/file",
346            get(|| async {
347                // Simulating a file stream
348                let file_content = b"Hello, this is the simulated file content!".to_vec();
349                let size = file_content.len() as u64;
350                let reader = Cursor::new(file_content);
351
352                // Response file stream
353                let stream = ReaderStream::new(reader);
354                FileStream::new(stream).content_size(size).into_response()
355            }),
356        );
357
358        // Simulating a GET request
359        let response = app
360            .oneshot(Request::builder().uri("/file").body(Body::empty())?)
361            .await?;
362
363        // Validate Response Status Code
364        assert_eq!(response.status(), StatusCode::OK);
365
366        // Validate Response Headers
367        assert_eq!(
368            response.headers().get("content-type").unwrap(),
369            "application/octet-stream"
370        );
371        assert_eq!(response.headers().get("content-length").unwrap(), "42");
372
373        // Validate Response Body
374        let body: &[u8] = &response.into_body().collect().await?.to_bytes();
375        assert_eq!(
376            std::str::from_utf8(body)?,
377            "Hello, this is the simulated file content!"
378        );
379        Ok(())
380    }
381
382    #[tokio::test]
383    async fn response_not_set_content_size() -> Result<(), Box<dyn std::error::Error>> {
384        let app = Router::new().route(
385            "/file",
386            get(|| async {
387                // Simulating a file stream
388                let file_content = b"Hello, this is the simulated file content!".to_vec();
389                let reader = Cursor::new(file_content);
390
391                // Response file stream
392                let stream = ReaderStream::new(reader);
393                FileStream::new(stream).file_name("test").into_response()
394            }),
395        );
396
397        // Simulating a GET request
398        let response = app
399            .oneshot(Request::builder().uri("/file").body(Body::empty())?)
400            .await?;
401
402        // Validate Response Status Code
403        assert_eq!(response.status(), StatusCode::OK);
404
405        // Validate Response Headers
406        assert_eq!(
407            response.headers().get("content-type").unwrap(),
408            "application/octet-stream"
409        );
410        assert_eq!(
411            response.headers().get("content-disposition").unwrap(),
412            "attachment; filename=\"test\""
413        );
414
415        // Validate Response Body
416        let body: &[u8] = &response.into_body().collect().await?.to_bytes();
417        assert_eq!(
418            std::str::from_utf8(body)?,
419            "Hello, this is the simulated file content!"
420        );
421        Ok(())
422    }
423
424    #[tokio::test]
425    async fn response_with_content_size_and_filename() -> Result<(), Box<dyn std::error::Error>> {
426        let app = Router::new().route(
427            "/file",
428            get(|| async {
429                // Simulating a file stream
430                let file_content = b"Hello, this is the simulated file content!".to_vec();
431                let size = file_content.len() as u64;
432                let reader = Cursor::new(file_content);
433
434                // Response file stream
435                let stream = ReaderStream::new(reader);
436                FileStream::new(stream)
437                    .file_name("test")
438                    .content_size(size)
439                    .into_response()
440            }),
441        );
442
443        // Simulating a GET request
444        let response = app
445            .oneshot(Request::builder().uri("/file").body(Body::empty())?)
446            .await?;
447
448        // Validate Response Status Code
449        assert_eq!(response.status(), StatusCode::OK);
450
451        // Validate Response Headers
452        assert_eq!(
453            response.headers().get("content-type").unwrap(),
454            "application/octet-stream"
455        );
456        assert_eq!(
457            response.headers().get("content-disposition").unwrap(),
458            "attachment; filename=\"test\""
459        );
460        assert_eq!(response.headers().get("content-length").unwrap(), "42");
461
462        // Validate Response Body
463        let body: &[u8] = &response.into_body().collect().await?.to_bytes();
464        assert_eq!(
465            std::str::from_utf8(body)?,
466            "Hello, this is the simulated file content!"
467        );
468        Ok(())
469    }
470
471    #[tokio::test]
472    async fn response_from_path() -> Result<(), Box<dyn std::error::Error>> {
473        let app = Router::new().route(
474            "/from_path",
475            get(move || async move {
476                FileStream::<ReaderStream<File>>::from_path(Path::new("CHANGELOG.md"))
477                    .await
478                    .unwrap()
479                    .into_response()
480            }),
481        );
482
483        // Simulating a GET request
484        let response = app
485            .oneshot(
486                Request::builder()
487                    .uri("/from_path")
488                    .body(Body::empty())
489                    .unwrap(),
490            )
491            .await
492            .unwrap();
493
494        // Validate Response Status Code
495        assert_eq!(response.status(), StatusCode::OK);
496
497        // Validate Response Headers
498        assert_eq!(
499            response.headers().get("content-type").unwrap(),
500            "application/octet-stream"
501        );
502        assert_eq!(
503            response.headers().get("content-disposition").unwrap(),
504            "attachment; filename=\"CHANGELOG.md\""
505        );
506
507        let file = File::open("CHANGELOG.md").await.unwrap();
508        // get file size
509        let content_length = file.metadata().await.unwrap().len();
510
511        assert_eq!(
512            response
513                .headers()
514                .get("content-length")
515                .unwrap()
516                .to_str()
517                .unwrap(),
518            content_length.to_string()
519        );
520        Ok(())
521    }
522
523    #[tokio::test]
524    async fn response_range_file() -> Result<(), Box<dyn std::error::Error>> {
525        let app = Router::new().route("/range_response", get(range_stream));
526
527        // Simulating a GET request
528        let response = app
529            .oneshot(
530                Request::builder()
531                    .uri("/range_response")
532                    .header(header::RANGE, "bytes=20-1000")
533                    .body(Body::empty())
534                    .unwrap(),
535            )
536            .await
537            .unwrap();
538
539        // Validate Response Status Code
540        assert_eq!(response.status(), StatusCode::PARTIAL_CONTENT);
541
542        // Validate Response Headers
543        assert_eq!(
544            response.headers().get("content-type").unwrap(),
545            "application/octet-stream"
546        );
547
548        let file = File::open("CHANGELOG.md").await.unwrap();
549        // get file size
550        let content_length = file.metadata().await.unwrap().len();
551
552        assert_eq!(
553            response
554                .headers()
555                .get("content-range")
556                .unwrap()
557                .to_str()
558                .unwrap(),
559            format!("bytes 20-1000/{content_length}")
560        );
561        Ok(())
562    }
563
564    async fn range_stream(headers: HeaderMap) -> Response {
565        let range_header = headers
566            .get(header::RANGE)
567            .and_then(|value| value.to_str().ok());
568
569        let (start, end) = if let Some(range) = range_header {
570            if let Some(range) = parse_range_header(range) {
571                range
572            } else {
573                return (StatusCode::RANGE_NOT_SATISFIABLE, "Invalid Range").into_response();
574            }
575        } else {
576            (0, 0) // default range end = 0, if end = 0 end == file size - 1
577        };
578
579        FileStream::<ReaderStream<File>>::try_range_response(Path::new("CHANGELOG.md"), start, end)
580            .await
581            .unwrap()
582    }
583
584    fn parse_range_header(range: &str) -> Option<(u64, u64)> {
585        let range = range.strip_prefix("bytes=")?;
586        let mut parts = range.split('-');
587        let start = parts.next()?.parse::<u64>().ok()?;
588        let end = parts
589            .next()
590            .and_then(|s| s.parse::<u64>().ok())
591            .unwrap_or(0);
592        if start > end {
593            return None;
594        }
595        Some((start, end))
596    }
597}