1use crate::{
39 headers::{self, HeaderMapExt},
40 HeaderName, Request,
41};
42use rama_core::{Context, Layer, Service};
43use rama_utils::macros::define_inner_service_accessors;
44use std::{
45 fmt::{self, Debug},
46 future::Future,
47};
48
49pub use rama_ua::{
50 DeviceKind, HttpAgent, PlatformKind, TlsAgent, UserAgent, UserAgentInfo, UserAgentKind,
51 UserAgentOverwrites,
52};
53
54pub struct UserAgentClassifier<S> {
62 inner: S,
63 overwrite_header: Option<HeaderName>,
64}
65
66impl<S> UserAgentClassifier<S> {
67 pub const fn new(inner: S, overwrite_header: Option<HeaderName>) -> Self {
69 Self {
70 inner,
71 overwrite_header,
72 }
73 }
74
75 define_inner_service_accessors!();
76}
77
78impl<S> Debug for UserAgentClassifier<S>
79where
80 S: Debug,
81{
82 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
83 f.debug_struct("UserAgentClassifier")
84 .field("inner", &self.inner)
85 .finish()
86 }
87}
88
89impl<S> Clone for UserAgentClassifier<S>
90where
91 S: Clone,
92{
93 fn clone(&self) -> Self {
94 Self {
95 inner: self.inner.clone(),
96 overwrite_header: self.overwrite_header.clone(),
97 }
98 }
99}
100
101impl<S> Default for UserAgentClassifier<S>
102where
103 S: Default,
104{
105 fn default() -> Self {
106 Self {
107 inner: S::default(),
108 overwrite_header: None,
109 }
110 }
111}
112
113impl<S, State, Body> Service<State, Request<Body>> for UserAgentClassifier<S>
114where
115 S: Service<State, Request<Body>>,
116 State: Clone + Send + Sync + 'static,
117{
118 type Response = S::Response;
119 type Error = S::Error;
120
121 fn serve(
122 &self,
123 mut ctx: Context<State>,
124 req: Request<Body>,
125 ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
126 let mut user_agent = req
127 .headers()
128 .typed_get::<headers::UserAgent>()
129 .map(|ua| UserAgent::new(ua.to_string()));
130
131 if let Some(overwrites) = self
132 .overwrite_header
133 .as_ref()
134 .and_then(|header| req.headers().get(header))
135 .map(|header| header.as_bytes())
136 .and_then(|value| serde_html_form::from_bytes::<UserAgentOverwrites>(value).ok())
137 {
138 if let Some(ua) = overwrites.ua {
139 user_agent = Some(UserAgent::new(ua));
140 }
141 if let Some(ref mut ua) = user_agent {
142 if let Some(http_agent) = overwrites.http {
143 ua.with_http_agent(http_agent);
144 }
145 if let Some(tls_agent) = overwrites.tls {
146 ua.with_tls_agent(tls_agent);
147 }
148 if let Some(preserve_ua) = overwrites.preserve_ua {
149 ua.with_preserve_ua_header(preserve_ua);
150 }
151 }
152 }
153
154 if let Some(ua) = user_agent.take() {
155 ctx.insert(ua);
156 }
157
158 self.inner.serve(ctx, req)
159 }
160}
161
162#[derive(Debug, Clone, Default)]
163pub struct UserAgentClassifierLayer {
167 overwrite_header: Option<HeaderName>,
168}
169
170impl UserAgentClassifierLayer {
171 pub const fn new() -> Self {
173 Self {
174 overwrite_header: None,
175 }
176 }
177
178 pub fn overwrite_header(mut self, header: HeaderName) -> Self {
181 self.overwrite_header = Some(header);
182 self
183 }
184
185 pub fn set_overwrite_header(&mut self, header: HeaderName) -> &mut Self {
188 self.overwrite_header = Some(header);
189 self
190 }
191}
192
193impl<S> Layer<S> for UserAgentClassifierLayer {
194 type Service = UserAgentClassifier<S>;
195
196 fn layer(&self, inner: S) -> Self::Service {
197 UserAgentClassifier::new(inner, self.overwrite_header.clone())
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204 use crate::layer::required_header::AddRequiredRequestHeadersLayer;
205 use crate::service::client::HttpClientExt;
206 use crate::{headers, IntoResponse, Response, StatusCode};
207 use rama_core::service::service_fn;
208 use rama_core::Context;
209 use std::convert::Infallible;
210
211 #[tokio::test]
212 async fn test_user_agent_classifier_layer_ua_rama() {
213 async fn handle<S>(ctx: Context<S>, _req: Request) -> Result<Response, Infallible> {
214 let ua: &UserAgent = ctx.get().unwrap();
215
216 assert_eq!(
217 ua.header_str(),
218 format!("{}/{}", rama_utils::info::NAME, rama_utils::info::VERSION).as_str(),
219 );
220 assert!(ua.info().is_none());
221 assert!(ua.platform().is_none());
222
223 Ok(StatusCode::OK.into_response())
224 }
225
226 let service = (
227 AddRequiredRequestHeadersLayer::default(),
228 UserAgentClassifierLayer::new(),
229 )
230 .layer(service_fn(handle));
231
232 let _ = service
233 .get("http://www.example.com")
234 .send(Context::default())
235 .await
236 .unwrap();
237 }
238
239 #[tokio::test]
240 async fn test_user_agent_classifier_layer_ua_chrome() {
241 const UA: &str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36 Edg/124.0.2478.67";
242
243 async fn handle<S>(ctx: Context<S>, _req: Request) -> Result<Response, Infallible> {
244 let ua: &UserAgent = ctx.get().unwrap();
245
246 assert_eq!(ua.header_str(), UA);
247 let ua_info = ua.info().unwrap();
248 assert_eq!(ua_info.kind, UserAgentKind::Chromium);
249 assert_eq!(ua_info.version, Some(124));
250 assert_eq!(ua.platform(), Some(PlatformKind::Windows));
251
252 Ok(StatusCode::OK.into_response())
253 }
254
255 let service = UserAgentClassifierLayer::new().layer(service_fn(handle));
256
257 let _ = service
258 .get("http://www.example.com")
259 .typed_header(headers::UserAgent::from_static(UA))
260 .send(Context::default())
261 .await
262 .unwrap();
263 }
264
265 #[tokio::test]
266 async fn test_user_agent_classifier_layer_overwrite_ua() {
267 const UA: &str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36 Edg/124.0.2478.67";
268
269 async fn handle<S>(ctx: Context<S>, _req: Request) -> Result<Response, Infallible> {
270 let ua: &UserAgent = ctx.get().unwrap();
271
272 assert_eq!(ua.header_str(), UA);
273 let ua_info = ua.info().unwrap();
274 assert_eq!(ua_info.kind, UserAgentKind::Chromium);
275 assert_eq!(ua_info.version, Some(124));
276 assert_eq!(ua.platform(), Some(PlatformKind::Windows));
277
278 Ok(StatusCode::OK.into_response())
279 }
280
281 let service = UserAgentClassifierLayer::new()
282 .overwrite_header(HeaderName::from_static("x-proxy-ua"))
283 .layer(service_fn(handle));
284
285 let _ = service
286 .get("http://www.example.com")
287 .header(
288 "x-proxy-ua",
289 serde_html_form::to_string(&UserAgentOverwrites {
290 ua: Some(UA.to_owned()),
291 ..Default::default()
292 })
293 .unwrap(),
294 )
295 .send(Context::default())
296 .await
297 .unwrap();
298 }
299
300 #[tokio::test]
301 async fn test_user_agent_classifier_layer_overwrite_ua_all() {
302 const UA: &str = "iPhone App/1.0";
303
304 async fn handle<S>(ctx: Context<S>, _req: Request) -> Result<Response, Infallible> {
305 let ua: &UserAgent = ctx.get().unwrap();
306
307 assert_eq!(ua.header_str(), UA);
308 assert!(ua.info().is_none());
309 assert!(ua.platform().is_none());
310 assert_eq!(ua.http_agent(), HttpAgent::Safari);
311 assert_eq!(ua.tls_agent(), TlsAgent::Boringssl);
312 assert!(ua.preserve_ua_header());
313
314 Ok(StatusCode::OK.into_response())
315 }
316
317 let service = UserAgentClassifierLayer::new()
318 .overwrite_header(HeaderName::from_static("x-proxy-ua"))
319 .layer(service_fn(handle));
320
321 let _ = service
322 .get("http://www.example.com")
323 .header(
324 "x-proxy-ua",
325 serde_html_form::to_string(&UserAgentOverwrites {
326 ua: Some(UA.to_owned()),
327 http: Some(HttpAgent::Safari),
328 tls: Some(TlsAgent::Boringssl),
329 preserve_ua: Some(true),
330 })
331 .unwrap(),
332 )
333 .send(Context::default())
334 .await
335 .unwrap();
336 }
337}