sylvia_iot_broker/routes/
middleware.rs1use std::{
4 collections::{HashMap, HashSet},
5 task::{Context, Poll},
6};
7
8use axum::{
9 extract::Request,
10 http::Method,
11 response::{IntoResponse, Response},
12};
13use futures::future::BoxFuture;
14use reqwest;
15use serde::{self, Deserialize};
16use tower::{Layer, Service};
17
18use sylvia_iot_corelib::{err::ErrResp, http as sylvia_http};
19
20pub type RoleScopeType = (Vec<&'static str>, Vec<String>);
21type RoleScopeInner = (HashSet<&'static str>, HashSet<String>);
22
23#[derive(Clone)]
24pub struct GetTokenInfoData {
25 pub token: String,
27 pub user_id: String,
28 pub account: String,
29 pub roles: HashMap<String, bool>,
30 pub name: String,
31 pub client_id: String,
32 pub scopes: Vec<String>,
33}
34
35#[derive(Clone)]
36pub struct AuthService {
37 auth_uri: String,
38 role_scopes: HashMap<Method, RoleScopeType>,
39}
40
41#[derive(Clone)]
42pub struct AuthMiddleware<S> {
43 client: reqwest::Client,
44 auth_uri: String,
45 role_scopes: HashMap<Method, RoleScopeInner>,
46 service: S,
47}
48
49#[derive(Deserialize)]
51struct GetTokenInfo {
52 data: GetTokenInfoDataInner,
53}
54
55#[derive(Deserialize)]
56struct GetTokenInfoDataInner {
57 #[serde(rename = "userId")]
58 user_id: String,
59 account: String,
60 roles: HashMap<String, bool>,
61 name: String,
62 #[serde(rename = "clientId")]
63 client_id: String,
64 scopes: Vec<String>,
65}
66
67impl AuthService {
68 pub fn new(auth_uri: String, role_scopes: HashMap<Method, RoleScopeType>) -> Self {
69 AuthService {
70 role_scopes,
71 auth_uri,
72 }
73 }
74}
75
76impl<S> Layer<S> for AuthService {
77 type Service = AuthMiddleware<S>;
78
79 fn layer(&self, inner: S) -> Self::Service {
80 let mut role_scopes: HashMap<Method, RoleScopeInner> = HashMap::new();
81 for (k, (r, s)) in self.role_scopes.iter() {
82 role_scopes.insert(
83 k.clone(),
84 (
85 r.iter().map(|&r| r).collect(),
86 s.iter().map(|s| s.clone()).collect(),
87 ),
88 );
89 }
90
91 AuthMiddleware {
92 client: reqwest::Client::new(),
93 auth_uri: self.auth_uri.clone(),
94 role_scopes,
95 service: inner,
96 }
97 }
98}
99
100impl<S> Service<Request> for AuthMiddleware<S>
101where
102 S: Service<Request, Response = Response> + Clone + Send + 'static,
103 S::Future: Send + 'static,
104{
105 type Response = S::Response;
106 type Error = S::Error;
107 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
108
109 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
110 self.service.poll_ready(cx)
111 }
112
113 fn call(&mut self, mut req: Request) -> Self::Future {
114 let mut svc = self.service.clone();
115 let client = self.client.clone();
116 let auth_uri = self.auth_uri.clone();
117 let role_scopes = self.role_scopes.clone();
118
119 Box::pin(async move {
120 let token = match sylvia_http::parse_header_auth(&req) {
121 Err(e) => return Ok(e.into_response()),
122 Ok(token) => match token {
123 None => {
124 let e = ErrResp::ErrParam(Some("missing token".to_string()));
125 return Ok(e.into_response());
126 }
127 Some(token) => token,
128 },
129 };
130
131 let token_req = match client
132 .request(reqwest::Method::GET, auth_uri.as_str())
133 .header(reqwest::header::AUTHORIZATION, token.as_str())
134 .build()
135 {
136 Err(e) => {
137 let e = ErrResp::ErrRsc(Some(format!("request auth error: {}", e)));
138 return Ok(e.into_response());
139 }
140 Ok(req) => req,
141 };
142 let resp = match client.execute(token_req).await {
143 Err(e) => {
144 let e = ErrResp::ErrIntMsg(Some(format!("auth error: {}", e)));
145 return Ok(e.into_response());
146 }
147 Ok(resp) => match resp.status() {
148 reqwest::StatusCode::UNAUTHORIZED => {
149 return Ok(ErrResp::ErrAuth(None).into_response())
150 }
151 reqwest::StatusCode::OK => resp,
152 _ => {
153 let e = ErrResp::ErrIntMsg(Some(format!(
154 "auth error with status code: {}",
155 resp.status()
156 )));
157 return Ok(e.into_response());
158 }
159 },
160 };
161 let token_info = match resp.json::<GetTokenInfo>().await {
162 Err(e) => {
163 let e = ErrResp::ErrIntMsg(Some(format!("read auth body error: {}", e)));
164 return Ok(e.into_response());
165 }
166 Ok(info) => info,
167 };
168
169 if let Some((api_roles, api_scopes)) = role_scopes.get(req.method()) {
170 if api_roles.len() > 0 {
171 let roles: HashSet<&str> = token_info
172 .data
173 .roles
174 .iter()
175 .filter(|(_, &v)| v)
176 .map(|(k, _)| k.as_str())
177 .collect();
178 if api_roles.is_disjoint(&roles) {
179 let e = ErrResp::ErrPerm(Some("invalid role".to_string()));
180 return Ok(e.into_response());
181 }
182 }
183 if api_scopes.len() > 0 {
184 let api_scopes: HashSet<&str> = api_scopes.iter().map(|s| s.as_str()).collect();
185 let scopes: HashSet<&str> =
186 token_info.data.scopes.iter().map(|s| s.as_str()).collect();
187 if api_scopes.is_disjoint(&scopes) {
188 let e = ErrResp::ErrPerm(Some("invalid scope".to_string()));
189 return Ok(e.into_response());
190 }
191 }
192 }
193
194 let mut split = token.split_whitespace();
195 split.next(); let token = match split.next() {
197 None => {
198 let e = ErrResp::ErrUnknown(Some("parse token error".to_string()));
199 return Ok(e.into_response());
200 }
201 Some(token) => token.to_string(),
202 };
203
204 req.extensions_mut().insert(GetTokenInfoData {
205 token,
206 user_id: token_info.data.user_id,
207 account: token_info.data.account,
208 roles: token_info.data.roles,
209 name: token_info.data.name,
210 client_id: token_info.data.client_id,
211 scopes: token_info.data.scopes,
212 });
213
214 let res = svc.call(req).await?;
215 Ok(res)
216 })
217 }
218}