1use crate::id_token::{IdTokenOrRequest, IdTokenProvider};
4use crate::token::{TokenOrRequest, TokenProvider};
5use crate::{error::Error, token::RequestReason, IdToken, Token};
6
7use std::hash::Hasher;
8use std::sync::RwLock;
9
10type Hash = u64;
11
12#[derive(Debug)]
13struct Entry<T> {
14 hash: Hash,
15 token: T,
16}
17
18#[derive(Debug)]
20pub struct TokenCache<T> {
21 cache: RwLock<Vec<Entry<T>>>,
22}
23
24pub enum TokenOrRequestReason<T> {
25 Token(T),
26 RequestReason(RequestReason),
27}
28
29impl<T> TokenCache<T> {
30 pub fn new() -> Self {
31 Self {
32 cache: RwLock::new(Vec::new()),
33 }
34 }
35
36 pub fn get(&self, hash: Hash) -> Result<TokenOrRequestReason<T>, Error>
38 where
39 T: CacheableToken + Clone,
40 {
41 let reason = {
42 let cache = self.cache.read().map_err(|_e| Error::Poisoned)?;
43 match cache.binary_search_by(|i| i.hash.cmp(&hash)) {
44 Ok(i) => {
45 let token = &cache[i].token;
46
47 if !token.has_expired() {
48 return Ok(TokenOrRequestReason::Token(token.clone()));
49 }
50
51 RequestReason::Expired
52 }
53 Err(_) => RequestReason::ParametersChanged,
54 }
55 };
56
57 Ok(TokenOrRequestReason::RequestReason(reason))
58 }
59
60 pub fn insert(&self, token: T, hash: Hash) -> Result<(), Error> {
62 let mut cache = self.cache.write().map_err(|_e| Error::Poisoned)?;
64 match cache.binary_search_by(|i| i.hash.cmp(&hash)) {
65 Ok(i) => cache[i].token = token,
66 Err(i) => {
67 cache.insert(i, Entry { hash, token });
68 }
69 };
70
71 Ok(())
72 }
73}
74
75impl<T> Default for TokenCache<T> {
76 fn default() -> Self {
77 Self::new()
78 }
79}
80
81pub trait CacheableToken {
82 fn has_expired(&self) -> bool;
83}
84
85pub struct CachedTokenProvider<P> {
88 access_tokens: TokenCache<Token>,
89 id_tokens: TokenCache<IdToken>,
90 inner: P,
91}
92
93impl<P: std::fmt::Debug> std::fmt::Debug for CachedTokenProvider<P> {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 f.debug_struct("CachedTokenProvider")
96 .field("inner", &self.inner)
97 .finish_non_exhaustive()
98 }
99}
100
101impl<P> CachedTokenProvider<P> {
102 pub fn wrap(token_provider: P) -> Self {
104 Self {
105 access_tokens: TokenCache::new(),
106 id_tokens: TokenCache::new(),
107 inner: token_provider,
108 }
109 }
110
111 pub fn inner(&self) -> &P {
113 &self.inner
114 }
115}
116
117impl<P> TokenProvider for CachedTokenProvider<P>
118where
119 P: TokenProvider,
120{
121 fn get_token_with_subject<'a, S, I, T>(
122 &self,
123 subject: Option<T>,
124 scopes: I,
125 ) -> Result<TokenOrRequest, Error>
126 where
127 S: AsRef<str> + 'a,
128 I: IntoIterator<Item = &'a S> + Clone,
129 T: Into<String>,
130 {
131 let scope_hash = hash_scopes(&scopes);
132
133 let reason = match self.access_tokens.get(scope_hash)? {
134 TokenOrRequestReason::Token(token) => return Ok(TokenOrRequest::Token(token)),
135 TokenOrRequestReason::RequestReason(reason) => reason,
136 };
137
138 match self.inner.get_token_with_subject(subject, scopes)? {
139 TokenOrRequest::Token(token) => Ok(TokenOrRequest::Token(token)),
140 TokenOrRequest::Request { request, .. } => Ok(TokenOrRequest::Request {
141 request,
142 reason,
143 scope_hash,
144 }),
145 }
146 }
147
148 fn parse_token_response<S>(
149 &self,
150 hash: u64,
151 response: http::Response<S>,
152 ) -> Result<Token, Error>
153 where
154 S: AsRef<[u8]>,
155 {
156 let token = self.inner.parse_token_response(hash, response)?;
157
158 self.access_tokens.insert(token.clone(), hash)?;
159 Ok(token)
160 }
161}
162
163impl<P> IdTokenProvider for CachedTokenProvider<P>
164where
165 P: IdTokenProvider,
166{
167 fn get_id_token(&self, audience: &str) -> Result<IdTokenOrRequest, Error> {
168 let hash = hash_str(audience);
169
170 let reason = match self.id_tokens.get(hash)? {
171 TokenOrRequestReason::Token(token) => return Ok(IdTokenOrRequest::IdToken(token)),
172 TokenOrRequestReason::RequestReason(reason) => reason,
173 };
174
175 match self.inner.get_id_token(audience)? {
176 IdTokenOrRequest::IdToken(token) => Ok(IdTokenOrRequest::IdToken(token)),
177 IdTokenOrRequest::AccessTokenRequest { request, .. } => {
178 Ok(IdTokenOrRequest::AccessTokenRequest {
179 request,
180 reason,
181 audience_hash: hash,
182 })
183 }
184 IdTokenOrRequest::IdTokenRequest { request, .. } => {
185 Ok(IdTokenOrRequest::IdTokenRequest {
186 request,
187 reason,
188 audience_hash: hash,
189 })
190 }
191 }
192 }
193
194 fn get_id_token_with_access_token<S>(
195 &self,
196 audience: &str,
197 response: crate::id_token::AccessTokenResponse<S>,
198 ) -> Result<crate::id_token::IdTokenRequest, Error>
199 where
200 S: AsRef<[u8]>,
201 {
202 self.inner
203 .get_id_token_with_access_token(audience, response)
204 }
205
206 fn parse_id_token_response<S>(
207 &self,
208 hash: u64,
209 response: http::Response<S>,
210 ) -> Result<IdToken, Error>
211 where
212 S: AsRef<[u8]>,
213 {
214 let token = self.inner.parse_id_token_response(hash, response)?;
215
216 self.id_tokens.insert(token.clone(), hash)?;
217 Ok(token)
218 }
219}
220
221fn hash_str(str: &str) -> Hash {
222 let hash = {
223 let mut hasher = twox_hash::XxHash::default();
224 hasher.write(str.as_bytes());
225 hasher.finish()
226 };
227
228 hash
229}
230
231fn hash_scopes<'a, I, S>(scopes: &I) -> Hash
232where
233 S: AsRef<str> + 'a,
234 I: IntoIterator<Item = &'a S> + Clone,
235{
236 let scopes_str = scopes
237 .clone()
238 .into_iter()
239 .map(|s| s.as_ref())
240 .collect::<Vec<_>>()
241 .join("|");
242
243 hash_str(&scopes_str)
244}
245
246#[cfg(test)]
247mod test {
248 use std::{
249 ops::Add,
250 ops::Sub,
251 time::{Duration, SystemTime},
252 };
253
254 use super::*;
255
256 #[test]
257 fn test_hash_scopes() {
258 use std::hash::Hasher;
259
260 let expected = {
261 let mut hasher = twox_hash::XxHash::default();
262 hasher.write(b"scope1|");
263 hasher.write(b"scope2|");
264 hasher.write(b"scope3");
265 hasher.finish()
266 };
267
268 let hash = hash_scopes(&["scope1", "scope2", "scope3"].iter());
269
270 assert_eq!(expected, hash);
271
272 let hash = hash_scopes(
273 &[
274 "scope1".to_owned(),
275 "scope2".to_owned(),
276 "scope3".to_owned(),
277 ]
278 .iter(),
279 );
280
281 assert_eq!(expected, hash);
282 }
283
284 #[test]
285 fn test_cache() {
286 let cache = TokenCache::new();
287 let hash = hash_scopes(&["scope1", "scope2"].iter());
288 let token = mock_token(100);
289 let expired_token = mock_token(-100);
290
291 assert!(matches!(
292 cache.get(hash).unwrap(),
293 TokenOrRequestReason::RequestReason(RequestReason::ParametersChanged)
294 ));
295
296 cache.insert(expired_token, hash).unwrap();
297
298 assert!(matches!(
299 cache.get(hash).unwrap(),
300 TokenOrRequestReason::RequestReason(RequestReason::Expired)
301 ));
302
303 cache.insert(token, hash).unwrap();
304
305 assert!(matches!(
306 cache.get(hash).unwrap(),
307 TokenOrRequestReason::Token(..)
308 ));
309 }
310
311 #[test]
312 fn test_cache_wrapper() {
313 let cached_provider = CachedTokenProvider::wrap(PanicProvider);
314
315 let hash = hash_scopes(&["scope1", "scope2"].iter());
316 let token = mock_token(100);
317
318 cached_provider.access_tokens.insert(token, hash).unwrap();
319
320 let tor = cached_provider.get_token(&["scope1", "scope2"]).unwrap();
321
322 assert!(matches!(tor, TokenOrRequest::Token(..)));
324 }
325
326 fn mock_token(expires_in: i64) -> Token {
327 let expires_in_timestamp = if expires_in > 0 {
328 SystemTime::now().add(Duration::from_secs(expires_in as u64))
329 } else {
330 SystemTime::now().sub(Duration::from_secs(expires_in.unsigned_abs()))
331 };
332
333 Token {
334 access_token: "access-token".to_string(),
335 refresh_token: "refresh-token".to_string(),
336 token_type: "token-type".to_string(),
337 expires_in: Some(expires_in),
338 expires_in_timestamp: Some(expires_in_timestamp),
339 }
340 }
341
342 struct PanicProvider;
345 impl TokenProvider for PanicProvider {
346 fn get_token_with_subject<'a, S, I, T>(
347 &self,
348 _subject: Option<T>,
349 _scopes: I,
350 ) -> Result<TokenOrRequest, Error>
351 where
352 S: AsRef<str> + 'a,
353 I: IntoIterator<Item = &'a S> + Clone,
354 T: Into<String>,
355 {
356 panic!("should not have been reached")
357 }
358
359 fn parse_token_response<S>(
360 &self,
361 _hash: u64,
362 _response: http::Response<S>,
363 ) -> Result<Token, Error>
364 where
365 S: AsRef<[u8]>,
366 {
367 panic!("should not have been reached")
368 }
369 }
370
371 impl IdTokenProvider for PanicProvider {
372 fn get_id_token(&self, _audience: &str) -> Result<IdTokenOrRequest, Error> {
373 panic!("should not have been reached")
374 }
375
376 fn parse_id_token_response<S>(
377 &self,
378 _hash: u64,
379 _response: http::Response<S>,
380 ) -> Result<IdToken, Error>
381 where
382 S: AsRef<[u8]>,
383 {
384 panic!("should not have been reached")
385 }
386
387 fn get_id_token_with_access_token<S>(
388 &self,
389 _audience: &str,
390 _response: crate::id_token::AccessTokenResponse<S>,
391 ) -> Result<crate::id_token::IdTokenRequest, Error>
392 where
393 S: AsRef<[u8]>,
394 {
395 panic!("should not have been reached")
396 }
397 }
398}