framework_cqrs_lib/cqrs/infra/token/services/
jwt_rsa.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
use std::fmt::Debug;
use std::sync::Arc;

use async_trait::async_trait;
use jsonwebtoken::{Algorithm, decode, decode_header, DecodingKey, Validation};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde::de::DeserializeOwned;

use crate::cqrs::infra::cache::CacheAsync;
use crate::cqrs::core::token::TokenService;
use crate::cqrs::models::errors::{Error, ResultErr};

pub struct JwtRSATokenService {
    pub cache: Arc<CacheAsync>,
    pub http_client: Arc<Client>,
    pub auth_back_url: String,
}

impl JwtRSATokenService {
    pub fn new(cache: Arc<CacheAsync>, http_client: Arc<Client>, auth_back_url: String) -> Self {
        Self {
            cache,
            http_client,
            auth_back_url,
        }
    }
}

#[async_trait]
impl TokenService for JwtRSATokenService {
    async fn decode<CLAIMS: Debug + Serialize + DeserializeOwned>(&self, token: &str) -> ResultErr<CLAIMS> {
        let header = decode_header(token).map_err(|err| {
            let message = err.to_string();
            Error::Simple(format!("decode header token : {message}"))
        })?;

        let kid = header.kid.ok_or(Error::Simple("jwt invalid, pas de kid dans l'entete".to_string()))?;
        let maybe_data = self.cache.get(&kid).await;

        let jwk = match maybe_data {
            Some(data) => {
                let jwk = serde_json::from_str::<JWK>(data.as_str())
                    .map_err(|err| Error::Simple(err.to_string()))?;
                Ok(jwk)
            }
            None => {
                let url = format!("{}/v1/jwks/{kid}/public", self.auth_back_url);
                let response = self.http_client
                    .get(url)
                    .send()
                    .await.map_err(|err| Error::Simple(err.to_string()))?;
                if response.status() == 200 {
                    let jwk = response.json::<JWK>().await.map_err(|err| Error::Simple(err.to_string()))?;
                    let stringify = serde_json::to_string(&jwk)
                        .map_err(|err| Error::Simple(err.to_string()))?;
                    self.cache.upsert(kid, stringify).await;
                    Ok(jwk)
                } else {
                    Err(Error::Simple("erreur lors du call authbacku".to_string()))
                }
            }
        }?;

        let decoding_key = DecodingKey::from_rsa_components(
            jwk.n.as_str(),
            jwk.e.as_str(),
        ).map_err(|err| Error::Simple({
            let error_message = err.to_string();
            format!("decoding key : {error_message}")
        }))?;


        decode::<CLAIMS>(token, &decoding_key, &Validation::new(Algorithm::RS256))
            .map(|token_data| token_data.claims)
            .map_err(|err| {
                let message = err.to_string();
                Error::Simple(format!("decode token : {message}"))
            })
    }
}

#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct JWK {
    n: String,
    e: String,
}