1use crate::validate_request::{ValidateRequest, ValidateRequestHeader, ValidateRequestHeaderLayer};
58use base64::Engine as _;
59use http::{
60 header::{self, HeaderValue},
61 Request, Response, StatusCode,
62};
63use std::{fmt, marker::PhantomData};
64
65const BASE64: base64::engine::GeneralPurpose = base64::engine::general_purpose::STANDARD;
66
67impl<S, ResBody> ValidateRequestHeader<S, Basic<ResBody>> {
68 pub fn basic(inner: S, username: &str, value: &str) -> Self
76 where
77 ResBody: Default,
78 {
79 Self::custom(inner, Basic::new(username, value))
80 }
81}
82
83impl<ResBody> ValidateRequestHeaderLayer<Basic<ResBody>> {
84 pub fn basic(username: &str, password: &str) -> Self
92 where
93 ResBody: Default,
94 {
95 Self::custom(Basic::new(username, password))
96 }
97}
98
99impl<S, ResBody> ValidateRequestHeader<S, Bearer<ResBody>> {
100 pub fn bearer(inner: S, token: &str) -> Self
108 where
109 ResBody: Default,
110 {
111 Self::custom(inner, Bearer::new(token))
112 }
113}
114
115impl<ResBody> ValidateRequestHeaderLayer<Bearer<ResBody>> {
116 pub fn bearer(token: &str) -> Self
124 where
125 ResBody: Default,
126 {
127 Self::custom(Bearer::new(token))
128 }
129}
130
131pub struct Bearer<ResBody> {
135 header_value: HeaderValue,
136 _ty: PhantomData<fn() -> ResBody>,
137}
138
139impl<ResBody> Bearer<ResBody> {
140 fn new(token: &str) -> Self
141 where
142 ResBody: Default,
143 {
144 Self {
145 header_value: format!("Bearer {}", token)
146 .parse()
147 .expect("token is not a valid header value"),
148 _ty: PhantomData,
149 }
150 }
151}
152
153impl<ResBody> Clone for Bearer<ResBody> {
154 fn clone(&self) -> Self {
155 Self {
156 header_value: self.header_value.clone(),
157 _ty: PhantomData,
158 }
159 }
160}
161
162impl<ResBody> fmt::Debug for Bearer<ResBody> {
163 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164 f.debug_struct("Bearer")
165 .field("header_value", &self.header_value)
166 .finish()
167 }
168}
169
170impl<B, ResBody> ValidateRequest<B> for Bearer<ResBody>
171where
172 ResBody: Default,
173{
174 type ResponseBody = ResBody;
175
176 fn validate(&mut self, request: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> {
177 match request.headers().get(header::AUTHORIZATION) {
178 Some(actual) if actual == self.header_value => Ok(()),
179 _ => {
180 let mut res = Response::new(ResBody::default());
181 *res.status_mut() = StatusCode::UNAUTHORIZED;
182 Err(res)
183 }
184 }
185 }
186}
187
188pub struct Basic<ResBody> {
192 header_value: HeaderValue,
193 _ty: PhantomData<fn() -> ResBody>,
194}
195
196impl<ResBody> Basic<ResBody> {
197 fn new(username: &str, password: &str) -> Self
198 where
199 ResBody: Default,
200 {
201 let encoded = BASE64.encode(format!("{}:{}", username, password));
202 let header_value = format!("Basic {}", encoded).parse().unwrap();
203 Self {
204 header_value,
205 _ty: PhantomData,
206 }
207 }
208}
209
210impl<ResBody> Clone for Basic<ResBody> {
211 fn clone(&self) -> Self {
212 Self {
213 header_value: self.header_value.clone(),
214 _ty: PhantomData,
215 }
216 }
217}
218
219impl<ResBody> fmt::Debug for Basic<ResBody> {
220 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
221 f.debug_struct("Basic")
222 .field("header_value", &self.header_value)
223 .finish()
224 }
225}
226
227impl<B, ResBody> ValidateRequest<B> for Basic<ResBody>
228where
229 ResBody: Default,
230{
231 type ResponseBody = ResBody;
232
233 fn validate(&mut self, request: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> {
234 match request.headers().get(header::AUTHORIZATION) {
235 Some(actual) if actual == self.header_value => Ok(()),
236 _ => {
237 let mut res = Response::new(ResBody::default());
238 *res.status_mut() = StatusCode::UNAUTHORIZED;
239 res.headers_mut()
240 .insert(header::WWW_AUTHENTICATE, "Basic".parse().unwrap());
241 Err(res)
242 }
243 }
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use crate::validate_request::ValidateRequestHeaderLayer;
250
251 #[allow(unused_imports)]
252 use super::*;
253 use crate::test_helpers::Body;
254 use http::header;
255 use tower::{BoxError, ServiceBuilder, ServiceExt};
256 use tower_service::Service;
257
258 #[tokio::test]
259 async fn valid_basic_token() {
260 let mut service = ServiceBuilder::new()
261 .layer(ValidateRequestHeaderLayer::basic("foo", "bar"))
262 .service_fn(echo);
263
264 let request = Request::get("/")
265 .header(
266 header::AUTHORIZATION,
267 format!("Basic {}", BASE64.encode("foo:bar")),
268 )
269 .body(Body::empty())
270 .unwrap();
271
272 let res = service.ready().await.unwrap().call(request).await.unwrap();
273
274 assert_eq!(res.status(), StatusCode::OK);
275 }
276
277 #[tokio::test]
278 async fn invalid_basic_token() {
279 let mut service = ServiceBuilder::new()
280 .layer(ValidateRequestHeaderLayer::basic("foo", "bar"))
281 .service_fn(echo);
282
283 let request = Request::get("/")
284 .header(
285 header::AUTHORIZATION,
286 format!("Basic {}", BASE64.encode("wrong:credentials")),
287 )
288 .body(Body::empty())
289 .unwrap();
290
291 let res = service.ready().await.unwrap().call(request).await.unwrap();
292
293 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
294
295 let www_authenticate = res.headers().get(header::WWW_AUTHENTICATE).unwrap();
296 assert_eq!(www_authenticate, "Basic");
297 }
298
299 #[tokio::test]
300 async fn valid_bearer_token() {
301 let mut service = ServiceBuilder::new()
302 .layer(ValidateRequestHeaderLayer::bearer("foobar"))
303 .service_fn(echo);
304
305 let request = Request::get("/")
306 .header(header::AUTHORIZATION, "Bearer foobar")
307 .body(Body::empty())
308 .unwrap();
309
310 let res = service.ready().await.unwrap().call(request).await.unwrap();
311
312 assert_eq!(res.status(), StatusCode::OK);
313 }
314
315 #[tokio::test]
316 async fn basic_auth_is_case_sensitive_in_prefix() {
317 let mut service = ServiceBuilder::new()
318 .layer(ValidateRequestHeaderLayer::basic("foo", "bar"))
319 .service_fn(echo);
320
321 let request = Request::get("/")
322 .header(
323 header::AUTHORIZATION,
324 format!("basic {}", BASE64.encode("foo:bar")),
325 )
326 .body(Body::empty())
327 .unwrap();
328
329 let res = service.ready().await.unwrap().call(request).await.unwrap();
330
331 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
332 }
333
334 #[tokio::test]
335 async fn basic_auth_is_case_sensitive_in_value() {
336 let mut service = ServiceBuilder::new()
337 .layer(ValidateRequestHeaderLayer::basic("foo", "bar"))
338 .service_fn(echo);
339
340 let request = Request::get("/")
341 .header(
342 header::AUTHORIZATION,
343 format!("Basic {}", BASE64.encode("Foo:bar")),
344 )
345 .body(Body::empty())
346 .unwrap();
347
348 let res = service.ready().await.unwrap().call(request).await.unwrap();
349
350 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
351 }
352
353 #[tokio::test]
354 async fn invalid_bearer_token() {
355 let mut service = ServiceBuilder::new()
356 .layer(ValidateRequestHeaderLayer::bearer("foobar"))
357 .service_fn(echo);
358
359 let request = Request::get("/")
360 .header(header::AUTHORIZATION, "Bearer wat")
361 .body(Body::empty())
362 .unwrap();
363
364 let res = service.ready().await.unwrap().call(request).await.unwrap();
365
366 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
367 }
368
369 #[tokio::test]
370 async fn bearer_token_is_case_sensitive_in_prefix() {
371 let mut service = ServiceBuilder::new()
372 .layer(ValidateRequestHeaderLayer::bearer("foobar"))
373 .service_fn(echo);
374
375 let request = Request::get("/")
376 .header(header::AUTHORIZATION, "bearer foobar")
377 .body(Body::empty())
378 .unwrap();
379
380 let res = service.ready().await.unwrap().call(request).await.unwrap();
381
382 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
383 }
384
385 #[tokio::test]
386 async fn bearer_token_is_case_sensitive_in_token() {
387 let mut service = ServiceBuilder::new()
388 .layer(ValidateRequestHeaderLayer::bearer("foobar"))
389 .service_fn(echo);
390
391 let request = Request::get("/")
392 .header(header::AUTHORIZATION, "Bearer Foobar")
393 .body(Body::empty())
394 .unwrap();
395
396 let res = service.ready().await.unwrap().call(request).await.unwrap();
397
398 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
399 }
400
401 async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> {
402 Ok(Response::new(req.into_body()))
403 }
404}