rama_http/layer/
ua.rs

1//! User-Agent (see also `rama-ua`) http layer support
2//!
3//! # Example
4//!
5//! ```
6//! use rama_http::{
7//!     service::client::HttpClientExt, IntoResponse, Request, Response, StatusCode,
8//!     layer::ua::{PlatformKind, UserAgent, UserAgentClassifierLayer, UserAgentKind, UserAgentInfo},
9//! };
10//! use rama_core::{Context, Layer, service::service_fn};
11//! use std::convert::Infallible;
12//!
13//! const UA: &str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36 Edg/124.0.2478.67";
14//!
15//! async fn handle<S>(ctx: Context<S>, _req: Request) -> Result<Response, Infallible> {
16//!     let ua: &UserAgent = ctx.get().unwrap();
17//!
18//!     assert_eq!(ua.header_str(), UA);
19//!     assert_eq!(ua.info(), Some(UserAgentInfo{ kind: UserAgentKind::Chromium, version: Some(124) }));
20//!     assert_eq!(ua.platform(), Some(PlatformKind::Windows));
21//!
22//!     Ok(StatusCode::OK.into_response())
23//! }
24//!
25//! # #[tokio::main]
26//! # async fn main() {
27//! let service = UserAgentClassifierLayer::new().layer(service_fn(handle));
28//!
29//! let _ = service
30//!     .get("http://www.example.com")
31//!     .typed_header(headers::UserAgent::from_static(UA))
32//!     .send(Context::default())
33//!     .await
34//!     .unwrap();
35//! # }
36//! ```
37
38use crate::{
39    headers::{self, HeaderMapExt},
40    HeaderName, Request,
41};
42use rama_core::{Context, Layer, Service};
43use rama_utils::macros::define_inner_service_accessors;
44use std::{
45    fmt::{self, Debug},
46    future::Future,
47};
48
49pub use rama_ua::{
50    DeviceKind, HttpAgent, PlatformKind, TlsAgent, UserAgent, UserAgentInfo, UserAgentKind,
51    UserAgentOverwrites,
52};
53
54/// A [`Service`] that classifies the [`UserAgent`] of incoming [`Request`]s.
55///
56/// The [`Extensions`] of the [`Context`] is updated with the [`UserAgent`]
57/// if the [`Request`] contains a valid [`UserAgent`] header.
58///
59/// [`Extensions`]: rama_core::context::Extensions
60/// [`Context`]: rama_core::Context
61pub struct UserAgentClassifier<S> {
62    inner: S,
63    overwrite_header: Option<HeaderName>,
64}
65
66impl<S> UserAgentClassifier<S> {
67    /// Create a new [`UserAgentClassifier`] [`Service`].
68    pub const fn new(inner: S, overwrite_header: Option<HeaderName>) -> Self {
69        Self {
70            inner,
71            overwrite_header,
72        }
73    }
74
75    define_inner_service_accessors!();
76}
77
78impl<S> Debug for UserAgentClassifier<S>
79where
80    S: Debug,
81{
82    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
83        f.debug_struct("UserAgentClassifier")
84            .field("inner", &self.inner)
85            .finish()
86    }
87}
88
89impl<S> Clone for UserAgentClassifier<S>
90where
91    S: Clone,
92{
93    fn clone(&self) -> Self {
94        Self {
95            inner: self.inner.clone(),
96            overwrite_header: self.overwrite_header.clone(),
97        }
98    }
99}
100
101impl<S> Default for UserAgentClassifier<S>
102where
103    S: Default,
104{
105    fn default() -> Self {
106        Self {
107            inner: S::default(),
108            overwrite_header: None,
109        }
110    }
111}
112
113impl<S, State, Body> Service<State, Request<Body>> for UserAgentClassifier<S>
114where
115    S: Service<State, Request<Body>>,
116    State: Clone + Send + Sync + 'static,
117{
118    type Response = S::Response;
119    type Error = S::Error;
120
121    fn serve(
122        &self,
123        mut ctx: Context<State>,
124        req: Request<Body>,
125    ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
126        let mut user_agent = req
127            .headers()
128            .typed_get::<headers::UserAgent>()
129            .map(|ua| UserAgent::new(ua.to_string()));
130
131        if let Some(overwrites) = self
132            .overwrite_header
133            .as_ref()
134            .and_then(|header| req.headers().get(header))
135            .map(|header| header.as_bytes())
136            .and_then(|value| serde_html_form::from_bytes::<UserAgentOverwrites>(value).ok())
137        {
138            if let Some(ua) = overwrites.ua {
139                user_agent = Some(UserAgent::new(ua));
140            }
141            if let Some(ref mut ua) = user_agent {
142                if let Some(http_agent) = overwrites.http {
143                    ua.with_http_agent(http_agent);
144                }
145                if let Some(tls_agent) = overwrites.tls {
146                    ua.with_tls_agent(tls_agent);
147                }
148                if let Some(preserve_ua) = overwrites.preserve_ua {
149                    ua.with_preserve_ua_header(preserve_ua);
150                }
151            }
152        }
153
154        if let Some(ua) = user_agent.take() {
155            ctx.insert(ua);
156        }
157
158        self.inner.serve(ctx, req)
159    }
160}
161
162#[derive(Debug, Clone, Default)]
163/// A [`Layer`] that wraps a [`Service`] with a [`UserAgentClassifier`].
164///
165/// This [`Layer`] is used to classify the [`UserAgent`] of incoming [`Request`]s.
166pub struct UserAgentClassifierLayer {
167    overwrite_header: Option<HeaderName>,
168}
169
170impl UserAgentClassifierLayer {
171    /// Create a new [`UserAgentClassifierLayer`].
172    pub const fn new() -> Self {
173        Self {
174            overwrite_header: None,
175        }
176    }
177
178    /// Define a custom header to allow overwriting certain
179    /// [`UserAgent`] information.
180    pub fn overwrite_header(mut self, header: HeaderName) -> Self {
181        self.overwrite_header = Some(header);
182        self
183    }
184
185    /// Define a custom header to allow overwriting certain
186    /// [`UserAgent`] information.
187    pub fn set_overwrite_header(&mut self, header: HeaderName) -> &mut Self {
188        self.overwrite_header = Some(header);
189        self
190    }
191}
192
193impl<S> Layer<S> for UserAgentClassifierLayer {
194    type Service = UserAgentClassifier<S>;
195
196    fn layer(&self, inner: S) -> Self::Service {
197        UserAgentClassifier::new(inner, self.overwrite_header.clone())
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204    use crate::layer::required_header::AddRequiredRequestHeadersLayer;
205    use crate::service::client::HttpClientExt;
206    use crate::{headers, IntoResponse, Response, StatusCode};
207    use rama_core::service::service_fn;
208    use rama_core::Context;
209    use std::convert::Infallible;
210
211    #[tokio::test]
212    async fn test_user_agent_classifier_layer_ua_rama() {
213        async fn handle<S>(ctx: Context<S>, _req: Request) -> Result<Response, Infallible> {
214            let ua: &UserAgent = ctx.get().unwrap();
215
216            assert_eq!(
217                ua.header_str(),
218                format!("{}/{}", rama_utils::info::NAME, rama_utils::info::VERSION).as_str(),
219            );
220            assert!(ua.info().is_none());
221            assert!(ua.platform().is_none());
222
223            Ok(StatusCode::OK.into_response())
224        }
225
226        let service = (
227            AddRequiredRequestHeadersLayer::default(),
228            UserAgentClassifierLayer::new(),
229        )
230            .layer(service_fn(handle));
231
232        let _ = service
233            .get("http://www.example.com")
234            .send(Context::default())
235            .await
236            .unwrap();
237    }
238
239    #[tokio::test]
240    async fn test_user_agent_classifier_layer_ua_chrome() {
241        const UA: &str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36 Edg/124.0.2478.67";
242
243        async fn handle<S>(ctx: Context<S>, _req: Request) -> Result<Response, Infallible> {
244            let ua: &UserAgent = ctx.get().unwrap();
245
246            assert_eq!(ua.header_str(), UA);
247            let ua_info = ua.info().unwrap();
248            assert_eq!(ua_info.kind, UserAgentKind::Chromium);
249            assert_eq!(ua_info.version, Some(124));
250            assert_eq!(ua.platform(), Some(PlatformKind::Windows));
251
252            Ok(StatusCode::OK.into_response())
253        }
254
255        let service = UserAgentClassifierLayer::new().layer(service_fn(handle));
256
257        let _ = service
258            .get("http://www.example.com")
259            .typed_header(headers::UserAgent::from_static(UA))
260            .send(Context::default())
261            .await
262            .unwrap();
263    }
264
265    #[tokio::test]
266    async fn test_user_agent_classifier_layer_overwrite_ua() {
267        const UA: &str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36 Edg/124.0.2478.67";
268
269        async fn handle<S>(ctx: Context<S>, _req: Request) -> Result<Response, Infallible> {
270            let ua: &UserAgent = ctx.get().unwrap();
271
272            assert_eq!(ua.header_str(), UA);
273            let ua_info = ua.info().unwrap();
274            assert_eq!(ua_info.kind, UserAgentKind::Chromium);
275            assert_eq!(ua_info.version, Some(124));
276            assert_eq!(ua.platform(), Some(PlatformKind::Windows));
277
278            Ok(StatusCode::OK.into_response())
279        }
280
281        let service = UserAgentClassifierLayer::new()
282            .overwrite_header(HeaderName::from_static("x-proxy-ua"))
283            .layer(service_fn(handle));
284
285        let _ = service
286            .get("http://www.example.com")
287            .header(
288                "x-proxy-ua",
289                serde_html_form::to_string(&UserAgentOverwrites {
290                    ua: Some(UA.to_owned()),
291                    ..Default::default()
292                })
293                .unwrap(),
294            )
295            .send(Context::default())
296            .await
297            .unwrap();
298    }
299
300    #[tokio::test]
301    async fn test_user_agent_classifier_layer_overwrite_ua_all() {
302        const UA: &str = "iPhone App/1.0";
303
304        async fn handle<S>(ctx: Context<S>, _req: Request) -> Result<Response, Infallible> {
305            let ua: &UserAgent = ctx.get().unwrap();
306
307            assert_eq!(ua.header_str(), UA);
308            assert!(ua.info().is_none());
309            assert!(ua.platform().is_none());
310            assert_eq!(ua.http_agent(), HttpAgent::Safari);
311            assert_eq!(ua.tls_agent(), TlsAgent::Boringssl);
312            assert!(ua.preserve_ua_header());
313
314            Ok(StatusCode::OK.into_response())
315        }
316
317        let service = UserAgentClassifierLayer::new()
318            .overwrite_header(HeaderName::from_static("x-proxy-ua"))
319            .layer(service_fn(handle));
320
321        let _ = service
322            .get("http://www.example.com")
323            .header(
324                "x-proxy-ua",
325                serde_html_form::to_string(&UserAgentOverwrites {
326                    ua: Some(UA.to_owned()),
327                    http: Some(HttpAgent::Safari),
328                    tls: Some(TlsAgent::Boringssl),
329                    preserve_ua: Some(true),
330                })
331                .unwrap(),
332            )
333            .send(Context::default())
334            .await
335            .unwrap();
336    }
337}