sylvia_iot_data/routes/
middleware.rs1use std::{
4 collections::HashMap,
5 task::{Context, Poll},
6};
7
8use axum::{
9 extract::Request,
10 response::{IntoResponse, Response},
11};
12use futures::future::BoxFuture;
13use reqwest;
14use serde::{self, Deserialize};
15use tower::{Layer, Service};
16
17use sylvia_iot_corelib::{err::ErrResp, http as sylvia_http};
18
19#[derive(Clone)]
20pub struct GetTokenInfoData {
21 pub token: String,
23 pub user_id: String,
24 pub account: String,
25 pub roles: HashMap<String, bool>,
26 pub name: String,
27 pub client_id: String,
28 pub scopes: Vec<String>,
29}
30
31#[derive(Clone)]
32pub struct AuthService {
33 auth_uri: String,
34}
35
36#[derive(Clone)]
37pub struct AuthMiddleware<S> {
38 client: reqwest::Client,
39 auth_uri: String,
40 service: S,
41}
42
43#[derive(Clone, Deserialize)]
45struct GetTokenInfo {
46 data: GetTokenInfoDataInner,
47}
48
49#[derive(Clone, Deserialize)]
50struct GetTokenInfoDataInner {
51 #[serde(rename = "userId")]
52 pub user_id: String,
53 pub account: String,
54 pub roles: HashMap<String, bool>,
55 pub name: String,
56 #[serde(rename = "clientId")]
57 pub client_id: String,
58 pub scopes: Vec<String>,
59}
60
61impl AuthService {
62 pub fn new(auth_uri: String) -> Self {
63 AuthService { auth_uri }
64 }
65}
66
67impl<S> Layer<S> for AuthService {
68 type Service = AuthMiddleware<S>;
69
70 fn layer(&self, inner: S) -> Self::Service {
71 AuthMiddleware {
72 client: reqwest::Client::new(),
73 auth_uri: self.auth_uri.clone(),
74 service: inner,
75 }
76 }
77}
78
79impl<S> Service<Request> for AuthMiddleware<S>
80where
81 S: Service<Request, Response = Response> + Clone + Send + 'static,
82 S::Future: Send + 'static,
83{
84 type Response = S::Response;
85 type Error = S::Error;
86 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
87
88 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
89 self.service.poll_ready(cx)
90 }
91
92 fn call(&mut self, mut req: Request) -> Self::Future {
93 let mut svc = self.service.clone();
94 let client = self.client.clone();
95 let auth_uri = self.auth_uri.clone();
96
97 Box::pin(async move {
98 let token = match sylvia_http::parse_header_auth(&req) {
99 Err(e) => return Ok(e.into_response()),
100 Ok(token) => match token {
101 None => {
102 let e = ErrResp::ErrParam(Some("missing token".to_string()));
103 return Ok(e.into_response());
104 }
105 Some(token) => token,
106 },
107 };
108
109 let token_req = match client
110 .request(reqwest::Method::GET, auth_uri.as_str())
111 .header(reqwest::header::AUTHORIZATION, token.as_str())
112 .build()
113 {
114 Err(e) => {
115 let e = ErrResp::ErrRsc(Some(format!("request auth error: {}", e)));
116 return Ok(e.into_response());
117 }
118 Ok(req) => req,
119 };
120 let resp = match client.execute(token_req).await {
121 Err(e) => {
122 let e = ErrResp::ErrIntMsg(Some(format!("auth error: {}", e)));
123 return Ok(e.into_response());
124 }
125 Ok(resp) => match resp.status() {
126 reqwest::StatusCode::UNAUTHORIZED => {
127 return Ok(ErrResp::ErrAuth(None).into_response())
128 }
129 reqwest::StatusCode::OK => resp,
130 _ => {
131 let e = ErrResp::ErrIntMsg(Some(format!(
132 "auth error with status code: {}",
133 resp.status()
134 )));
135 return Ok(e.into_response());
136 }
137 },
138 };
139 let token_info = match resp.json::<GetTokenInfo>().await {
140 Err(e) => {
141 let e = ErrResp::ErrIntMsg(Some(format!("read auth body error: {}", e)));
142 return Ok(e.into_response());
143 }
144 Ok(info) => info,
145 };
146
147 let mut split = token.split_whitespace();
148 split.next(); let token = match split.next() {
150 None => {
151 let e = ErrResp::ErrUnknown(Some("parse token error".to_string()));
152 return Ok(e.into_response());
153 }
154 Some(token) => token.to_string(),
155 };
156
157 req.extensions_mut().insert(GetTokenInfoData {
158 token,
159 user_id: token_info.data.user_id,
160 account: token_info.data.account,
161 roles: token_info.data.roles,
162 name: token_info.data.name,
163 client_id: token_info.data.client_id,
164 scopes: token_info.data.scopes,
165 });
166
167 let res = svc.call(req).await?;
168 Ok(res)
169 })
170 }
171}