1use crate::header::PROXY_AUTHENTICATE;
6use crate::headers::{HeaderMapExt, ProxyAuthorization, authorization::Credentials};
7use crate::{Request, Response, StatusCode};
8use rama_core::{Context, Layer, Service};
9use rama_net::user::{UserId, auth::Authority};
10use rama_utils::macros::define_inner_service_accessors;
11use std::fmt;
12use std::marker::PhantomData;
13
14pub struct ProxyAuthLayer<A, C, L = ()> {
18 proxy_auth: A,
19 allow_anonymous: bool,
20 _phantom: PhantomData<fn(C, L) -> ()>,
21}
22
23impl<A: fmt::Debug, C, L> fmt::Debug for ProxyAuthLayer<A, C, L> {
24 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
25 f.debug_struct("ProxyAuthLayer")
26 .field("proxy_auth", &self.proxy_auth)
27 .field(
28 "_phantom",
29 &format_args!("{}", std::any::type_name::<fn(C, L) -> ()>()),
30 )
31 .finish()
32 }
33}
34
35impl<A: Clone, C, L> Clone for ProxyAuthLayer<A, C, L> {
36 fn clone(&self) -> Self {
37 Self {
38 proxy_auth: self.proxy_auth.clone(),
39 allow_anonymous: self.allow_anonymous,
40 _phantom: PhantomData,
41 }
42 }
43}
44
45impl<A, C> ProxyAuthLayer<A, C, ()> {
46 pub const fn new(proxy_auth: A) -> Self {
48 ProxyAuthLayer {
49 proxy_auth,
50 allow_anonymous: false,
51 _phantom: PhantomData,
52 }
53 }
54
55 pub fn set_allow_anonymous(&mut self, allow_anonymous: bool) -> &mut Self {
57 self.allow_anonymous = allow_anonymous;
58 self
59 }
60
61 pub fn with_allow_anonymous(mut self, allow_anonymous: bool) -> Self {
63 self.allow_anonymous = allow_anonymous;
64 self
65 }
66}
67
68impl<A, C, L> ProxyAuthLayer<A, C, L> {
69 pub fn with_labels<L2>(self) -> ProxyAuthLayer<A, C, L2> {
79 ProxyAuthLayer {
80 proxy_auth: self.proxy_auth,
81 allow_anonymous: self.allow_anonymous,
82 _phantom: PhantomData,
83 }
84 }
85}
86
87impl<A, C, L, S> Layer<S> for ProxyAuthLayer<A, C, L>
88where
89 A: Authority<C, L> + Clone,
90 C: Credentials + Clone + Send + Sync + 'static,
91{
92 type Service = ProxyAuthService<A, C, S, L>;
93
94 fn layer(&self, inner: S) -> Self::Service {
95 ProxyAuthService::new(self.proxy_auth.clone(), inner)
96 }
97
98 fn into_layer(self, inner: S) -> Self::Service {
99 ProxyAuthService::new(self.proxy_auth, inner)
100 }
101}
102
103pub struct ProxyAuthService<A, C, S, L = ()> {
111 proxy_auth: A,
112 allow_anonymous: bool,
113 inner: S,
114 _phantom: PhantomData<fn(C, L) -> ()>,
115}
116
117impl<A, C, S, L> ProxyAuthService<A, C, S, L> {
118 pub const fn new(proxy_auth: A, inner: S) -> Self {
120 Self {
121 proxy_auth,
122 allow_anonymous: false,
123 inner,
124 _phantom: PhantomData,
125 }
126 }
127
128 pub fn set_allow_anonymous(&mut self, allow_anonymous: bool) -> &mut Self {
130 self.allow_anonymous = allow_anonymous;
131 self
132 }
133
134 pub fn with_allow_anonymous(mut self, allow_anonymous: bool) -> Self {
136 self.allow_anonymous = allow_anonymous;
137 self
138 }
139
140 define_inner_service_accessors!();
141}
142
143impl<A: fmt::Debug, C, S: fmt::Debug, L> fmt::Debug for ProxyAuthService<A, C, S, L> {
144 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
145 f.debug_struct("ProxyAuthService")
146 .field("proxy_auth", &self.proxy_auth)
147 .field("allow_anonymous", &self.allow_anonymous)
148 .field("inner", &self.inner)
149 .field(
150 "_phantom",
151 &format_args!("{}", std::any::type_name::<fn(C, L) -> ()>()),
152 )
153 .finish()
154 }
155}
156
157impl<A: Clone, C, S: Clone, L> Clone for ProxyAuthService<A, C, S, L> {
158 fn clone(&self) -> Self {
159 ProxyAuthService {
160 proxy_auth: self.proxy_auth.clone(),
161 allow_anonymous: self.allow_anonymous,
162 inner: self.inner.clone(),
163 _phantom: PhantomData,
164 }
165 }
166}
167
168impl<A, C, L, S, State, ReqBody, ResBody> Service<State, Request<ReqBody>>
169 for ProxyAuthService<A, C, S, L>
170where
171 A: Authority<C, L>,
172 C: Credentials + Clone + Send + Sync + 'static,
173 S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
174 L: 'static,
175 ReqBody: Send + 'static,
176 ResBody: Default + Send + 'static,
177 State: Clone + Send + Sync + 'static,
178{
179 type Response = S::Response;
180 type Error = S::Error;
181
182 async fn serve(
183 &self,
184 mut ctx: Context<State>,
185 req: Request<ReqBody>,
186 ) -> Result<Self::Response, Self::Error> {
187 if let Some(credentials) = req
188 .headers()
189 .typed_get::<ProxyAuthorization<C>>()
190 .map(|h| h.0)
191 .or_else(|| ctx.get::<C>().cloned())
192 {
193 if let Some(ext) = self.proxy_auth.authorized(credentials).await {
194 ctx.extend(ext);
195 self.inner.serve(ctx, req).await
196 } else {
197 Ok(Response::builder()
198 .status(StatusCode::PROXY_AUTHENTICATION_REQUIRED)
199 .header(PROXY_AUTHENTICATE, C::SCHEME)
200 .body(Default::default())
201 .unwrap())
202 }
203 } else if self.allow_anonymous {
204 ctx.insert(UserId::Anonymous);
205 self.inner.serve(ctx, req).await
206 } else {
207 Ok(Response::builder()
208 .status(StatusCode::PROXY_AUTHENTICATION_REQUIRED)
209 .header(PROXY_AUTHENTICATE, C::SCHEME)
210 .body(Default::default())
211 .unwrap())
212 }
213 }
214}