tower_http/auth/
require_authorization.rs

1//! Authorize requests using [`ValidateRequest`].
2//!
3//! [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
4//!
5//! # Example
6//!
7//! ```
8//! use tower_http::validate_request::{ValidateRequest, ValidateRequestHeader, ValidateRequestHeaderLayer};
9//! use http::{Request, Response, StatusCode, header::AUTHORIZATION};
10//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError};
11//! use bytes::Bytes;
12//! use http_body_util::Full;
13//!
14//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> {
15//!     Ok(Response::new(Full::default()))
16//! }
17//!
18//! # #[tokio::main]
19//! # async fn main() -> Result<(), BoxError> {
20//! let mut service = ServiceBuilder::new()
21//!     // Require the `Authorization` header to be `Bearer passwordlol`
22//!     .layer(ValidateRequestHeaderLayer::bearer("passwordlol"))
23//!     .service_fn(handle);
24//!
25//! // Requests with the correct token are allowed through
26//! let request = Request::builder()
27//!     .header(AUTHORIZATION, "Bearer passwordlol")
28//!     .body(Full::default())
29//!     .unwrap();
30//!
31//! let response = service
32//!     .ready()
33//!     .await?
34//!     .call(request)
35//!     .await?;
36//!
37//! assert_eq!(StatusCode::OK, response.status());
38//!
39//! // Requests with an invalid token get a `401 Unauthorized` response
40//! let request = Request::builder()
41//!     .body(Full::default())
42//!     .unwrap();
43//!
44//! let response = service
45//!     .ready()
46//!     .await?
47//!     .call(request)
48//!     .await?;
49//!
50//! assert_eq!(StatusCode::UNAUTHORIZED, response.status());
51//! # Ok(())
52//! # }
53//! ```
54//!
55//! Custom validation can be made by implementing [`ValidateRequest`].
56
57use crate::validate_request::{ValidateRequest, ValidateRequestHeader, ValidateRequestHeaderLayer};
58use base64::Engine as _;
59use http::{
60    header::{self, HeaderValue},
61    Request, Response, StatusCode,
62};
63use std::{fmt, marker::PhantomData};
64
65const BASE64: base64::engine::GeneralPurpose = base64::engine::general_purpose::STANDARD;
66
67impl<S, ResBody> ValidateRequestHeader<S, Basic<ResBody>> {
68    /// Authorize requests using a username and password pair.
69    ///
70    /// The `Authorization` header is required to be `Basic {credentials}` where `credentials` is
71    /// `base64_encode("{username}:{password}")`.
72    ///
73    /// Since the username and password is sent in clear text it is recommended to use HTTPS/TLS
74    /// with this method. However use of HTTPS/TLS is not enforced by this middleware.
75    pub fn basic(inner: S, username: &str, value: &str) -> Self
76    where
77        ResBody: Default,
78    {
79        Self::custom(inner, Basic::new(username, value))
80    }
81}
82
83impl<ResBody> ValidateRequestHeaderLayer<Basic<ResBody>> {
84    /// Authorize requests using a username and password pair.
85    ///
86    /// The `Authorization` header is required to be `Basic {credentials}` where `credentials` is
87    /// `base64_encode("{username}:{password}")`.
88    ///
89    /// Since the username and password is sent in clear text it is recommended to use HTTPS/TLS
90    /// with this method. However use of HTTPS/TLS is not enforced by this middleware.
91    pub fn basic(username: &str, password: &str) -> Self
92    where
93        ResBody: Default,
94    {
95        Self::custom(Basic::new(username, password))
96    }
97}
98
99impl<S, ResBody> ValidateRequestHeader<S, Bearer<ResBody>> {
100    /// Authorize requests using a "bearer token". Commonly used for OAuth 2.
101    ///
102    /// The `Authorization` header is required to be `Bearer {token}`.
103    ///
104    /// # Panics
105    ///
106    /// Panics if the token is not a valid [`HeaderValue`].
107    pub fn bearer(inner: S, token: &str) -> Self
108    where
109        ResBody: Default,
110    {
111        Self::custom(inner, Bearer::new(token))
112    }
113}
114
115impl<ResBody> ValidateRequestHeaderLayer<Bearer<ResBody>> {
116    /// Authorize requests using a "bearer token". Commonly used for OAuth 2.
117    ///
118    /// The `Authorization` header is required to be `Bearer {token}`.
119    ///
120    /// # Panics
121    ///
122    /// Panics if the token is not a valid [`HeaderValue`].
123    pub fn bearer(token: &str) -> Self
124    where
125        ResBody: Default,
126    {
127        Self::custom(Bearer::new(token))
128    }
129}
130
131/// Type that performs "bearer token" authorization.
132///
133/// See [`ValidateRequestHeader::bearer`] for more details.
134pub struct Bearer<ResBody> {
135    header_value: HeaderValue,
136    _ty: PhantomData<fn() -> ResBody>,
137}
138
139impl<ResBody> Bearer<ResBody> {
140    fn new(token: &str) -> Self
141    where
142        ResBody: Default,
143    {
144        Self {
145            header_value: format!("Bearer {}", token)
146                .parse()
147                .expect("token is not a valid header value"),
148            _ty: PhantomData,
149        }
150    }
151}
152
153impl<ResBody> Clone for Bearer<ResBody> {
154    fn clone(&self) -> Self {
155        Self {
156            header_value: self.header_value.clone(),
157            _ty: PhantomData,
158        }
159    }
160}
161
162impl<ResBody> fmt::Debug for Bearer<ResBody> {
163    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164        f.debug_struct("Bearer")
165            .field("header_value", &self.header_value)
166            .finish()
167    }
168}
169
170impl<B, ResBody> ValidateRequest<B> for Bearer<ResBody>
171where
172    ResBody: Default,
173{
174    type ResponseBody = ResBody;
175
176    fn validate(&mut self, request: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> {
177        match request.headers().get(header::AUTHORIZATION) {
178            Some(actual) if actual == self.header_value => Ok(()),
179            _ => {
180                let mut res = Response::new(ResBody::default());
181                *res.status_mut() = StatusCode::UNAUTHORIZED;
182                Err(res)
183            }
184        }
185    }
186}
187
188/// Type that performs basic authorization.
189///
190/// See [`ValidateRequestHeader::basic`] for more details.
191pub struct Basic<ResBody> {
192    header_value: HeaderValue,
193    _ty: PhantomData<fn() -> ResBody>,
194}
195
196impl<ResBody> Basic<ResBody> {
197    fn new(username: &str, password: &str) -> Self
198    where
199        ResBody: Default,
200    {
201        let encoded = BASE64.encode(format!("{}:{}", username, password));
202        let header_value = format!("Basic {}", encoded).parse().unwrap();
203        Self {
204            header_value,
205            _ty: PhantomData,
206        }
207    }
208}
209
210impl<ResBody> Clone for Basic<ResBody> {
211    fn clone(&self) -> Self {
212        Self {
213            header_value: self.header_value.clone(),
214            _ty: PhantomData,
215        }
216    }
217}
218
219impl<ResBody> fmt::Debug for Basic<ResBody> {
220    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
221        f.debug_struct("Basic")
222            .field("header_value", &self.header_value)
223            .finish()
224    }
225}
226
227impl<B, ResBody> ValidateRequest<B> for Basic<ResBody>
228where
229    ResBody: Default,
230{
231    type ResponseBody = ResBody;
232
233    fn validate(&mut self, request: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> {
234        match request.headers().get(header::AUTHORIZATION) {
235            Some(actual) if actual == self.header_value => Ok(()),
236            _ => {
237                let mut res = Response::new(ResBody::default());
238                *res.status_mut() = StatusCode::UNAUTHORIZED;
239                res.headers_mut()
240                    .insert(header::WWW_AUTHENTICATE, "Basic".parse().unwrap());
241                Err(res)
242            }
243        }
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use crate::validate_request::ValidateRequestHeaderLayer;
250
251    #[allow(unused_imports)]
252    use super::*;
253    use crate::test_helpers::Body;
254    use http::header;
255    use tower::{BoxError, ServiceBuilder, ServiceExt};
256    use tower_service::Service;
257
258    #[tokio::test]
259    async fn valid_basic_token() {
260        let mut service = ServiceBuilder::new()
261            .layer(ValidateRequestHeaderLayer::basic("foo", "bar"))
262            .service_fn(echo);
263
264        let request = Request::get("/")
265            .header(
266                header::AUTHORIZATION,
267                format!("Basic {}", BASE64.encode("foo:bar")),
268            )
269            .body(Body::empty())
270            .unwrap();
271
272        let res = service.ready().await.unwrap().call(request).await.unwrap();
273
274        assert_eq!(res.status(), StatusCode::OK);
275    }
276
277    #[tokio::test]
278    async fn invalid_basic_token() {
279        let mut service = ServiceBuilder::new()
280            .layer(ValidateRequestHeaderLayer::basic("foo", "bar"))
281            .service_fn(echo);
282
283        let request = Request::get("/")
284            .header(
285                header::AUTHORIZATION,
286                format!("Basic {}", BASE64.encode("wrong:credentials")),
287            )
288            .body(Body::empty())
289            .unwrap();
290
291        let res = service.ready().await.unwrap().call(request).await.unwrap();
292
293        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
294
295        let www_authenticate = res.headers().get(header::WWW_AUTHENTICATE).unwrap();
296        assert_eq!(www_authenticate, "Basic");
297    }
298
299    #[tokio::test]
300    async fn valid_bearer_token() {
301        let mut service = ServiceBuilder::new()
302            .layer(ValidateRequestHeaderLayer::bearer("foobar"))
303            .service_fn(echo);
304
305        let request = Request::get("/")
306            .header(header::AUTHORIZATION, "Bearer foobar")
307            .body(Body::empty())
308            .unwrap();
309
310        let res = service.ready().await.unwrap().call(request).await.unwrap();
311
312        assert_eq!(res.status(), StatusCode::OK);
313    }
314
315    #[tokio::test]
316    async fn basic_auth_is_case_sensitive_in_prefix() {
317        let mut service = ServiceBuilder::new()
318            .layer(ValidateRequestHeaderLayer::basic("foo", "bar"))
319            .service_fn(echo);
320
321        let request = Request::get("/")
322            .header(
323                header::AUTHORIZATION,
324                format!("basic {}", BASE64.encode("foo:bar")),
325            )
326            .body(Body::empty())
327            .unwrap();
328
329        let res = service.ready().await.unwrap().call(request).await.unwrap();
330
331        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
332    }
333
334    #[tokio::test]
335    async fn basic_auth_is_case_sensitive_in_value() {
336        let mut service = ServiceBuilder::new()
337            .layer(ValidateRequestHeaderLayer::basic("foo", "bar"))
338            .service_fn(echo);
339
340        let request = Request::get("/")
341            .header(
342                header::AUTHORIZATION,
343                format!("Basic {}", BASE64.encode("Foo:bar")),
344            )
345            .body(Body::empty())
346            .unwrap();
347
348        let res = service.ready().await.unwrap().call(request).await.unwrap();
349
350        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
351    }
352
353    #[tokio::test]
354    async fn invalid_bearer_token() {
355        let mut service = ServiceBuilder::new()
356            .layer(ValidateRequestHeaderLayer::bearer("foobar"))
357            .service_fn(echo);
358
359        let request = Request::get("/")
360            .header(header::AUTHORIZATION, "Bearer wat")
361            .body(Body::empty())
362            .unwrap();
363
364        let res = service.ready().await.unwrap().call(request).await.unwrap();
365
366        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
367    }
368
369    #[tokio::test]
370    async fn bearer_token_is_case_sensitive_in_prefix() {
371        let mut service = ServiceBuilder::new()
372            .layer(ValidateRequestHeaderLayer::bearer("foobar"))
373            .service_fn(echo);
374
375        let request = Request::get("/")
376            .header(header::AUTHORIZATION, "bearer foobar")
377            .body(Body::empty())
378            .unwrap();
379
380        let res = service.ready().await.unwrap().call(request).await.unwrap();
381
382        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
383    }
384
385    #[tokio::test]
386    async fn bearer_token_is_case_sensitive_in_token() {
387        let mut service = ServiceBuilder::new()
388            .layer(ValidateRequestHeaderLayer::bearer("foobar"))
389            .service_fn(echo);
390
391        let request = Request::get("/")
392            .header(header::AUTHORIZATION, "Bearer Foobar")
393            .body(Body::empty())
394            .unwrap();
395
396        let res = service.ready().await.unwrap().call(request).await.unwrap();
397
398        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
399    }
400
401    async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> {
402        Ok(Response::new(req.into_body()))
403    }
404}