use http::{Request, Response, Uri};
use std::{
borrow::Cow,
task::{Context, Poll},
};
use tower_layer::Layer;
use tower_service::Service;
#[derive(Debug, Copy, Clone)]
pub struct NormalizePathLayer {}
impl NormalizePathLayer {
pub fn trim_trailing_slash() -> Self {
NormalizePathLayer {}
}
}
impl<S> Layer<S> for NormalizePathLayer {
type Service = NormalizePath<S>;
fn layer(&self, inner: S) -> Self::Service {
NormalizePath::trim_trailing_slash(inner)
}
}
#[derive(Debug, Copy, Clone)]
pub struct NormalizePath<S> {
inner: S,
}
impl<S> NormalizePath<S> {
pub fn trim_trailing_slash(inner: S) -> Self {
Self { inner }
}
define_inner_service_accessors!();
}
impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for NormalizePath<S>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
#[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
normalize_trailing_slash(req.uri_mut());
self.inner.call(req)
}
}
fn normalize_trailing_slash(uri: &mut Uri) {
if !uri.path().ends_with('/') && !uri.path().starts_with("//") {
return;
}
let new_path = format!("/{}", uri.path().trim_matches('/'));
let mut parts = uri.clone().into_parts();
let new_path_and_query = if let Some(path_and_query) = &parts.path_and_query {
let new_path_and_query = if let Some(query) = path_and_query.query() {
Cow::Owned(format!("{}?{}", new_path, query))
} else {
new_path.into()
}
.parse()
.unwrap();
Some(new_path_and_query)
} else {
None
};
parts.path_and_query = new_path_and_query;
if let Ok(new_uri) = Uri::from_parts(parts) {
*uri = new_uri;
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::convert::Infallible;
use tower::{ServiceBuilder, ServiceExt};
#[tokio::test]
async fn works() {
async fn handle(request: Request<()>) -> Result<Response<String>, Infallible> {
Ok(Response::new(request.uri().to_string()))
}
let mut svc = ServiceBuilder::new()
.layer(NormalizePathLayer::trim_trailing_slash())
.service_fn(handle);
let body = svc
.ready()
.await
.unwrap()
.call(Request::builder().uri("/foo/").body(()).unwrap())
.await
.unwrap()
.into_body();
assert_eq!(body, "/foo");
}
#[test]
fn is_noop_if_no_trailing_slash() {
let mut uri = "/foo".parse::<Uri>().unwrap();
normalize_trailing_slash(&mut uri);
assert_eq!(uri, "/foo");
}
#[test]
fn maintains_query() {
let mut uri = "/foo/?a=a".parse::<Uri>().unwrap();
normalize_trailing_slash(&mut uri);
assert_eq!(uri, "/foo?a=a");
}
#[test]
fn removes_multiple_trailing_slashes() {
let mut uri = "/foo////".parse::<Uri>().unwrap();
normalize_trailing_slash(&mut uri);
assert_eq!(uri, "/foo");
}
#[test]
fn removes_multiple_trailing_slashes_even_with_query() {
let mut uri = "/foo////?a=a".parse::<Uri>().unwrap();
normalize_trailing_slash(&mut uri);
assert_eq!(uri, "/foo?a=a");
}
#[test]
fn is_noop_on_index() {
let mut uri = "/".parse::<Uri>().unwrap();
normalize_trailing_slash(&mut uri);
assert_eq!(uri, "/");
}
#[test]
fn removes_multiple_trailing_slashes_on_index() {
let mut uri = "////".parse::<Uri>().unwrap();
normalize_trailing_slash(&mut uri);
assert_eq!(uri, "/");
}
#[test]
fn removes_multiple_trailing_slashes_on_index_even_with_query() {
let mut uri = "////?a=a".parse::<Uri>().unwrap();
normalize_trailing_slash(&mut uri);
assert_eq!(uri, "/?a=a");
}
#[test]
fn removes_multiple_preceding_slashes_even_with_query() {
let mut uri = "///foo//?a=a".parse::<Uri>().unwrap();
normalize_trailing_slash(&mut uri);
assert_eq!(uri, "/foo?a=a");
}
#[test]
fn removes_multiple_preceding_slashes() {
let mut uri = "///foo".parse::<Uri>().unwrap();
normalize_trailing_slash(&mut uri);
assert_eq!(uri, "/foo");
}
}