use std::future::Future;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use http::{header, HeaderMap, HeaderValue, Method, Request, Response, StatusCode, Version};
use http_body_util::BodyExt;
use pin_project::pin_project;
use tonic::metadata::GRPC_CONTENT_TYPE;
use tonic::{
body::{empty_body, BoxBody},
server::NamedService,
};
use tower_service::Service;
use tracing::{debug, trace};
use crate::call::content_types::is_grpc_web;
use crate::call::{Encoding, GrpcWebCall};
use crate::BoxError;
#[derive(Debug, Clone)]
pub struct GrpcWebService<S> {
inner: S,
}
#[derive(Debug, PartialEq)]
enum RequestKind<'a> {
GrpcWeb {
method: &'a Method,
encoding: Encoding,
accept: Encoding,
},
Other(http::Version),
}
impl<S> GrpcWebService<S> {
pub(crate) fn new(inner: S) -> Self {
GrpcWebService { inner }
}
}
impl<S> GrpcWebService<S>
where
S: Service<Request<BoxBody>, Response = Response<BoxBody>> + Send + 'static,
{
fn response(&self, status: StatusCode) -> ResponseFuture<S::Future> {
ResponseFuture {
case: Case::ImmediateResponse {
res: Some(
Response::builder()
.status(status)
.body(empty_body())
.unwrap(),
),
},
}
}
}
impl<S> Service<Request<BoxBody>> for GrpcWebService<S>
where
S: Service<Request<BoxBody>, Response = Response<BoxBody>> + Send + 'static,
S::Future: Send + 'static,
S::Error: Into<BoxError> + Send,
{
type Response = S::Response;
type Error = S::Error;
type Future = ResponseFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<BoxBody>) -> Self::Future {
match RequestKind::new(req.headers(), req.method(), req.version()) {
RequestKind::GrpcWeb {
method: &Method::POST,
encoding,
accept,
} => {
trace!(kind = "simple", path = ?req.uri().path(), ?encoding, ?accept);
ResponseFuture {
case: Case::GrpcWeb {
future: self.inner.call(coerce_request(req, encoding)),
accept,
},
}
}
RequestKind::GrpcWeb { .. } => {
debug!(kind = "simple", error="method not allowed", method = ?req.method());
self.response(StatusCode::METHOD_NOT_ALLOWED)
}
RequestKind::Other(Version::HTTP_2) => {
debug!(kind = "other h2", content_type = ?req.headers().get(header::CONTENT_TYPE));
ResponseFuture {
case: Case::Other {
future: self.inner.call(req),
},
}
}
RequestKind::Other(_) => {
debug!(kind = "other h1", content_type = ?req.headers().get(header::CONTENT_TYPE));
self.response(StatusCode::BAD_REQUEST)
}
}
}
}
#[allow(missing_debug_implementations)]
#[pin_project]
#[must_use = "futures do nothing unless polled"]
pub struct ResponseFuture<F> {
#[pin]
case: Case<F>,
}
#[pin_project(project = CaseProj)]
enum Case<F> {
GrpcWeb {
#[pin]
future: F,
accept: Encoding,
},
Other {
#[pin]
future: F,
},
ImmediateResponse {
res: Option<Response<BoxBody>>,
},
}
impl<F, E> Future for ResponseFuture<F>
where
F: Future<Output = Result<Response<BoxBody>, E>> + Send + 'static,
E: Into<BoxError> + Send,
{
type Output = Result<Response<BoxBody>, E>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
match this.case.as_mut().project() {
CaseProj::GrpcWeb { future, accept } => {
let res = ready!(future.poll(cx))?;
Poll::Ready(Ok(coerce_response(res, *accept)))
}
CaseProj::Other { future } => future.poll(cx),
CaseProj::ImmediateResponse { res } => Poll::Ready(Ok(res.take().unwrap())),
}
}
}
impl<S: NamedService> NamedService for GrpcWebService<S> {
const NAME: &'static str = S::NAME;
}
impl<'a> RequestKind<'a> {
fn new(headers: &'a HeaderMap, method: &'a Method, version: Version) -> Self {
if is_grpc_web(headers) {
return RequestKind::GrpcWeb {
method,
encoding: Encoding::from_content_type(headers),
accept: Encoding::from_accept(headers),
};
}
RequestKind::Other(version)
}
}
fn coerce_request(mut req: Request<BoxBody>, encoding: Encoding) -> Request<BoxBody> {
req.headers_mut().remove(header::CONTENT_LENGTH);
req.headers_mut()
.insert(header::CONTENT_TYPE, GRPC_CONTENT_TYPE);
req.headers_mut()
.insert(header::TE, HeaderValue::from_static("trailers"));
req.headers_mut().insert(
header::ACCEPT_ENCODING,
HeaderValue::from_static("identity,deflate,gzip"),
);
req.map(|b| GrpcWebCall::request(b, encoding).boxed_unsync())
}
fn coerce_response(res: Response<BoxBody>, encoding: Encoding) -> Response<BoxBody> {
let mut res = res
.map(|b| GrpcWebCall::response(b, encoding))
.map(BoxBody::new);
res.headers_mut().insert(
header::CONTENT_TYPE,
HeaderValue::from_static(encoding.to_content_type()),
);
res
}
#[cfg(test)]
mod tests {
use super::*;
use crate::call::content_types::*;
use http::header::{
ACCESS_CONTROL_REQUEST_HEADERS, ACCESS_CONTROL_REQUEST_METHOD, CONTENT_TYPE, ORIGIN,
};
type BoxFuture<T, E> = Pin<Box<dyn Future<Output = Result<T, E>> + Send>>;
#[derive(Debug, Clone)]
struct Svc;
impl tower_service::Service<Request<BoxBody>> for Svc {
type Response = Response<BoxBody>;
type Error = String;
type Future = BoxFuture<Self::Response, Self::Error>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _: Request<BoxBody>) -> Self::Future {
Box::pin(async { Ok(Response::new(empty_body())) })
}
}
impl NamedService for Svc {
const NAME: &'static str = "test";
}
mod grpc_web {
use super::*;
use tower_layer::Layer;
fn request() -> Request<BoxBody> {
Request::builder()
.method(Method::POST)
.header(CONTENT_TYPE, GRPC_WEB)
.header(ORIGIN, "http://example.com")
.body(empty_body())
.unwrap()
}
#[tokio::test]
async fn default_cors_config() {
let mut svc = crate::enable(Svc);
let res = svc.call(request()).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn web_layer() {
let mut svc = crate::GrpcWebLayer::new().layer(Svc);
let res = svc.call(request()).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn without_origin() {
let mut svc = crate::enable(Svc);
let mut req = request();
req.headers_mut().remove(ORIGIN);
let res = svc.call(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn only_post_and_options_allowed() {
let mut svc = crate::enable(Svc);
for method in &[
Method::GET,
Method::PUT,
Method::DELETE,
Method::HEAD,
Method::PATCH,
] {
let mut req = request();
*req.method_mut() = method.clone();
let res = svc.call(req).await.unwrap();
assert_eq!(
res.status(),
StatusCode::METHOD_NOT_ALLOWED,
"{} should not be allowed",
method
);
}
}
#[tokio::test]
async fn grpc_web_content_types() {
let mut svc = crate::enable(Svc);
for ct in &[GRPC_WEB_TEXT, GRPC_WEB_PROTO, GRPC_WEB_TEXT_PROTO, GRPC_WEB] {
let mut req = request();
req.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static(ct));
let res = svc.call(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
}
}
mod options {
use super::*;
fn request() -> Request<BoxBody> {
Request::builder()
.method(Method::OPTIONS)
.header(ORIGIN, "http://example.com")
.header(ACCESS_CONTROL_REQUEST_HEADERS, "x-grpc-web")
.header(ACCESS_CONTROL_REQUEST_METHOD, "POST")
.body(empty_body())
.unwrap()
}
#[tokio::test]
async fn valid_grpc_web_preflight() {
let mut svc = crate::enable(Svc);
let res = svc.call(request()).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
}
mod grpc {
use super::*;
fn request() -> Request<BoxBody> {
Request::builder()
.version(Version::HTTP_2)
.header(CONTENT_TYPE, GRPC_CONTENT_TYPE)
.body(empty_body())
.unwrap()
}
#[tokio::test]
async fn h2_is_ok() {
let mut svc = crate::enable(Svc);
let req = request();
let res = svc.call(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK)
}
#[tokio::test]
async fn h1_is_err() {
let mut svc = crate::enable(Svc);
let req = Request::builder()
.header(CONTENT_TYPE, GRPC_CONTENT_TYPE)
.body(empty_body())
.unwrap();
let res = svc.call(req).await.unwrap();
assert_eq!(res.status(), StatusCode::BAD_REQUEST)
}
#[tokio::test]
async fn content_type_variants() {
let mut svc = crate::enable(Svc);
for variant in &["grpc", "grpc+proto", "grpc+thrift", "grpc+foo"] {
let mut req = request();
req.headers_mut().insert(
CONTENT_TYPE,
HeaderValue::from_maybe_shared(format!("application/{}", variant)).unwrap(),
);
let res = svc.call(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK)
}
}
}
mod other {
use super::*;
fn request() -> Request<BoxBody> {
Request::builder()
.header(CONTENT_TYPE, "application/text")
.body(empty_body())
.unwrap()
}
#[tokio::test]
async fn h1_is_err() {
let mut svc = crate::enable(Svc);
let res = svc.call(request()).await.unwrap();
assert_eq!(res.status(), StatusCode::BAD_REQUEST)
}
#[tokio::test]
async fn h2_is_ok() {
let mut svc = crate::enable(Svc);
let mut req = request();
*req.version_mut() = Version::HTTP_2;
let res = svc.call(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK)
}
}
}