1pub use crate::types::TokenInfo;
6
7use std::collections::HashMap;
8use std::io;
9use std::path::{Path, PathBuf};
10use thiserror::Error;
11use tokio::sync::Mutex;
12
13use async_trait::async_trait;
14
15use serde::{Deserialize, Serialize};
16
17#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
31struct ScopeHash(u64);
32
33#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
36struct ScopeFilter(u64);
37
38#[derive(Debug, Copy, Clone, Eq, PartialEq)]
39enum FilterResponse {
40 Maybe,
41 No,
42}
43
44impl ScopeFilter {
45 fn is_subset_of(self, filter: ScopeFilter) -> FilterResponse {
47 if self.0 & filter.0 == self.0 {
48 FilterResponse::Maybe
49 } else {
50 FilterResponse::No
51 }
52 }
53}
54
55#[derive(Debug)]
57pub(crate) struct ScopeSet<'a, T> {
58 hash: ScopeHash,
59 filter: ScopeFilter,
60 scopes: &'a [T],
61}
62
63impl<T> Clone for ScopeSet<'_, T> {
66 fn clone(&self) -> Self {
67 *self
68 }
69}
70impl<T> Copy for ScopeSet<'_, T> {}
71
72impl<'a, T> ScopeSet<'a, T>
73where
74 T: AsRef<str>,
75{
76 pub fn from(scopes: &'a [T]) -> Self {
84 let (hash, filter) = scopes.iter().fold(
85 (ScopeHash(0), ScopeFilter(0)),
86 |(mut scope_hash, mut scope_filter), scope| {
87 let h = seahash::hash(scope.as_ref().as_bytes());
88
89 for i in 0..4 {
92 let h = (h >> (6 * i)) & 0b11_1111;
94 scope_filter.0 |= 1 << h;
95 }
96
97 scope_hash.0 ^= h;
99 (scope_hash, scope_filter)
100 },
101 );
102 ScopeSet {
103 hash,
104 filter,
105 scopes,
106 }
107 }
108}
109
110#[derive(Debug, Error)]
111pub enum TokenStorageError {
113 #[error("I/O error: {0}")]
115 Io(#[from] std::io::Error),
116 #[error("{0}")]
118 Other(std::borrow::Cow<'static, str>),
119}
120
121#[async_trait]
124pub trait TokenStorage: Send + Sync {
125 async fn set(&self, scopes: &[&str], token: TokenInfo) -> Result<(), TokenStorageError>;
128
129 async fn get(&self, scopes: &[&str]) -> Option<TokenInfo>;
131}
132
133pub(crate) enum Storage {
134 Memory { tokens: Mutex<JSONTokens> },
135 Disk(DiskStorage),
136 Custom(Box<dyn TokenStorage>),
137}
138
139impl Storage {
140 pub(crate) async fn set<T>(
141 &self,
142 scopes: ScopeSet<'_, T>,
143 token: TokenInfo,
144 ) -> Result<(), TokenStorageError>
145 where
146 T: AsRef<str>,
147 {
148 match self {
149 Storage::Memory { tokens } => Ok(tokens.lock().await.set(scopes, token)?),
150 Storage::Disk(disk_storage) => Ok(disk_storage.set(scopes, token).await?),
151 Storage::Custom(custom_storage) => {
152 let mut str_scopes = scopes
153 .scopes
154 .iter()
155 .map(|scope| scope.as_ref())
156 .collect::<Vec<_>>();
157 str_scopes.sort_unstable();
158 str_scopes.dedup();
159
160 custom_storage.set(&str_scopes[..], token).await
161 }
162 }
163 }
164
165 pub(crate) async fn get<T>(&self, scopes: ScopeSet<'_, T>) -> Option<TokenInfo>
166 where
167 T: AsRef<str>,
168 {
169 match self {
170 Storage::Memory { tokens } => tokens.lock().await.get(scopes),
171 Storage::Disk(disk_storage) => disk_storage.get(scopes).await,
172 Storage::Custom(custom_storage) => {
173 let mut str_scopes = scopes
174 .scopes
175 .iter()
176 .map(|scope| scope.as_ref())
177 .collect::<Vec<_>>();
178 str_scopes.sort_unstable();
179 str_scopes.dedup();
180
181 custom_storage.get(&str_scopes[..]).await
182 }
183 }
184 }
185}
186
187#[derive(Debug, Clone)]
190struct JSONToken {
191 scopes: Vec<String>,
192 token: TokenInfo,
193 hash: ScopeHash,
194 filter: ScopeFilter,
195}
196
197impl<'de> Deserialize<'de> for JSONToken {
198 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
199 where
200 D: serde::Deserializer<'de>,
201 {
202 #[derive(Deserialize)]
203 struct RawJSONToken {
204 scopes: Vec<String>,
205 token: TokenInfo,
206 }
207 let RawJSONToken { scopes, token } = RawJSONToken::deserialize(deserializer)?;
208 let ScopeSet { hash, filter, .. } = ScopeSet::from(&scopes);
209 Ok(JSONToken {
210 scopes,
211 token,
212 hash,
213 filter,
214 })
215 }
216}
217
218impl Serialize for JSONToken {
219 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
220 where
221 S: serde::Serializer,
222 {
223 #[derive(Serialize)]
224 struct RawJSONToken<'a> {
225 scopes: &'a [String],
226 token: &'a TokenInfo,
227 }
228 RawJSONToken {
229 scopes: &self.scopes,
230 token: &self.token,
231 }
232 .serialize(serializer)
233 }
234}
235
236#[derive(Debug, Clone)]
238pub(crate) struct JSONTokens {
239 token_map: HashMap<ScopeHash, JSONToken>,
240}
241
242impl Serialize for JSONTokens {
243 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
244 where
245 S: serde::Serializer,
246 {
247 serializer.collect_seq(self.token_map.values())
248 }
249}
250
251impl<'de> Deserialize<'de> for JSONTokens {
252 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
253 where
254 D: serde::Deserializer<'de>,
255 {
256 struct V;
257 impl<'de> serde::de::Visitor<'de> for V {
258 type Value = JSONTokens;
259
260 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
262 formatter.write_str("a sequence of JSONToken's")
263 }
264
265 fn visit_seq<M>(self, mut access: M) -> Result<Self::Value, M::Error>
266 where
267 M: serde::de::SeqAccess<'de>,
268 {
269 let mut token_map = HashMap::with_capacity(access.size_hint().unwrap_or(0));
270 while let Some(json_token) = access.next_element::<JSONToken>()? {
271 token_map.insert(json_token.hash, json_token);
272 }
273 Ok(JSONTokens { token_map })
274 }
275 }
276
277 deserializer.deserialize_seq(V)
280 }
281}
282
283impl JSONTokens {
284 pub(crate) fn new() -> Self {
285 JSONTokens {
286 token_map: HashMap::new(),
287 }
288 }
289
290 async fn load_from_file(filename: &Path) -> Result<Self, io::Error> {
291 let contents = tokio::fs::read(filename).await?;
292 serde_json::from_slice(&contents).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
293 }
294
295 fn get<T>(
296 &self,
297 ScopeSet {
298 hash,
299 filter,
300 scopes,
301 }: ScopeSet<T>,
302 ) -> Option<TokenInfo>
303 where
304 T: AsRef<str>,
305 {
306 if let Some(json_token) = self.token_map.get(&hash) {
307 return Some(json_token.token.clone());
308 }
309
310 let requested_scopes_are_subset_of = |other_scopes: &[String]| {
311 scopes
312 .iter()
313 .all(|s| other_scopes.iter().any(|t| t.as_str() == s.as_ref()))
314 };
315 self.token_map
318 .values()
319 .filter(|json_token| filter.is_subset_of(json_token.filter) == FilterResponse::Maybe)
320 .find(|v: &&JSONToken| requested_scopes_are_subset_of(&v.scopes))
321 .map(|t: &JSONToken| t.token.clone())
322 }
323
324 fn set<T>(
325 &mut self,
326 ScopeSet {
327 hash,
328 filter,
329 scopes,
330 }: ScopeSet<T>,
331 token: TokenInfo,
332 ) -> Result<(), io::Error>
333 where
334 T: AsRef<str>,
335 {
336 use std::collections::hash_map::Entry;
337 match self.token_map.entry(hash) {
338 Entry::Occupied(mut entry) => {
339 entry.get_mut().token = token;
340 }
341 Entry::Vacant(entry) => {
342 let json_token = JSONToken {
343 scopes: scopes.iter().map(|x| x.as_ref().to_owned()).collect(),
344 token,
345 hash,
346 filter,
347 };
348 entry.insert(json_token);
349 }
350 }
351 Ok(())
352 }
353}
354
355pub(crate) struct DiskStorage {
356 tokens: Mutex<JSONTokens>,
357 filename: PathBuf,
358}
359
360impl DiskStorage {
361 pub(crate) async fn new(filename: impl Into<PathBuf>) -> Result<Self, io::Error> {
362 let filename = filename.into();
363 let tokens = match JSONTokens::load_from_file(&filename).await {
364 Ok(tokens) => tokens,
365 Err(e) if e.kind() == io::ErrorKind::NotFound => JSONTokens::new(),
366 Err(e) => return Err(e),
367 };
368
369 Ok(DiskStorage {
370 tokens: Mutex::new(tokens),
371 filename,
372 })
373 }
374
375 pub(crate) async fn set<T>(
376 &self,
377 scopes: ScopeSet<'_, T>,
378 token: TokenInfo,
379 ) -> Result<(), io::Error>
380 where
381 T: AsRef<str>,
382 {
383 use tokio::io::AsyncWriteExt;
384 let json = {
385 use std::ops::Deref;
386 let mut lock = self.tokens.lock().await;
387 lock.set(scopes, token)?;
388 serde_json::to_string(lock.deref())
389 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?
390 };
391 let mut f = open_writeable_file(&self.filename).await?;
392 f.write_all(json.as_bytes()).await?;
393 Ok(())
394 }
395
396 pub(crate) async fn get<T>(&self, scopes: ScopeSet<'_, T>) -> Option<TokenInfo>
397 where
398 T: AsRef<str>,
399 {
400 self.tokens.lock().await.get(scopes)
401 }
402}
403
404#[cfg(unix)]
405async fn open_writeable_file(
406 filename: impl AsRef<Path>,
407) -> Result<tokio::fs::File, tokio::io::Error> {
408 use std::os::unix::fs::OpenOptionsExt;
411 let opts: tokio::fs::OpenOptions = {
412 let mut opts = std::fs::OpenOptions::new();
413 opts.write(true).create(true).truncate(true).mode(0o600);
414 opts.into()
415 };
416 opts.open(filename).await
417}
418
419#[cfg(not(unix))]
420async fn open_writeable_file(
421 filename: impl AsRef<Path>,
422) -> Result<tokio::fs::File, tokio::io::Error> {
423 tokio::fs::File::create(filename).await
426}
427
428#[cfg(test)]
429mod tests {
430 use std::time::Duration;
431
432 use super::*;
433
434 #[test]
435 fn test_scope_filter() {
436 let foo = ScopeSet::from(&["foo"]).filter;
437 let bar = ScopeSet::from(&["bar"]).filter;
438 let foobar = ScopeSet::from(&["foo", "bar"]).filter;
439
440 assert!(foo.is_subset_of(foobar) == FilterResponse::Maybe);
443 assert!(bar.is_subset_of(foobar) == FilterResponse::Maybe);
444
445 assert!(foo.is_subset_of(bar) == FilterResponse::No);
449 assert!(bar.is_subset_of(foo) == FilterResponse::No);
450 assert!(foobar.is_subset_of(foo) == FilterResponse::No);
451 assert!(foobar.is_subset_of(bar) == FilterResponse::No);
452 }
453
454 #[tokio::test]
455 async fn test_disk_storage() {
456 let new_token = |access_token: &str| TokenInfo {
457 access_token: Some(access_token.to_owned()),
458 refresh_token: None,
459 expires_at: None,
460 id_token: None,
461 };
462 let scope_set = ScopeSet::from(&["myscope"]);
463
464 let tempdir = tempfile::Builder::new()
465 .tempdir()
466 .expect("Tempdir to be created");
467
468 let filename = tempdir.path().join("tokenstorage.json");
469
470 {
471 let storage = DiskStorage::new(&filename).await.unwrap();
472 assert!(storage.get(scope_set).await.is_none());
473 storage
474 .set(scope_set, new_token("my_access_token"))
475 .await
476 .unwrap();
477 assert_eq!(
478 storage.get(scope_set).await,
479 Some(new_token("my_access_token"))
480 );
481 }
482 async fn find_file(path: &Path) {
483 loop {
484 if tokio::fs::metadata(path).await.is_ok() {
485 break;
486 }
487 }
488 }
489
490 tokio::time::timeout(Duration::from_secs(1), find_file(&filename))
491 .await
492 .unwrap_or_else(|_| panic!("File not created at {}", filename.to_string_lossy()));
493
494 {
495 let storage = DiskStorage::new(&filename).await.unwrap();
497 assert_eq!(
498 storage.get(scope_set).await,
499 Some(new_token("my_access_token"))
500 );
501 }
502 }
503}