1use std::future::{poll_fn, IntoFuture};
2
3use futures_util::StreamExt as _;
4use titan_core::{Respondable, Service};
5
6use titan_http::{
7 body::Body,
8 header::{HeaderValue, CONTENT_LENGTH},
9 HttpRequestExt, HttpResponseExt, Request,
10};
11use tokio::{
12 io::{self, AsyncWriteExt as _, BufReader},
13 net::{TcpListener, TcpStream},
14};
15
16use crate::utils::{self};
17
18pub fn serve<S>(listener: TcpListener, service: S) -> Serve<S>
87where
88 S: titan_core::Service<Request> + Send + Clone + 'static,
89 S::Future: Send,
90 S::Response: Respondable,
91 S::Error: Respondable,
92{
93 Serve { listener, service }
94}
95
96pub struct Serve<S> {
97 listener: TcpListener,
98 service: S,
99}
100
101fn is_connection_error(e: &io::Error) -> bool {
102 matches!(
103 e.kind(),
104 io::ErrorKind::ConnectionRefused
105 | io::ErrorKind::ConnectionAborted
106 | io::ErrorKind::ConnectionReset
107 )
108}
109
110impl<S> Serve<S> {
111 async fn tcp_accept(listener: &TcpListener) -> Option<TcpStream> {
112 match listener.accept().await {
113 Ok(conn) => Some(conn.0),
114 Err(e) => {
115 if !is_connection_error(&e) {
116 eprintln!("Accept error: {e}");
117 }
118 None
119 }
120 }
121 }
122}
123
124impl<S> IntoFuture for Serve<S>
125where
126 S: Service<Request> + 'static + Send + Clone,
127 S::Future: Send,
128 S::Error: Respondable,
129 S::Response: Respondable,
130{
131 type Output = io::Result<()>;
132 type IntoFuture = private::ServeFuture;
133
134 fn into_future(self) -> Self::IntoFuture {
135 private::ServeFuture(Box::pin(async move {
136 let Self { mut service, listener } = self;
137 loop {
138 let mut tcp_stream = match Self::tcp_accept(&listener).await {
139 Some(conn) => conn,
140 None => continue,
141 };
142
143 if poll_fn(|cx| service.poll_ready(cx)).await.is_err() {
144 eprintln!("Skipping running because poll_ready failed");
145 continue;
146 }
147
148 let mut buf_reader = BufReader::new(&mut tcp_stream);
149 let http_headers =
150 utils::read_request(&mut buf_reader).await.join("\n");
151
152 let req_empty_body = HttpRequestExt::from(http_headers).0;
153 let body_length = req_empty_body
154 .headers()
155 .get(CONTENT_LENGTH)
156 .unwrap_or(&HeaderValue::from(0))
157 .to_str()
158 .unwrap()
159 .parse()
160 .unwrap();
161
162 let req =
163 utils::fill_req_body(req_empty_body, body_length, buf_reader).await;
164 let nice_service = service.clone();
165 let mut nice_service = std::mem::replace(&mut service, nice_service);
166 tokio::spawn(async move {
167 #[allow(unused_mut)]
168 let mut response = match nice_service.call(req).await {
169 Ok(result) => result.respond(),
170 Err(result) => result.respond(),
171 };
172
173 #[cfg(feature = "date-header")]
174 {
175 use titan_http::header::HeaderName;
176 response.headers_mut().extend([(
177 HeaderName::from_static("date"),
178 HeaderValue::from_str(&chrono::Utc::now()
179 .format("%a, %d %b %Y %H:%M:%S GMT")
180 .to_string()())
181 .unwrap(),
182 )]);
183 }
184
185 let (parts, body) = HttpResponseExt(response).parse_parts();
186
187 tcp_stream.write_all(parts.as_bytes()).await.unwrap();
188 tcp_stream.write_all(b"\r\n").await.unwrap();
189
190 match body {
191 Body::Full(body) => {
192 tcp_stream.write_all(&body).await.unwrap();
193 }
194 Body::Stream(stream) => {
195 futures_util::pin_mut!(stream);
196
197 while let Some(chunk) = stream.next().await {
198 tcp_stream.write_all(&chunk).await.unwrap();
199 tcp_stream.flush().await.unwrap();
200 }
201 tcp_stream.shutdown().await.unwrap();
202 }
203 }
204 });
205 }
206 }))
207 }
208}
209
210mod private {
211 use std::{
212 future::Future,
213 io,
214 pin::Pin,
215 task::{Context, Poll},
216 };
217
218 pub struct ServeFuture(
219 pub(super) Pin<Box<dyn Future<Output = io::Result<()>> + 'static>>,
220 );
221
222 impl Future for ServeFuture {
223 type Output = io::Result<()>;
224
225 #[inline]
226 fn poll(
227 mut self: Pin<&mut Self>,
228 cx: &mut Context<'_>,
229 ) -> Poll<Self::Output> {
230 self.0.as_mut().poll(cx)
231 }
232 }
233
234 impl std::fmt::Debug for ServeFuture {
235 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236 f.debug_struct("ServeFuture").finish_non_exhaustive()
237 }
238 }
239}