1use crate::header::PROXY_AUTHENTICATE;
6use crate::headers::{authorization::Credentials, HeaderMapExt, ProxyAuthorization};
7use crate::{Request, Response, StatusCode};
8use rama_core::{Context, Layer, Service};
9use rama_net::user::{auth::Authority, UserId};
10use rama_utils::macros::define_inner_service_accessors;
11use std::fmt;
12use std::marker::PhantomData;
13
14pub struct ProxyAuthLayer<A, C, L = ()> {
18 proxy_auth: A,
19 allow_anonymous: bool,
20 _phantom: PhantomData<fn(C, L) -> ()>,
21}
22
23impl<A: fmt::Debug, C, L> fmt::Debug for ProxyAuthLayer<A, C, L> {
24 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
25 f.debug_struct("ProxyAuthLayer")
26 .field("proxy_auth", &self.proxy_auth)
27 .field(
28 "_phantom",
29 &format_args!("{}", std::any::type_name::<fn(C, L) -> ()>()),
30 )
31 .finish()
32 }
33}
34
35impl<A: Clone, C, L> Clone for ProxyAuthLayer<A, C, L> {
36 fn clone(&self) -> Self {
37 Self {
38 proxy_auth: self.proxy_auth.clone(),
39 allow_anonymous: self.allow_anonymous,
40 _phantom: PhantomData,
41 }
42 }
43}
44
45impl<A, C> ProxyAuthLayer<A, C, ()> {
46 pub const fn new(proxy_auth: A) -> Self {
48 ProxyAuthLayer {
49 proxy_auth,
50 allow_anonymous: false,
51 _phantom: PhantomData,
52 }
53 }
54
55 pub fn set_allow_anonymous(&mut self, allow_anonymous: bool) -> &mut Self {
57 self.allow_anonymous = allow_anonymous;
58 self
59 }
60
61 pub fn with_allow_anonymous(mut self, allow_anonymous: bool) -> Self {
63 self.allow_anonymous = allow_anonymous;
64 self
65 }
66}
67
68impl<A, C, L> ProxyAuthLayer<A, C, L> {
69 pub fn with_labels<L2>(self) -> ProxyAuthLayer<A, C, L2> {
79 ProxyAuthLayer {
80 proxy_auth: self.proxy_auth,
81 allow_anonymous: self.allow_anonymous,
82 _phantom: PhantomData,
83 }
84 }
85}
86
87impl<A, C, L, S> Layer<S> for ProxyAuthLayer<A, C, L>
88where
89 A: Authority<C, L> + Clone,
90 C: Credentials + Clone + Send + Sync + 'static,
91{
92 type Service = ProxyAuthService<A, C, S, L>;
93
94 fn layer(&self, inner: S) -> Self::Service {
95 ProxyAuthService::new(self.proxy_auth.clone(), inner)
96 }
97}
98
99pub struct ProxyAuthService<A, C, S, L = ()> {
107 proxy_auth: A,
108 allow_anonymous: bool,
109 inner: S,
110 _phantom: PhantomData<fn(C, L) -> ()>,
111}
112
113impl<A, C, S, L> ProxyAuthService<A, C, S, L> {
114 pub const fn new(proxy_auth: A, inner: S) -> Self {
116 Self {
117 proxy_auth,
118 allow_anonymous: false,
119 inner,
120 _phantom: PhantomData,
121 }
122 }
123
124 pub fn set_allow_anonymous(&mut self, allow_anonymous: bool) -> &mut Self {
126 self.allow_anonymous = allow_anonymous;
127 self
128 }
129
130 pub fn with_allow_anonymous(mut self, allow_anonymous: bool) -> Self {
132 self.allow_anonymous = allow_anonymous;
133 self
134 }
135
136 define_inner_service_accessors!();
137}
138
139impl<A: fmt::Debug, C, S: fmt::Debug, L> fmt::Debug for ProxyAuthService<A, C, S, L> {
140 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
141 f.debug_struct("ProxyAuthService")
142 .field("proxy_auth", &self.proxy_auth)
143 .field("allow_anonymous", &self.allow_anonymous)
144 .field("inner", &self.inner)
145 .field(
146 "_phantom",
147 &format_args!("{}", std::any::type_name::<fn(C, L) -> ()>()),
148 )
149 .finish()
150 }
151}
152
153impl<A: Clone, C, S: Clone, L> Clone for ProxyAuthService<A, C, S, L> {
154 fn clone(&self) -> Self {
155 ProxyAuthService {
156 proxy_auth: self.proxy_auth.clone(),
157 allow_anonymous: self.allow_anonymous,
158 inner: self.inner.clone(),
159 _phantom: PhantomData,
160 }
161 }
162}
163
164impl<A, C, L, S, State, ReqBody, ResBody> Service<State, Request<ReqBody>>
165 for ProxyAuthService<A, C, S, L>
166where
167 A: Authority<C, L>,
168 C: Credentials + Clone + Send + Sync + 'static,
169 S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
170 L: 'static,
171 ReqBody: Send + 'static,
172 ResBody: Default + Send + 'static,
173 State: Clone + Send + Sync + 'static,
174{
175 type Response = S::Response;
176 type Error = S::Error;
177
178 async fn serve(
179 &self,
180 mut ctx: Context<State>,
181 req: Request<ReqBody>,
182 ) -> Result<Self::Response, Self::Error> {
183 if let Some(credentials) = req
184 .headers()
185 .typed_get::<ProxyAuthorization<C>>()
186 .map(|h| h.0)
187 .or_else(|| ctx.get::<C>().cloned())
188 {
189 if let Some(ext) = self.proxy_auth.authorized(credentials).await {
190 ctx.extend(ext);
191 self.inner.serve(ctx, req).await
192 } else {
193 Ok(Response::builder()
194 .status(StatusCode::PROXY_AUTHENTICATION_REQUIRED)
195 .header(PROXY_AUTHENTICATE, C::SCHEME)
196 .body(Default::default())
197 .unwrap())
198 }
199 } else if self.allow_anonymous {
200 ctx.insert(UserId::Anonymous);
201 self.inner.serve(ctx, req).await
202 } else {
203 Ok(Response::builder()
204 .status(StatusCode::PROXY_AUTHENTICATION_REQUIRED)
205 .header(PROXY_AUTHENTICATE, C::SCHEME)
206 .body(Default::default())
207 .unwrap())
208 }
209 }
210}