sylvia_iot_data/routes/
middleware.rs

1//! Provides the authentication middleware by sending the Bearer token to [`sylvia-iot-auth`].
2
3use std::{
4    collections::HashMap,
5    task::{Context, Poll},
6};
7
8use axum::{
9    extract::Request,
10    response::{IntoResponse, Response},
11};
12use futures::future::BoxFuture;
13use reqwest;
14use serde::{self, Deserialize};
15use tower::{Layer, Service};
16
17use sylvia_iot_corelib::{err::ErrResp, http as sylvia_http};
18
19#[derive(Clone)]
20pub struct GetTokenInfoData {
21    /// The access token.
22    pub token: String,
23    pub user_id: String,
24    pub account: String,
25    pub roles: HashMap<String, bool>,
26    pub name: String,
27    pub client_id: String,
28    pub scopes: Vec<String>,
29}
30
31#[derive(Clone)]
32pub struct AuthService {
33    auth_uri: String,
34}
35
36#[derive(Clone)]
37pub struct AuthMiddleware<S> {
38    client: reqwest::Client,
39    auth_uri: String,
40    service: S,
41}
42
43/// The user/client information of the token.
44#[derive(Clone, Deserialize)]
45struct GetTokenInfo {
46    data: GetTokenInfoDataInner,
47}
48
49#[derive(Clone, Deserialize)]
50struct GetTokenInfoDataInner {
51    #[serde(rename = "userId")]
52    pub user_id: String,
53    pub account: String,
54    pub roles: HashMap<String, bool>,
55    pub name: String,
56    #[serde(rename = "clientId")]
57    pub client_id: String,
58    pub scopes: Vec<String>,
59}
60
61impl AuthService {
62    pub fn new(auth_uri: String) -> Self {
63        AuthService { auth_uri }
64    }
65}
66
67impl<S> Layer<S> for AuthService {
68    type Service = AuthMiddleware<S>;
69
70    fn layer(&self, inner: S) -> Self::Service {
71        AuthMiddleware {
72            client: reqwest::Client::new(),
73            auth_uri: self.auth_uri.clone(),
74            service: inner,
75        }
76    }
77}
78
79impl<S> Service<Request> for AuthMiddleware<S>
80where
81    S: Service<Request, Response = Response> + Clone + Send + 'static,
82    S::Future: Send + 'static,
83{
84    type Response = S::Response;
85    type Error = S::Error;
86    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
87
88    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
89        self.service.poll_ready(cx)
90    }
91
92    fn call(&mut self, mut req: Request) -> Self::Future {
93        let mut svc = self.service.clone();
94        let client = self.client.clone();
95        let auth_uri = self.auth_uri.clone();
96
97        Box::pin(async move {
98            let token = match sylvia_http::parse_header_auth(&req) {
99                Err(e) => return Ok(e.into_response()),
100                Ok(token) => match token {
101                    None => {
102                        let e = ErrResp::ErrParam(Some("missing token".to_string()));
103                        return Ok(e.into_response());
104                    }
105                    Some(token) => token,
106                },
107            };
108
109            let token_req = match client
110                .request(reqwest::Method::GET, auth_uri.as_str())
111                .header(reqwest::header::AUTHORIZATION, token.as_str())
112                .build()
113            {
114                Err(e) => {
115                    let e = ErrResp::ErrRsc(Some(format!("request auth error: {}", e)));
116                    return Ok(e.into_response());
117                }
118                Ok(req) => req,
119            };
120            let resp = match client.execute(token_req).await {
121                Err(e) => {
122                    let e = ErrResp::ErrIntMsg(Some(format!("auth error: {}", e)));
123                    return Ok(e.into_response());
124                }
125                Ok(resp) => match resp.status() {
126                    reqwest::StatusCode::UNAUTHORIZED => {
127                        return Ok(ErrResp::ErrAuth(None).into_response())
128                    }
129                    reqwest::StatusCode::OK => resp,
130                    _ => {
131                        let e = ErrResp::ErrIntMsg(Some(format!(
132                            "auth error with status code: {}",
133                            resp.status()
134                        )));
135                        return Ok(e.into_response());
136                    }
137                },
138            };
139            let token_info = match resp.json::<GetTokenInfo>().await {
140                Err(e) => {
141                    let e = ErrResp::ErrIntMsg(Some(format!("read auth body error: {}", e)));
142                    return Ok(e.into_response());
143                }
144                Ok(info) => info,
145            };
146
147            let mut split = token.split_whitespace();
148            split.next(); // skip "Bearer".
149            let token = match split.next() {
150                None => {
151                    let e = ErrResp::ErrUnknown(Some("parse token error".to_string()));
152                    return Ok(e.into_response());
153                }
154                Some(token) => token.to_string(),
155            };
156
157            req.extensions_mut().insert(GetTokenInfoData {
158                token,
159                user_id: token_info.data.user_id,
160                account: token_info.data.account,
161                roles: token_info.data.roles,
162                name: token_info.data.name,
163                client_id: token_info.data.client_id,
164                scopes: token_info.data.scopes,
165            });
166
167            let res = svc.call(req).await?;
168            Ok(res)
169        })
170    }
171}