tower_http/follow_redirect/
mod.rs

1//! Middleware for following redirections.
2//!
3//! # Overview
4//!
5//! The [`FollowRedirect`] middleware retries requests with the inner [`Service`] to follow HTTP
6//! redirections.
7//!
8//! The middleware tries to clone the original [`Request`] when making a redirected request.
9//! However, since [`Extensions`][http::Extensions] are `!Clone`, any extensions set by outer
10//! middleware will be discarded. Also, the request body cannot always be cloned. When the
11//! original body is known to be empty by [`Body::size_hint`], the middleware uses `Default`
12//! implementation of the body type to create a new request body. If you know that the body can be
13//! cloned in some way, you can tell the middleware to clone it by configuring a [`policy`].
14//!
15//! # Examples
16//!
17//! ## Basic usage
18//!
19//! ```
20//! use http::{Request, Response};
21//! use bytes::Bytes;
22//! use http_body_util::Full;
23//! use tower::{Service, ServiceBuilder, ServiceExt};
24//! use tower_http::follow_redirect::{FollowRedirectLayer, RequestUri};
25//!
26//! # #[tokio::main]
27//! # async fn main() -> Result<(), std::convert::Infallible> {
28//! # let http_client = tower::service_fn(|req: Request<_>| async move {
29//! #     let dest = "https://www.rust-lang.org/";
30//! #     let mut res = http::Response::builder();
31//! #     if req.uri() != dest {
32//! #         res = res
33//! #             .status(http::StatusCode::MOVED_PERMANENTLY)
34//! #             .header(http::header::LOCATION, dest);
35//! #     }
36//! #     Ok::<_, std::convert::Infallible>(res.body(Full::<Bytes>::default()).unwrap())
37//! # });
38//! let mut client = ServiceBuilder::new()
39//!     .layer(FollowRedirectLayer::new())
40//!     .service(http_client);
41//!
42//! let request = Request::builder()
43//!     .uri("https://rust-lang.org/")
44//!     .body(Full::<Bytes>::default())
45//!     .unwrap();
46//!
47//! let response = client.ready().await?.call(request).await?;
48//! // Get the final request URI.
49//! assert_eq!(response.extensions().get::<RequestUri>().unwrap().0, "https://www.rust-lang.org/");
50//! # Ok(())
51//! # }
52//! ```
53//!
54//! ## Customizing the `Policy`
55//!
56//! You can use a [`Policy`] value to customize how the middleware handles redirections.
57//!
58//! ```
59//! use http::{Request, Response};
60//! use http_body_util::Full;
61//! use bytes::Bytes;
62//! use tower::{Service, ServiceBuilder, ServiceExt};
63//! use tower_http::follow_redirect::{
64//!     policy::{self, PolicyExt},
65//!     FollowRedirectLayer,
66//! };
67//!
68//! #[derive(Debug)]
69//! enum MyError {
70//!     TooManyRedirects,
71//!     Other(tower::BoxError),
72//! }
73//!
74//! # #[tokio::main]
75//! # async fn main() -> Result<(), MyError> {
76//! # let http_client =
77//! #     tower::service_fn(|_: Request<Full<Bytes>>| async { Ok(Response::new(Full::<Bytes>::default())) });
78//! let policy = policy::Limited::new(10) // Set the maximum number of redirections to 10.
79//!     // Return an error when the limit was reached.
80//!     .or::<_, (), _>(policy::redirect_fn(|_| Err(MyError::TooManyRedirects)))
81//!     // Do not follow cross-origin redirections, and return the redirection responses as-is.
82//!     .and::<_, (), _>(policy::SameOrigin::new());
83//!
84//! let mut client = ServiceBuilder::new()
85//!     .layer(FollowRedirectLayer::with_policy(policy))
86//!     .map_err(MyError::Other)
87//!     .service(http_client);
88//!
89//! // ...
90//! # let _ = client.ready().await?.call(Request::default()).await?;
91//! # Ok(())
92//! # }
93//! ```
94
95pub mod policy;
96
97use self::policy::{Action, Attempt, Policy, Standard};
98use futures_util::future::Either;
99use http::{
100    header::LOCATION, HeaderMap, HeaderValue, Method, Request, Response, StatusCode, Uri, Version,
101};
102use http_body::Body;
103use iri_string::types::{UriAbsoluteString, UriReferenceStr};
104use pin_project_lite::pin_project;
105use std::{
106    convert::TryFrom,
107    future::Future,
108    mem,
109    pin::Pin,
110    str,
111    task::{ready, Context, Poll},
112};
113use tower::util::Oneshot;
114use tower_layer::Layer;
115use tower_service::Service;
116
117/// [`Layer`] for retrying requests with a [`Service`] to follow redirection responses.
118///
119/// See the [module docs](self) for more details.
120#[derive(Clone, Copy, Debug, Default)]
121pub struct FollowRedirectLayer<P = Standard> {
122    policy: P,
123}
124
125impl FollowRedirectLayer {
126    /// Create a new [`FollowRedirectLayer`] with a [`Standard`] redirection policy.
127    pub fn new() -> Self {
128        Self::default()
129    }
130}
131
132impl<P> FollowRedirectLayer<P> {
133    /// Create a new [`FollowRedirectLayer`] with the given redirection [`Policy`].
134    pub fn with_policy(policy: P) -> Self {
135        FollowRedirectLayer { policy }
136    }
137}
138
139impl<S, P> Layer<S> for FollowRedirectLayer<P>
140where
141    S: Clone,
142    P: Clone,
143{
144    type Service = FollowRedirect<S, P>;
145
146    fn layer(&self, inner: S) -> Self::Service {
147        FollowRedirect::with_policy(inner, self.policy.clone())
148    }
149}
150
151/// Middleware that retries requests with a [`Service`] to follow redirection responses.
152///
153/// See the [module docs](self) for more details.
154#[derive(Clone, Copy, Debug)]
155pub struct FollowRedirect<S, P = Standard> {
156    inner: S,
157    policy: P,
158}
159
160impl<S> FollowRedirect<S> {
161    /// Create a new [`FollowRedirect`] with a [`Standard`] redirection policy.
162    pub fn new(inner: S) -> Self {
163        Self::with_policy(inner, Standard::default())
164    }
165
166    /// Returns a new [`Layer`] that wraps services with a `FollowRedirect` middleware.
167    ///
168    /// [`Layer`]: tower_layer::Layer
169    pub fn layer() -> FollowRedirectLayer {
170        FollowRedirectLayer::new()
171    }
172}
173
174impl<S, P> FollowRedirect<S, P>
175where
176    P: Clone,
177{
178    /// Create a new [`FollowRedirect`] with the given redirection [`Policy`].
179    pub fn with_policy(inner: S, policy: P) -> Self {
180        FollowRedirect { inner, policy }
181    }
182
183    /// Returns a new [`Layer`] that wraps services with a `FollowRedirect` middleware
184    /// with the given redirection [`Policy`].
185    ///
186    /// [`Layer`]: tower_layer::Layer
187    pub fn layer_with_policy(policy: P) -> FollowRedirectLayer<P> {
188        FollowRedirectLayer::with_policy(policy)
189    }
190
191    define_inner_service_accessors!();
192}
193
194impl<ReqBody, ResBody, S, P> Service<Request<ReqBody>> for FollowRedirect<S, P>
195where
196    S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
197    ReqBody: Body + Default,
198    P: Policy<ReqBody, S::Error> + Clone,
199{
200    type Response = Response<ResBody>;
201    type Error = S::Error;
202    type Future = ResponseFuture<S, ReqBody, P>;
203
204    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
205        self.inner.poll_ready(cx)
206    }
207
208    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
209        let service = self.inner.clone();
210        let mut service = mem::replace(&mut self.inner, service);
211        let mut policy = self.policy.clone();
212        let mut body = BodyRepr::None;
213        body.try_clone_from(req.body(), &policy);
214        policy.on_request(&mut req);
215        ResponseFuture {
216            method: req.method().clone(),
217            uri: req.uri().clone(),
218            version: req.version(),
219            headers: req.headers().clone(),
220            body,
221            future: Either::Left(service.call(req)),
222            service,
223            policy,
224        }
225    }
226}
227
228pin_project! {
229    /// Response future for [`FollowRedirect`].
230    #[derive(Debug)]
231    pub struct ResponseFuture<S, B, P>
232    where
233        S: Service<Request<B>>,
234    {
235        #[pin]
236        future: Either<S::Future, Oneshot<S, Request<B>>>,
237        service: S,
238        policy: P,
239        method: Method,
240        uri: Uri,
241        version: Version,
242        headers: HeaderMap<HeaderValue>,
243        body: BodyRepr<B>,
244    }
245}
246
247impl<S, ReqBody, ResBody, P> Future for ResponseFuture<S, ReqBody, P>
248where
249    S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
250    ReqBody: Body + Default,
251    P: Policy<ReqBody, S::Error>,
252{
253    type Output = Result<Response<ResBody>, S::Error>;
254
255    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
256        let mut this = self.project();
257        let mut res = ready!(this.future.as_mut().poll(cx)?);
258        res.extensions_mut().insert(RequestUri(this.uri.clone()));
259
260        match res.status() {
261            StatusCode::MOVED_PERMANENTLY | StatusCode::FOUND => {
262                // User agents MAY change the request method from POST to GET
263                // (RFC 7231 section 6.4.2. and 6.4.3.).
264                if *this.method == Method::POST {
265                    *this.method = Method::GET;
266                    *this.body = BodyRepr::Empty;
267                }
268            }
269            StatusCode::SEE_OTHER => {
270                // A user agent can perform a GET or HEAD request (RFC 7231 section 6.4.4.).
271                if *this.method != Method::HEAD {
272                    *this.method = Method::GET;
273                }
274                *this.body = BodyRepr::Empty;
275            }
276            StatusCode::TEMPORARY_REDIRECT | StatusCode::PERMANENT_REDIRECT => {}
277            _ => return Poll::Ready(Ok(res)),
278        };
279
280        let body = if let Some(body) = this.body.take() {
281            body
282        } else {
283            return Poll::Ready(Ok(res));
284        };
285
286        let location = res
287            .headers()
288            .get(&LOCATION)
289            .and_then(|loc| resolve_uri(str::from_utf8(loc.as_bytes()).ok()?, this.uri));
290        let location = if let Some(loc) = location {
291            loc
292        } else {
293            return Poll::Ready(Ok(res));
294        };
295
296        let attempt = Attempt {
297            status: res.status(),
298            location: &location,
299            previous: this.uri,
300        };
301        match this.policy.redirect(&attempt)? {
302            Action::Follow => {
303                *this.uri = location;
304                this.body.try_clone_from(&body, &this.policy);
305
306                let mut req = Request::new(body);
307                *req.uri_mut() = this.uri.clone();
308                *req.method_mut() = this.method.clone();
309                *req.version_mut() = *this.version;
310                *req.headers_mut() = this.headers.clone();
311                this.policy.on_request(&mut req);
312                this.future
313                    .set(Either::Right(Oneshot::new(this.service.clone(), req)));
314
315                cx.waker().wake_by_ref();
316                Poll::Pending
317            }
318            Action::Stop => Poll::Ready(Ok(res)),
319        }
320    }
321}
322
323/// Response [`Extensions`][http::Extensions] value that represents the effective request URI of
324/// a response returned by a [`FollowRedirect`] middleware.
325///
326/// The value differs from the original request's effective URI if the middleware has followed
327/// redirections.
328#[derive(Clone)]
329pub struct RequestUri(pub Uri);
330
331#[derive(Debug)]
332enum BodyRepr<B> {
333    Some(B),
334    Empty,
335    None,
336}
337
338impl<B> BodyRepr<B>
339where
340    B: Body + Default,
341{
342    fn take(&mut self) -> Option<B> {
343        match mem::replace(self, BodyRepr::None) {
344            BodyRepr::Some(body) => Some(body),
345            BodyRepr::Empty => {
346                *self = BodyRepr::Empty;
347                Some(B::default())
348            }
349            BodyRepr::None => None,
350        }
351    }
352
353    fn try_clone_from<P, E>(&mut self, body: &B, policy: &P)
354    where
355        P: Policy<B, E>,
356    {
357        match self {
358            BodyRepr::Some(_) | BodyRepr::Empty => {}
359            BodyRepr::None => {
360                if let Some(body) = clone_body(policy, body) {
361                    *self = BodyRepr::Some(body);
362                }
363            }
364        }
365    }
366}
367
368fn clone_body<P, B, E>(policy: &P, body: &B) -> Option<B>
369where
370    P: Policy<B, E>,
371    B: Body + Default,
372{
373    if body.size_hint().exact() == Some(0) {
374        Some(B::default())
375    } else {
376        policy.clone_body(body)
377    }
378}
379
380/// Try to resolve a URI reference `relative` against a base URI `base`.
381fn resolve_uri(relative: &str, base: &Uri) -> Option<Uri> {
382    let relative = UriReferenceStr::new(relative).ok()?;
383    let base = UriAbsoluteString::try_from(base.to_string()).ok()?;
384    let uri = relative.resolve_against(&base).to_string();
385    Uri::try_from(uri).ok()
386}
387
388#[cfg(test)]
389mod tests {
390    use super::{policy::*, *};
391    use crate::test_helpers::Body;
392    use http::header::LOCATION;
393    use std::convert::Infallible;
394    use tower::{ServiceBuilder, ServiceExt};
395
396    #[tokio::test]
397    async fn follows() {
398        let svc = ServiceBuilder::new()
399            .layer(FollowRedirectLayer::with_policy(Action::Follow))
400            .buffer(1)
401            .service_fn(handle);
402        let req = Request::builder()
403            .uri("http://example.com/42")
404            .body(Body::empty())
405            .unwrap();
406        let res = svc.oneshot(req).await.unwrap();
407        assert_eq!(*res.body(), 0);
408        assert_eq!(
409            res.extensions().get::<RequestUri>().unwrap().0,
410            "http://example.com/0"
411        );
412    }
413
414    #[tokio::test]
415    async fn stops() {
416        let svc = ServiceBuilder::new()
417            .layer(FollowRedirectLayer::with_policy(Action::Stop))
418            .buffer(1)
419            .service_fn(handle);
420        let req = Request::builder()
421            .uri("http://example.com/42")
422            .body(Body::empty())
423            .unwrap();
424        let res = svc.oneshot(req).await.unwrap();
425        assert_eq!(*res.body(), 42);
426        assert_eq!(
427            res.extensions().get::<RequestUri>().unwrap().0,
428            "http://example.com/42"
429        );
430    }
431
432    #[tokio::test]
433    async fn limited() {
434        let svc = ServiceBuilder::new()
435            .layer(FollowRedirectLayer::with_policy(Limited::new(10)))
436            .buffer(1)
437            .service_fn(handle);
438        let req = Request::builder()
439            .uri("http://example.com/42")
440            .body(Body::empty())
441            .unwrap();
442        let res = svc.oneshot(req).await.unwrap();
443        assert_eq!(*res.body(), 42 - 10);
444        assert_eq!(
445            res.extensions().get::<RequestUri>().unwrap().0,
446            "http://example.com/32"
447        );
448    }
449
450    /// A server with an endpoint `GET /{n}` which redirects to `/{n-1}` unless `n` equals zero,
451    /// returning `n` as the response body.
452    async fn handle<B>(req: Request<B>) -> Result<Response<u64>, Infallible> {
453        let n: u64 = req.uri().path()[1..].parse().unwrap();
454        let mut res = Response::builder();
455        if n > 0 {
456            res = res
457                .status(StatusCode::MOVED_PERMANENTLY)
458                .header(LOCATION, format!("/{}", n - 1));
459        }
460        Ok::<_, Infallible>(res.body(n).unwrap())
461    }
462}