axum_extra/
json_lines.rs

1//! Newline delimited JSON extractor and response.
2
3use axum::{
4    body::Body,
5    extract::{FromRequest, Request},
6    response::{IntoResponse, Response},
7    BoxError,
8};
9use bytes::{BufMut, BytesMut};
10use futures_util::stream::{BoxStream, Stream, TryStream, TryStreamExt};
11use pin_project_lite::pin_project;
12use serde::{de::DeserializeOwned, Serialize};
13use std::{
14    convert::Infallible,
15    io::{self, Write},
16    marker::PhantomData,
17    pin::Pin,
18    task::{Context, Poll},
19};
20use tokio::io::AsyncBufReadExt;
21use tokio_stream::wrappers::LinesStream;
22use tokio_util::io::StreamReader;
23
24pin_project! {
25    /// A stream of newline delimited JSON.
26    ///
27    /// This can be used both as an extractor and as a response.
28    ///
29    /// # As extractor
30    ///
31    /// ```rust
32    /// use axum_extra::json_lines::JsonLines;
33    /// use futures_util::stream::StreamExt;
34    ///
35    /// async fn handler(mut stream: JsonLines<serde_json::Value>) {
36    ///     while let Some(value) = stream.next().await {
37    ///         // ...
38    ///     }
39    /// }
40    /// ```
41    ///
42    /// # As response
43    ///
44    /// ```rust
45    /// use axum::{BoxError, response::{IntoResponse, Response}};
46    /// use axum_extra::json_lines::JsonLines;
47    /// use futures_util::stream::Stream;
48    ///
49    /// fn stream_of_values() -> impl Stream<Item = Result<serde_json::Value, BoxError>> {
50    ///     # futures_util::stream::empty()
51    /// }
52    ///
53    /// async fn handler() -> Response {
54    ///     JsonLines::new(stream_of_values()).into_response()
55    /// }
56    /// ```
57    // we use `AsExtractor` as the default because you're more likely to name this type if it's used
58    // as an extractor
59    #[must_use]
60    pub struct JsonLines<S, T = AsExtractor> {
61        #[pin]
62        inner: Inner<S>,
63        _marker: PhantomData<T>,
64    }
65}
66
67pin_project! {
68    #[project = InnerProj]
69    enum Inner<S> {
70        Response {
71            #[pin]
72            stream: S,
73        },
74        Extractor {
75            #[pin]
76            stream: BoxStream<'static, Result<S, axum::Error>>,
77        },
78    }
79}
80
81/// Marker type used to prove that an `JsonLines` was constructed via `FromRequest`.
82#[derive(Debug)]
83#[non_exhaustive]
84pub struct AsExtractor;
85
86/// Marker type used to prove that an `JsonLines` was constructed via `JsonLines::new`.
87#[derive(Debug)]
88#[non_exhaustive]
89pub struct AsResponse;
90
91impl<S> JsonLines<S, AsResponse> {
92    /// Create a new `JsonLines` from a stream of items.
93    pub fn new(stream: S) -> Self {
94        Self {
95            inner: Inner::Response { stream },
96            _marker: PhantomData,
97        }
98    }
99}
100
101impl<S, T> FromRequest<S> for JsonLines<T, AsExtractor>
102where
103    T: DeserializeOwned,
104    S: Send + Sync,
105{
106    type Rejection = Infallible;
107
108    async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
109        // `Stream::lines` isn't a thing so we have to convert it into an `AsyncRead`
110        // so we can call `AsyncRead::lines` and then convert it back to a `Stream`
111        let body = req.into_body();
112        let stream = body.into_data_stream();
113        let stream = stream.map_err(|err| io::Error::new(io::ErrorKind::Other, err));
114        let read = StreamReader::new(stream);
115        let lines_stream = LinesStream::new(read.lines());
116
117        let deserialized_stream =
118            lines_stream
119                .map_err(axum::Error::new)
120                .and_then(|value| async move {
121                    serde_json::from_str::<T>(&value).map_err(axum::Error::new)
122                });
123
124        Ok(Self {
125            inner: Inner::Extractor {
126                stream: Box::pin(deserialized_stream),
127            },
128            _marker: PhantomData,
129        })
130    }
131}
132
133impl<T> Stream for JsonLines<T, AsExtractor> {
134    type Item = Result<T, axum::Error>;
135
136    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
137        match self.project().inner.project() {
138            InnerProj::Extractor { stream } => stream.poll_next(cx),
139            // `JsonLines<_, AsExtractor>` can only be constructed via `FromRequest`
140            // which doesn't use this variant
141            InnerProj::Response { .. } => unreachable!(),
142        }
143    }
144}
145
146impl<S> IntoResponse for JsonLines<S, AsResponse>
147where
148    S: TryStream + Send + 'static,
149    S::Ok: Serialize + Send,
150    S::Error: Into<BoxError>,
151{
152    fn into_response(self) -> Response {
153        let inner = match self.inner {
154            Inner::Response { stream } => stream,
155            // `JsonLines<_, AsResponse>` can only be constructed via `JsonLines::new`
156            // which doesn't use this variant
157            Inner::Extractor { .. } => unreachable!(),
158        };
159
160        let stream = inner.map_err(Into::into).and_then(|value| async move {
161            let mut buf = BytesMut::new().writer();
162            serde_json::to_writer(&mut buf, &value)?;
163            buf.write_all(b"\n")?;
164            Ok::<_, BoxError>(buf.into_inner().freeze())
165        });
166        let stream = Body::from_stream(stream);
167
168        // there is no consensus around mime type yet
169        // https://github.com/wardi/jsonlines/issues/36
170        stream.into_response()
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    use crate::test_helpers::*;
178    use axum::{
179        routing::{get, post},
180        Router,
181    };
182    use futures_util::StreamExt;
183    use http::StatusCode;
184    use serde::Deserialize;
185    use std::error::Error;
186
187    #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)]
188    struct User {
189        id: i32,
190    }
191
192    #[tokio::test]
193    async fn extractor() {
194        let app = Router::new().route(
195            "/",
196            post(|mut stream: JsonLines<User>| async move {
197                assert_eq!(stream.next().await.unwrap().unwrap(), User { id: 1 });
198                assert_eq!(stream.next().await.unwrap().unwrap(), User { id: 2 });
199                assert_eq!(stream.next().await.unwrap().unwrap(), User { id: 3 });
200
201                // sources are downcastable to `serde_json::Error`
202                let err = stream.next().await.unwrap().unwrap_err();
203                let _: &serde_json::Error = err
204                    .source()
205                    .unwrap()
206                    .downcast_ref::<serde_json::Error>()
207                    .unwrap();
208            }),
209        );
210
211        let client = TestClient::new(app);
212
213        let res = client
214            .post("/")
215            .body(
216                [
217                    "{\"id\":1}",
218                    "{\"id\":2}",
219                    "{\"id\":3}",
220                    // to trigger an error for source downcasting
221                    "{\"id\":false}",
222                ]
223                .join("\n"),
224            )
225            .await;
226        assert_eq!(res.status(), StatusCode::OK);
227    }
228
229    #[tokio::test]
230    async fn response() {
231        let app = Router::new().route(
232            "/",
233            get(|| async {
234                let values = futures_util::stream::iter(vec![
235                    Ok::<_, Infallible>(User { id: 1 }),
236                    Ok::<_, Infallible>(User { id: 2 }),
237                    Ok::<_, Infallible>(User { id: 3 }),
238                ]);
239                JsonLines::new(values)
240            }),
241        );
242
243        let client = TestClient::new(app);
244
245        let res = client.get("/").await;
246
247        let values = res
248            .text()
249            .await
250            .lines()
251            .map(|line| serde_json::from_str::<User>(line).unwrap())
252            .collect::<Vec<_>>();
253
254        assert_eq!(
255            values,
256            vec![User { id: 1 }, User { id: 2 }, User { id: 3 },]
257        );
258    }
259}