1use axum::{
4 body::Body,
5 extract::{FromRequest, Request},
6 response::{IntoResponse, Response},
7 BoxError,
8};
9use bytes::{BufMut, BytesMut};
10use futures_util::stream::{BoxStream, Stream, TryStream, TryStreamExt};
11use pin_project_lite::pin_project;
12use serde::{de::DeserializeOwned, Serialize};
13use std::{
14 convert::Infallible,
15 io::{self, Write},
16 marker::PhantomData,
17 pin::Pin,
18 task::{Context, Poll},
19};
20use tokio::io::AsyncBufReadExt;
21use tokio_stream::wrappers::LinesStream;
22use tokio_util::io::StreamReader;
23
24pin_project! {
25 #[must_use]
60 pub struct JsonLines<S, T = AsExtractor> {
61 #[pin]
62 inner: Inner<S>,
63 _marker: PhantomData<T>,
64 }
65}
66
67pin_project! {
68 #[project = InnerProj]
69 enum Inner<S> {
70 Response {
71 #[pin]
72 stream: S,
73 },
74 Extractor {
75 #[pin]
76 stream: BoxStream<'static, Result<S, axum::Error>>,
77 },
78 }
79}
80
81#[derive(Debug)]
83#[non_exhaustive]
84pub struct AsExtractor;
85
86#[derive(Debug)]
88#[non_exhaustive]
89pub struct AsResponse;
90
91impl<S> JsonLines<S, AsResponse> {
92 pub fn new(stream: S) -> Self {
94 Self {
95 inner: Inner::Response { stream },
96 _marker: PhantomData,
97 }
98 }
99}
100
101impl<S, T> FromRequest<S> for JsonLines<T, AsExtractor>
102where
103 T: DeserializeOwned,
104 S: Send + Sync,
105{
106 type Rejection = Infallible;
107
108 async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
109 let body = req.into_body();
112 let stream = body.into_data_stream();
113 let stream = stream.map_err(|err| io::Error::new(io::ErrorKind::Other, err));
114 let read = StreamReader::new(stream);
115 let lines_stream = LinesStream::new(read.lines());
116
117 let deserialized_stream =
118 lines_stream
119 .map_err(axum::Error::new)
120 .and_then(|value| async move {
121 serde_json::from_str::<T>(&value).map_err(axum::Error::new)
122 });
123
124 Ok(Self {
125 inner: Inner::Extractor {
126 stream: Box::pin(deserialized_stream),
127 },
128 _marker: PhantomData,
129 })
130 }
131}
132
133impl<T> Stream for JsonLines<T, AsExtractor> {
134 type Item = Result<T, axum::Error>;
135
136 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
137 match self.project().inner.project() {
138 InnerProj::Extractor { stream } => stream.poll_next(cx),
139 InnerProj::Response { .. } => unreachable!(),
142 }
143 }
144}
145
146impl<S> IntoResponse for JsonLines<S, AsResponse>
147where
148 S: TryStream + Send + 'static,
149 S::Ok: Serialize + Send,
150 S::Error: Into<BoxError>,
151{
152 fn into_response(self) -> Response {
153 let inner = match self.inner {
154 Inner::Response { stream } => stream,
155 Inner::Extractor { .. } => unreachable!(),
158 };
159
160 let stream = inner.map_err(Into::into).and_then(|value| async move {
161 let mut buf = BytesMut::new().writer();
162 serde_json::to_writer(&mut buf, &value)?;
163 buf.write_all(b"\n")?;
164 Ok::<_, BoxError>(buf.into_inner().freeze())
165 });
166 let stream = Body::from_stream(stream);
167
168 stream.into_response()
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use crate::test_helpers::*;
178 use axum::{
179 routing::{get, post},
180 Router,
181 };
182 use futures_util::StreamExt;
183 use http::StatusCode;
184 use serde::Deserialize;
185 use std::error::Error;
186
187 #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)]
188 struct User {
189 id: i32,
190 }
191
192 #[tokio::test]
193 async fn extractor() {
194 let app = Router::new().route(
195 "/",
196 post(|mut stream: JsonLines<User>| async move {
197 assert_eq!(stream.next().await.unwrap().unwrap(), User { id: 1 });
198 assert_eq!(stream.next().await.unwrap().unwrap(), User { id: 2 });
199 assert_eq!(stream.next().await.unwrap().unwrap(), User { id: 3 });
200
201 let err = stream.next().await.unwrap().unwrap_err();
203 let _: &serde_json::Error = err
204 .source()
205 .unwrap()
206 .downcast_ref::<serde_json::Error>()
207 .unwrap();
208 }),
209 );
210
211 let client = TestClient::new(app);
212
213 let res = client
214 .post("/")
215 .body(
216 [
217 "{\"id\":1}",
218 "{\"id\":2}",
219 "{\"id\":3}",
220 "{\"id\":false}",
222 ]
223 .join("\n"),
224 )
225 .await;
226 assert_eq!(res.status(), StatusCode::OK);
227 }
228
229 #[tokio::test]
230 async fn response() {
231 let app = Router::new().route(
232 "/",
233 get(|| async {
234 let values = futures_util::stream::iter(vec![
235 Ok::<_, Infallible>(User { id: 1 }),
236 Ok::<_, Infallible>(User { id: 2 }),
237 Ok::<_, Infallible>(User { id: 3 }),
238 ]);
239 JsonLines::new(values)
240 }),
241 );
242
243 let client = TestClient::new(app);
244
245 let res = client.get("/").await;
246
247 let values = res
248 .text()
249 .await
250 .lines()
251 .map(|line| serde_json::from_str::<User>(line).unwrap())
252 .collect::<Vec<_>>();
253
254 assert_eq!(
255 values,
256 vec![User { id: 1 }, User { id: 2 }, User { id: 3 },]
257 );
258 }
259}