tower_http/follow_redirect/
mod.rs1pub mod policy;
96
97use self::policy::{Action, Attempt, Policy, Standard};
98use futures_util::future::Either;
99use http::{
100 header::LOCATION, HeaderMap, HeaderValue, Method, Request, Response, StatusCode, Uri, Version,
101};
102use http_body::Body;
103use iri_string::types::{UriAbsoluteString, UriReferenceStr};
104use pin_project_lite::pin_project;
105use std::{
106 convert::TryFrom,
107 future::Future,
108 mem,
109 pin::Pin,
110 str,
111 task::{ready, Context, Poll},
112};
113use tower::util::Oneshot;
114use tower_layer::Layer;
115use tower_service::Service;
116
117#[derive(Clone, Copy, Debug, Default)]
121pub struct FollowRedirectLayer<P = Standard> {
122 policy: P,
123}
124
125impl FollowRedirectLayer {
126 pub fn new() -> Self {
128 Self::default()
129 }
130}
131
132impl<P> FollowRedirectLayer<P> {
133 pub fn with_policy(policy: P) -> Self {
135 FollowRedirectLayer { policy }
136 }
137}
138
139impl<S, P> Layer<S> for FollowRedirectLayer<P>
140where
141 S: Clone,
142 P: Clone,
143{
144 type Service = FollowRedirect<S, P>;
145
146 fn layer(&self, inner: S) -> Self::Service {
147 FollowRedirect::with_policy(inner, self.policy.clone())
148 }
149}
150
151#[derive(Clone, Copy, Debug)]
155pub struct FollowRedirect<S, P = Standard> {
156 inner: S,
157 policy: P,
158}
159
160impl<S> FollowRedirect<S> {
161 pub fn new(inner: S) -> Self {
163 Self::with_policy(inner, Standard::default())
164 }
165
166 pub fn layer() -> FollowRedirectLayer {
170 FollowRedirectLayer::new()
171 }
172}
173
174impl<S, P> FollowRedirect<S, P>
175where
176 P: Clone,
177{
178 pub fn with_policy(inner: S, policy: P) -> Self {
180 FollowRedirect { inner, policy }
181 }
182
183 pub fn layer_with_policy(policy: P) -> FollowRedirectLayer<P> {
188 FollowRedirectLayer::with_policy(policy)
189 }
190
191 define_inner_service_accessors!();
192}
193
194impl<ReqBody, ResBody, S, P> Service<Request<ReqBody>> for FollowRedirect<S, P>
195where
196 S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
197 ReqBody: Body + Default,
198 P: Policy<ReqBody, S::Error> + Clone,
199{
200 type Response = Response<ResBody>;
201 type Error = S::Error;
202 type Future = ResponseFuture<S, ReqBody, P>;
203
204 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
205 self.inner.poll_ready(cx)
206 }
207
208 fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
209 let service = self.inner.clone();
210 let mut service = mem::replace(&mut self.inner, service);
211 let mut policy = self.policy.clone();
212 let mut body = BodyRepr::None;
213 body.try_clone_from(req.body(), &policy);
214 policy.on_request(&mut req);
215 ResponseFuture {
216 method: req.method().clone(),
217 uri: req.uri().clone(),
218 version: req.version(),
219 headers: req.headers().clone(),
220 body,
221 future: Either::Left(service.call(req)),
222 service,
223 policy,
224 }
225 }
226}
227
228pin_project! {
229 #[derive(Debug)]
231 pub struct ResponseFuture<S, B, P>
232 where
233 S: Service<Request<B>>,
234 {
235 #[pin]
236 future: Either<S::Future, Oneshot<S, Request<B>>>,
237 service: S,
238 policy: P,
239 method: Method,
240 uri: Uri,
241 version: Version,
242 headers: HeaderMap<HeaderValue>,
243 body: BodyRepr<B>,
244 }
245}
246
247impl<S, ReqBody, ResBody, P> Future for ResponseFuture<S, ReqBody, P>
248where
249 S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
250 ReqBody: Body + Default,
251 P: Policy<ReqBody, S::Error>,
252{
253 type Output = Result<Response<ResBody>, S::Error>;
254
255 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
256 let mut this = self.project();
257 let mut res = ready!(this.future.as_mut().poll(cx)?);
258 res.extensions_mut().insert(RequestUri(this.uri.clone()));
259
260 match res.status() {
261 StatusCode::MOVED_PERMANENTLY | StatusCode::FOUND => {
262 if *this.method == Method::POST {
265 *this.method = Method::GET;
266 *this.body = BodyRepr::Empty;
267 }
268 }
269 StatusCode::SEE_OTHER => {
270 if *this.method != Method::HEAD {
272 *this.method = Method::GET;
273 }
274 *this.body = BodyRepr::Empty;
275 }
276 StatusCode::TEMPORARY_REDIRECT | StatusCode::PERMANENT_REDIRECT => {}
277 _ => return Poll::Ready(Ok(res)),
278 };
279
280 let body = if let Some(body) = this.body.take() {
281 body
282 } else {
283 return Poll::Ready(Ok(res));
284 };
285
286 let location = res
287 .headers()
288 .get(&LOCATION)
289 .and_then(|loc| resolve_uri(str::from_utf8(loc.as_bytes()).ok()?, this.uri));
290 let location = if let Some(loc) = location {
291 loc
292 } else {
293 return Poll::Ready(Ok(res));
294 };
295
296 let attempt = Attempt {
297 status: res.status(),
298 location: &location,
299 previous: this.uri,
300 };
301 match this.policy.redirect(&attempt)? {
302 Action::Follow => {
303 *this.uri = location;
304 this.body.try_clone_from(&body, &this.policy);
305
306 let mut req = Request::new(body);
307 *req.uri_mut() = this.uri.clone();
308 *req.method_mut() = this.method.clone();
309 *req.version_mut() = *this.version;
310 *req.headers_mut() = this.headers.clone();
311 this.policy.on_request(&mut req);
312 this.future
313 .set(Either::Right(Oneshot::new(this.service.clone(), req)));
314
315 cx.waker().wake_by_ref();
316 Poll::Pending
317 }
318 Action::Stop => Poll::Ready(Ok(res)),
319 }
320 }
321}
322
323#[derive(Clone)]
329pub struct RequestUri(pub Uri);
330
331#[derive(Debug)]
332enum BodyRepr<B> {
333 Some(B),
334 Empty,
335 None,
336}
337
338impl<B> BodyRepr<B>
339where
340 B: Body + Default,
341{
342 fn take(&mut self) -> Option<B> {
343 match mem::replace(self, BodyRepr::None) {
344 BodyRepr::Some(body) => Some(body),
345 BodyRepr::Empty => {
346 *self = BodyRepr::Empty;
347 Some(B::default())
348 }
349 BodyRepr::None => None,
350 }
351 }
352
353 fn try_clone_from<P, E>(&mut self, body: &B, policy: &P)
354 where
355 P: Policy<B, E>,
356 {
357 match self {
358 BodyRepr::Some(_) | BodyRepr::Empty => {}
359 BodyRepr::None => {
360 if let Some(body) = clone_body(policy, body) {
361 *self = BodyRepr::Some(body);
362 }
363 }
364 }
365 }
366}
367
368fn clone_body<P, B, E>(policy: &P, body: &B) -> Option<B>
369where
370 P: Policy<B, E>,
371 B: Body + Default,
372{
373 if body.size_hint().exact() == Some(0) {
374 Some(B::default())
375 } else {
376 policy.clone_body(body)
377 }
378}
379
380fn resolve_uri(relative: &str, base: &Uri) -> Option<Uri> {
382 let relative = UriReferenceStr::new(relative).ok()?;
383 let base = UriAbsoluteString::try_from(base.to_string()).ok()?;
384 let uri = relative.resolve_against(&base).to_string();
385 Uri::try_from(uri).ok()
386}
387
388#[cfg(test)]
389mod tests {
390 use super::{policy::*, *};
391 use crate::test_helpers::Body;
392 use http::header::LOCATION;
393 use std::convert::Infallible;
394 use tower::{ServiceBuilder, ServiceExt};
395
396 #[tokio::test]
397 async fn follows() {
398 let svc = ServiceBuilder::new()
399 .layer(FollowRedirectLayer::with_policy(Action::Follow))
400 .buffer(1)
401 .service_fn(handle);
402 let req = Request::builder()
403 .uri("http://example.com/42")
404 .body(Body::empty())
405 .unwrap();
406 let res = svc.oneshot(req).await.unwrap();
407 assert_eq!(*res.body(), 0);
408 assert_eq!(
409 res.extensions().get::<RequestUri>().unwrap().0,
410 "http://example.com/0"
411 );
412 }
413
414 #[tokio::test]
415 async fn stops() {
416 let svc = ServiceBuilder::new()
417 .layer(FollowRedirectLayer::with_policy(Action::Stop))
418 .buffer(1)
419 .service_fn(handle);
420 let req = Request::builder()
421 .uri("http://example.com/42")
422 .body(Body::empty())
423 .unwrap();
424 let res = svc.oneshot(req).await.unwrap();
425 assert_eq!(*res.body(), 42);
426 assert_eq!(
427 res.extensions().get::<RequestUri>().unwrap().0,
428 "http://example.com/42"
429 );
430 }
431
432 #[tokio::test]
433 async fn limited() {
434 let svc = ServiceBuilder::new()
435 .layer(FollowRedirectLayer::with_policy(Limited::new(10)))
436 .buffer(1)
437 .service_fn(handle);
438 let req = Request::builder()
439 .uri("http://example.com/42")
440 .body(Body::empty())
441 .unwrap();
442 let res = svc.oneshot(req).await.unwrap();
443 assert_eq!(*res.body(), 42 - 10);
444 assert_eq!(
445 res.extensions().get::<RequestUri>().unwrap().0,
446 "http://example.com/32"
447 );
448 }
449
450 async fn handle<B>(req: Request<B>) -> Result<Response<u64>, Infallible> {
453 let n: u64 = req.uri().path()[1..].parse().unwrap();
454 let mut res = Response::builder();
455 if n > 0 {
456 res = res
457 .status(StatusCode::MOVED_PERMANENTLY)
458 .header(LOCATION, format!("/{}", n - 1));
459 }
460 Ok::<_, Infallible>(res.body(n).unwrap())
461 }
462}