rama_http/layer/remove_header/
request.rsuse crate::{HeaderName, Request, Response};
use rama_core::{Context, Layer, Service};
use rama_utils::macros::define_inner_service_accessors;
use std::{borrow::Cow, fmt, future::Future};
#[derive(Debug, Clone)]
pub struct RemoveRequestHeaderLayer {
mode: RemoveRequestHeaderMode,
}
#[derive(Debug, Clone)]
enum RemoveRequestHeaderMode {
Prefix(Cow<'static, str>),
Exact(HeaderName),
Hop,
}
impl RemoveRequestHeaderLayer {
pub fn prefix(prefix: impl Into<Cow<'static, str>>) -> Self {
Self {
mode: RemoveRequestHeaderMode::Prefix(prefix.into()),
}
}
pub fn exact(header: HeaderName) -> Self {
Self {
mode: RemoveRequestHeaderMode::Exact(header),
}
}
pub fn hop_by_hop() -> Self {
Self {
mode: RemoveRequestHeaderMode::Hop,
}
}
}
impl<S> Layer<S> for RemoveRequestHeaderLayer {
type Service = RemoveRequestHeader<S>;
fn layer(&self, inner: S) -> Self::Service {
RemoveRequestHeader {
inner,
mode: self.mode.clone(),
}
}
}
pub struct RemoveRequestHeader<S> {
inner: S,
mode: RemoveRequestHeaderMode,
}
impl<S> RemoveRequestHeader<S> {
pub fn prefix(prefix: impl Into<Cow<'static, str>>, inner: S) -> Self {
RemoveRequestHeaderLayer::prefix(prefix.into()).layer(inner)
}
pub fn exact(header: HeaderName, inner: S) -> Self {
RemoveRequestHeaderLayer::exact(header).layer(inner)
}
pub fn hop_by_hop(inner: S) -> Self {
RemoveRequestHeaderLayer::hop_by_hop().layer(inner)
}
define_inner_service_accessors!();
}
impl<S: fmt::Debug> fmt::Debug for RemoveRequestHeader<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RemoveRequestHeader")
.field("inner", &self.inner)
.field("mode", &self.mode)
.finish()
}
}
impl<S: Clone> Clone for RemoveRequestHeader<S> {
fn clone(&self) -> Self {
RemoveRequestHeader {
inner: self.inner.clone(),
mode: self.mode.clone(),
}
}
}
impl<ReqBody, ResBody, State, S> Service<State, Request<ReqBody>> for RemoveRequestHeader<S>
where
ReqBody: Send + 'static,
ResBody: Send + 'static,
State: Clone + Send + Sync + 'static,
S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
{
type Response = S::Response;
type Error = S::Error;
fn serve(
&self,
ctx: Context<State>,
mut req: Request<ReqBody>,
) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
match &self.mode {
RemoveRequestHeaderMode::Hop => {
super::remove_hop_by_hop_request_headers(req.headers_mut())
}
RemoveRequestHeaderMode::Prefix(prefix) => {
super::remove_headers_by_prefix(req.headers_mut(), prefix)
}
RemoveRequestHeaderMode::Exact(header) => {
super::remove_headers_by_exact_name(req.headers_mut(), header)
}
}
self.inner.serve(ctx, req)
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{Body, Response};
use rama_core::{service::service_fn, Layer, Service};
use std::convert::Infallible;
#[tokio::test]
async fn remove_request_header_prefix() {
let svc = RemoveRequestHeaderLayer::prefix("x-foo").layer(service_fn(
|_ctx: Context<()>, req: Request| async move {
assert!(req.headers().get("x-foo-bar").is_none());
assert_eq!(
req.headers().get("foo").map(|v| v.to_str().unwrap()),
Some("bar")
);
Ok::<_, Infallible>(Response::new(Body::empty()))
},
));
let req = Request::builder()
.header("x-foo-bar", "baz")
.header("foo", "bar")
.body(Body::empty())
.unwrap();
let _ = svc.serve(Context::default(), req).await.unwrap();
}
#[tokio::test]
async fn remove_request_header_exact() {
let svc = RemoveRequestHeaderLayer::exact(HeaderName::from_static("x-foo")).layer(
service_fn(|_ctx: Context<()>, req: Request| async move {
assert!(req.headers().get("x-foo").is_none());
assert_eq!(
req.headers().get("x-foo-bar").map(|v| v.to_str().unwrap()),
Some("baz")
);
Ok::<_, Infallible>(Response::new(Body::empty()))
}),
);
let req = Request::builder()
.header("x-foo", "baz")
.header("x-foo-bar", "baz")
.body(Body::empty())
.unwrap();
let _ = svc.serve(Context::default(), req).await.unwrap();
}
#[tokio::test]
async fn remove_request_header_hop_by_hop() {
let svc = RemoveRequestHeaderLayer::hop_by_hop().layer(service_fn(
|_ctx: Context<()>, req: Request| async move {
assert!(req.headers().get("connection").is_none());
assert_eq!(
req.headers().get("foo").map(|v| v.to_str().unwrap()),
Some("bar")
);
Ok::<_, Infallible>(Response::new(Body::empty()))
},
));
let req = Request::builder()
.header("connection", "close")
.header("foo", "bar")
.body(Body::empty())
.unwrap();
let _ = svc.serve(Context::default(), req).await.unwrap();
}
}