tame_oauth/
token_cache.rs

1//! Provides functionality for caching access tokens and id tokens.
2
3use crate::id_token::{IdTokenOrRequest, IdTokenProvider};
4use crate::token::{TokenOrRequest, TokenProvider};
5use crate::{error::Error, token::RequestReason, IdToken, Token};
6
7use std::hash::Hasher;
8use std::sync::RwLock;
9
10type Hash = u64;
11
12#[derive(Debug)]
13struct Entry<T> {
14    hash: Hash,
15    token: T,
16}
17
18/// An in-memory cache for caching tokens.
19#[derive(Debug)]
20pub struct TokenCache<T> {
21    cache: RwLock<Vec<Entry<T>>>,
22}
23
24pub enum TokenOrRequestReason<T> {
25    Token(T),
26    RequestReason(RequestReason),
27}
28
29impl<T> TokenCache<T> {
30    pub fn new() -> Self {
31        Self {
32            cache: RwLock::new(Vec::new()),
33        }
34    }
35
36    /// Get a token from the cache that matches the hash
37    pub fn get(&self, hash: Hash) -> Result<TokenOrRequestReason<T>, Error>
38    where
39        T: CacheableToken + Clone,
40    {
41        let reason = {
42            let cache = self.cache.read().map_err(|_e| Error::Poisoned)?;
43            match cache.binary_search_by(|i| i.hash.cmp(&hash)) {
44                Ok(i) => {
45                    let token = &cache[i].token;
46
47                    if !token.has_expired() {
48                        return Ok(TokenOrRequestReason::Token(token.clone()));
49                    }
50
51                    RequestReason::Expired
52                }
53                Err(_) => RequestReason::ParametersChanged,
54            }
55        };
56
57        Ok(TokenOrRequestReason::RequestReason(reason))
58    }
59
60    /// Insert a token into the cache
61    pub fn insert(&self, token: T, hash: Hash) -> Result<(), Error> {
62        // Last token wins, which...should?...be fine
63        let mut cache = self.cache.write().map_err(|_e| Error::Poisoned)?;
64        match cache.binary_search_by(|i| i.hash.cmp(&hash)) {
65            Ok(i) => cache[i].token = token,
66            Err(i) => {
67                cache.insert(i, Entry { hash, token });
68            }
69        };
70
71        Ok(())
72    }
73}
74
75impl<T> Default for TokenCache<T> {
76    fn default() -> Self {
77        Self::new()
78    }
79}
80
81pub trait CacheableToken {
82    fn has_expired(&self) -> bool;
83}
84
85/// Wraps a `TokenProvider` in a cache, only invokes the inner `TokenProvider` if
86/// the token in cache is expired, or if it doesn't exist.
87pub struct CachedTokenProvider<P> {
88    access_tokens: TokenCache<Token>,
89    id_tokens: TokenCache<IdToken>,
90    inner: P,
91}
92
93impl<P: std::fmt::Debug> std::fmt::Debug for CachedTokenProvider<P> {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        f.debug_struct("CachedTokenProvider")
96            .field("inner", &self.inner)
97            .finish_non_exhaustive()
98    }
99}
100
101impl<P> CachedTokenProvider<P> {
102    /// Wraps a token provider with a cache
103    pub fn wrap(token_provider: P) -> Self {
104        Self {
105            access_tokens: TokenCache::new(),
106            id_tokens: TokenCache::new(),
107            inner: token_provider,
108        }
109    }
110
111    /// Gets a reference to the wrapped (uncached) token provider
112    pub fn inner(&self) -> &P {
113        &self.inner
114    }
115}
116
117impl<P> TokenProvider for CachedTokenProvider<P>
118where
119    P: TokenProvider,
120{
121    fn get_token_with_subject<'a, S, I, T>(
122        &self,
123        subject: Option<T>,
124        scopes: I,
125    ) -> Result<TokenOrRequest, Error>
126    where
127        S: AsRef<str> + 'a,
128        I: IntoIterator<Item = &'a S> + Clone,
129        T: Into<String>,
130    {
131        let scope_hash = hash_scopes(&scopes);
132
133        let reason = match self.access_tokens.get(scope_hash)? {
134            TokenOrRequestReason::Token(token) => return Ok(TokenOrRequest::Token(token)),
135            TokenOrRequestReason::RequestReason(reason) => reason,
136        };
137
138        match self.inner.get_token_with_subject(subject, scopes)? {
139            TokenOrRequest::Token(token) => Ok(TokenOrRequest::Token(token)),
140            TokenOrRequest::Request { request, .. } => Ok(TokenOrRequest::Request {
141                request,
142                reason,
143                scope_hash,
144            }),
145        }
146    }
147
148    fn parse_token_response<S>(
149        &self,
150        hash: u64,
151        response: http::Response<S>,
152    ) -> Result<Token, Error>
153    where
154        S: AsRef<[u8]>,
155    {
156        let token = self.inner.parse_token_response(hash, response)?;
157
158        self.access_tokens.insert(token.clone(), hash)?;
159        Ok(token)
160    }
161}
162
163impl<P> IdTokenProvider for CachedTokenProvider<P>
164where
165    P: IdTokenProvider,
166{
167    fn get_id_token(&self, audience: &str) -> Result<IdTokenOrRequest, Error> {
168        let hash = hash_str(audience);
169
170        let reason = match self.id_tokens.get(hash)? {
171            TokenOrRequestReason::Token(token) => return Ok(IdTokenOrRequest::IdToken(token)),
172            TokenOrRequestReason::RequestReason(reason) => reason,
173        };
174
175        match self.inner.get_id_token(audience)? {
176            IdTokenOrRequest::IdToken(token) => Ok(IdTokenOrRequest::IdToken(token)),
177            IdTokenOrRequest::AccessTokenRequest { request, .. } => {
178                Ok(IdTokenOrRequest::AccessTokenRequest {
179                    request,
180                    reason,
181                    audience_hash: hash,
182                })
183            }
184            IdTokenOrRequest::IdTokenRequest { request, .. } => {
185                Ok(IdTokenOrRequest::IdTokenRequest {
186                    request,
187                    reason,
188                    audience_hash: hash,
189                })
190            }
191        }
192    }
193
194    fn get_id_token_with_access_token<S>(
195        &self,
196        audience: &str,
197        response: crate::id_token::AccessTokenResponse<S>,
198    ) -> Result<crate::id_token::IdTokenRequest, Error>
199    where
200        S: AsRef<[u8]>,
201    {
202        self.inner
203            .get_id_token_with_access_token(audience, response)
204    }
205
206    fn parse_id_token_response<S>(
207        &self,
208        hash: u64,
209        response: http::Response<S>,
210    ) -> Result<IdToken, Error>
211    where
212        S: AsRef<[u8]>,
213    {
214        let token = self.inner.parse_id_token_response(hash, response)?;
215
216        self.id_tokens.insert(token.clone(), hash)?;
217        Ok(token)
218    }
219}
220
221fn hash_str(str: &str) -> Hash {
222    let hash = {
223        let mut hasher = twox_hash::XxHash::default();
224        hasher.write(str.as_bytes());
225        hasher.finish()
226    };
227
228    hash
229}
230
231fn hash_scopes<'a, I, S>(scopes: &I) -> Hash
232where
233    S: AsRef<str> + 'a,
234    I: IntoIterator<Item = &'a S> + Clone,
235{
236    let scopes_str = scopes
237        .clone()
238        .into_iter()
239        .map(|s| s.as_ref())
240        .collect::<Vec<_>>()
241        .join("|");
242
243    hash_str(&scopes_str)
244}
245
246#[cfg(test)]
247mod test {
248    use std::{
249        ops::Add,
250        ops::Sub,
251        time::{Duration, SystemTime},
252    };
253
254    use super::*;
255
256    #[test]
257    fn test_hash_scopes() {
258        use std::hash::Hasher;
259
260        let expected = {
261            let mut hasher = twox_hash::XxHash::default();
262            hasher.write(b"scope1|");
263            hasher.write(b"scope2|");
264            hasher.write(b"scope3");
265            hasher.finish()
266        };
267
268        let hash = hash_scopes(&["scope1", "scope2", "scope3"].iter());
269
270        assert_eq!(expected, hash);
271
272        let hash = hash_scopes(
273            &[
274                "scope1".to_owned(),
275                "scope2".to_owned(),
276                "scope3".to_owned(),
277            ]
278            .iter(),
279        );
280
281        assert_eq!(expected, hash);
282    }
283
284    #[test]
285    fn test_cache() {
286        let cache = TokenCache::new();
287        let hash = hash_scopes(&["scope1", "scope2"].iter());
288        let token = mock_token(100);
289        let expired_token = mock_token(-100);
290
291        assert!(matches!(
292            cache.get(hash).unwrap(),
293            TokenOrRequestReason::RequestReason(RequestReason::ParametersChanged)
294        ));
295
296        cache.insert(expired_token, hash).unwrap();
297
298        assert!(matches!(
299            cache.get(hash).unwrap(),
300            TokenOrRequestReason::RequestReason(RequestReason::Expired)
301        ));
302
303        cache.insert(token, hash).unwrap();
304
305        assert!(matches!(
306            cache.get(hash).unwrap(),
307            TokenOrRequestReason::Token(..)
308        ));
309    }
310
311    #[test]
312    fn test_cache_wrapper() {
313        let cached_provider = CachedTokenProvider::wrap(PanicProvider);
314
315        let hash = hash_scopes(&["scope1", "scope2"].iter());
316        let token = mock_token(100);
317
318        cached_provider.access_tokens.insert(token, hash).unwrap();
319
320        let tor = cached_provider.get_token(&["scope1", "scope2"]).unwrap();
321
322        // check that a token in returned
323        assert!(matches!(tor, TokenOrRequest::Token(..)));
324    }
325
326    fn mock_token(expires_in: i64) -> Token {
327        let expires_in_timestamp = if expires_in > 0 {
328            SystemTime::now().add(Duration::from_secs(expires_in as u64))
329        } else {
330            SystemTime::now().sub(Duration::from_secs(expires_in.unsigned_abs()))
331        };
332
333        Token {
334            access_token: "access-token".to_string(),
335            refresh_token: "refresh-token".to_string(),
336            token_type: "token-type".to_string(),
337            expires_in: Some(expires_in),
338            expires_in_timestamp: Some(expires_in_timestamp),
339        }
340    }
341
342    /// `PanicProvider` is a mock token provider that panics if called, as a way of
343    /// testing that the cache wrapper handles the request.
344    struct PanicProvider;
345    impl TokenProvider for PanicProvider {
346        fn get_token_with_subject<'a, S, I, T>(
347            &self,
348            _subject: Option<T>,
349            _scopes: I,
350        ) -> Result<TokenOrRequest, Error>
351        where
352            S: AsRef<str> + 'a,
353            I: IntoIterator<Item = &'a S> + Clone,
354            T: Into<String>,
355        {
356            panic!("should not have been reached")
357        }
358
359        fn parse_token_response<S>(
360            &self,
361            _hash: u64,
362            _response: http::Response<S>,
363        ) -> Result<Token, Error>
364        where
365            S: AsRef<[u8]>,
366        {
367            panic!("should not have been reached")
368        }
369    }
370
371    impl IdTokenProvider for PanicProvider {
372        fn get_id_token(&self, _audience: &str) -> Result<IdTokenOrRequest, Error> {
373            panic!("should not have been reached")
374        }
375
376        fn parse_id_token_response<S>(
377            &self,
378            _hash: u64,
379            _response: http::Response<S>,
380        ) -> Result<IdToken, Error>
381        where
382            S: AsRef<[u8]>,
383        {
384            panic!("should not have been reached")
385        }
386
387        fn get_id_token_with_access_token<S>(
388            &self,
389            _audience: &str,
390            _response: crate::id_token::AccessTokenResponse<S>,
391        ) -> Result<crate::id_token::IdTokenRequest, Error>
392        where
393            S: AsRef<[u8]>,
394        {
395            panic!("should not have been reached")
396        }
397    }
398}