product_os_router/
default_headers.rsuse std::prelude::v1::*;
use std::{
future::Future,
pin::Pin,
task::{ Context, Poll },
};
use futures_util::ready;
use axum::http::{header::HeaderMap, Request, Response};
use tower_layer::Layer;
use tower_service::Service;
use pin_project::pin_project;
use crate::BoxError;
#[derive(Clone)]
pub struct DefaultHeadersLayer {
default_headers: HeaderMap,
}
impl DefaultHeadersLayer {
pub fn new(default_headers: HeaderMap) -> Self {
Self { default_headers }
}
}
impl<S> Layer<S> for DefaultHeadersLayer {
type Service = DefaultHeaders<S>;
fn layer(&self, inner: S) -> Self::Service {
Self::Service {
default_headers: self.default_headers.clone(),
inner,
}
}
}
#[derive(Clone)]
pub struct DefaultHeaders<S> {
default_headers: HeaderMap,
inner: S,
}
impl<S> DefaultHeaders<S> {}
impl<S, Request, ResBody> Service<Request> for DefaultHeaders<S>
where
S: Service<Request, Response = Response<ResBody>>,
S::Error: Into<BoxError>,
{
type Response = S::Response;
type Error = BoxError;
type Future = ResponseFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(Into::into)
}
fn call(&mut self, req: Request) -> Self::Future {
let default_headers = self.default_headers.clone();
let response_future = self.inner.call(req);
ResponseFuture {
default_headers,
response_future,
}
}
}
#[pin_project]
pub struct ResponseFuture<F> {
#[pin]
default_headers: HeaderMap,
#[pin]
response_future: F,
}
impl<F, ResBody, Error> Future for ResponseFuture<F>
where
F: Future<Output = Result<Response<ResBody>, Error>>,
Error: Into<BoxError>
{
type Output = Result<Response<ResBody>, BoxError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let response_future: Pin<&mut F> = this.response_future;
let default_headers: Pin<&mut HeaderMap> = this.default_headers;
match response_future.poll(cx) {
Poll::Ready(result) => {
match result {
Ok(mut response) => {
let headers = response.headers_mut();
for (name, value) in default_headers.iter() {
if !headers.contains_key(name) {
headers.insert(name, value.clone());
}
}
Poll::Ready(Ok(response))
}
Err(e) => {
Poll::Ready(Err(e.into()))
}
}
},
Poll::Pending => Poll::Pending
}
}
}