sylvia_iot_broker/routes/
middleware.rs

1//! Provides the authentication middleware by sending the Bearer token to [`sylvia-iot-auth`].
2
3use std::{
4    collections::{HashMap, HashSet},
5    task::{Context, Poll},
6};
7
8use axum::{
9    extract::Request,
10    http::Method,
11    response::{IntoResponse, Response},
12};
13use futures::future::BoxFuture;
14use reqwest;
15use serde::{self, Deserialize};
16use tower::{Layer, Service};
17
18use sylvia_iot_corelib::{err::ErrResp, http as sylvia_http};
19
20pub type RoleScopeType = (Vec<&'static str>, Vec<String>);
21type RoleScopeInner = (HashSet<&'static str>, HashSet<String>);
22
23#[derive(Clone)]
24pub struct GetTokenInfoData {
25    /// The access token.
26    pub token: String,
27    pub user_id: String,
28    pub account: String,
29    pub roles: HashMap<String, bool>,
30    pub name: String,
31    pub client_id: String,
32    pub scopes: Vec<String>,
33}
34
35#[derive(Clone)]
36pub struct AuthService {
37    auth_uri: String,
38    role_scopes: HashMap<Method, RoleScopeType>,
39}
40
41#[derive(Clone)]
42pub struct AuthMiddleware<S> {
43    client: reqwest::Client,
44    auth_uri: String,
45    role_scopes: HashMap<Method, RoleScopeInner>,
46    service: S,
47}
48
49/// The user/client information of the token.
50#[derive(Deserialize)]
51struct GetTokenInfo {
52    data: GetTokenInfoDataInner,
53}
54
55#[derive(Deserialize)]
56struct GetTokenInfoDataInner {
57    #[serde(rename = "userId")]
58    user_id: String,
59    account: String,
60    roles: HashMap<String, bool>,
61    name: String,
62    #[serde(rename = "clientId")]
63    client_id: String,
64    scopes: Vec<String>,
65}
66
67impl AuthService {
68    pub fn new(auth_uri: String, role_scopes: HashMap<Method, RoleScopeType>) -> Self {
69        AuthService {
70            role_scopes,
71            auth_uri,
72        }
73    }
74}
75
76impl<S> Layer<S> for AuthService {
77    type Service = AuthMiddleware<S>;
78
79    fn layer(&self, inner: S) -> Self::Service {
80        let mut role_scopes: HashMap<Method, RoleScopeInner> = HashMap::new();
81        for (k, (r, s)) in self.role_scopes.iter() {
82            role_scopes.insert(
83                k.clone(),
84                (
85                    r.iter().map(|&r| r).collect(),
86                    s.iter().map(|s| s.clone()).collect(),
87                ),
88            );
89        }
90
91        AuthMiddleware {
92            client: reqwest::Client::new(),
93            auth_uri: self.auth_uri.clone(),
94            role_scopes,
95            service: inner,
96        }
97    }
98}
99
100impl<S> Service<Request> for AuthMiddleware<S>
101where
102    S: Service<Request, Response = Response> + Clone + Send + 'static,
103    S::Future: Send + 'static,
104{
105    type Response = S::Response;
106    type Error = S::Error;
107    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
108
109    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
110        self.service.poll_ready(cx)
111    }
112
113    fn call(&mut self, mut req: Request) -> Self::Future {
114        let mut svc = self.service.clone();
115        let client = self.client.clone();
116        let auth_uri = self.auth_uri.clone();
117        let role_scopes = self.role_scopes.clone();
118
119        Box::pin(async move {
120            let token = match sylvia_http::parse_header_auth(&req) {
121                Err(e) => return Ok(e.into_response()),
122                Ok(token) => match token {
123                    None => {
124                        let e = ErrResp::ErrParam(Some("missing token".to_string()));
125                        return Ok(e.into_response());
126                    }
127                    Some(token) => token,
128                },
129            };
130
131            let token_req = match client
132                .request(reqwest::Method::GET, auth_uri.as_str())
133                .header(reqwest::header::AUTHORIZATION, token.as_str())
134                .build()
135            {
136                Err(e) => {
137                    let e = ErrResp::ErrRsc(Some(format!("request auth error: {}", e)));
138                    return Ok(e.into_response());
139                }
140                Ok(req) => req,
141            };
142            let resp = match client.execute(token_req).await {
143                Err(e) => {
144                    let e = ErrResp::ErrIntMsg(Some(format!("auth error: {}", e)));
145                    return Ok(e.into_response());
146                }
147                Ok(resp) => match resp.status() {
148                    reqwest::StatusCode::UNAUTHORIZED => {
149                        return Ok(ErrResp::ErrAuth(None).into_response())
150                    }
151                    reqwest::StatusCode::OK => resp,
152                    _ => {
153                        let e = ErrResp::ErrIntMsg(Some(format!(
154                            "auth error with status code: {}",
155                            resp.status()
156                        )));
157                        return Ok(e.into_response());
158                    }
159                },
160            };
161            let token_info = match resp.json::<GetTokenInfo>().await {
162                Err(e) => {
163                    let e = ErrResp::ErrIntMsg(Some(format!("read auth body error: {}", e)));
164                    return Ok(e.into_response());
165                }
166                Ok(info) => info,
167            };
168
169            if let Some((api_roles, api_scopes)) = role_scopes.get(req.method()) {
170                if api_roles.len() > 0 {
171                    let roles: HashSet<&str> = token_info
172                        .data
173                        .roles
174                        .iter()
175                        .filter(|(_, &v)| v)
176                        .map(|(k, _)| k.as_str())
177                        .collect();
178                    if api_roles.is_disjoint(&roles) {
179                        let e = ErrResp::ErrPerm(Some("invalid role".to_string()));
180                        return Ok(e.into_response());
181                    }
182                }
183                if api_scopes.len() > 0 {
184                    let api_scopes: HashSet<&str> = api_scopes.iter().map(|s| s.as_str()).collect();
185                    let scopes: HashSet<&str> =
186                        token_info.data.scopes.iter().map(|s| s.as_str()).collect();
187                    if api_scopes.is_disjoint(&scopes) {
188                        let e = ErrResp::ErrPerm(Some("invalid scope".to_string()));
189                        return Ok(e.into_response());
190                    }
191                }
192            }
193
194            let mut split = token.split_whitespace();
195            split.next(); // skip "Bearer".
196            let token = match split.next() {
197                None => {
198                    let e = ErrResp::ErrUnknown(Some("parse token error".to_string()));
199                    return Ok(e.into_response());
200                }
201                Some(token) => token.to_string(),
202            };
203
204            req.extensions_mut().insert(GetTokenInfoData {
205                token,
206                user_id: token_info.data.user_id,
207                account: token_info.data.account,
208                roles: token_info.data.roles,
209                name: token_info.data.name,
210                client_id: token_info.data.client_id,
211                scopes: token_info.data.scopes,
212            });
213
214            let res = svc.call(req).await?;
215            Ok(res)
216        })
217    }
218}