yup_oauth2/
storage.rs

1// partially (c) 2016 Google Inc. (Lewin Bormann, lewinb@google.com)
2//
3// See project root for licensing information.
4//
5pub use crate::types::TokenInfo;
6
7use std::collections::HashMap;
8use std::io;
9use std::path::{Path, PathBuf};
10use thiserror::Error;
11use tokio::sync::Mutex;
12
13use async_trait::async_trait;
14
15use serde::{Deserialize, Serialize};
16
17// The storage layer allows retrieving tokens for scopes that have been
18// previously granted tokens. One wrinkle is that a token granted for a set
19// of scopes X is also valid for any subset of X's scopes. So when retrieving a
20// token for a set of scopes provided by the caller it's beneficial to compare
21// that set to all previously stored tokens to see if it is a subset of any
22// existing set. To do this efficiently we store a bloom filter along with each
23// token that represents the set of scopes the token is associated with. The
24// bloom filter allows for efficiently skipping any entries that are
25// definitively not a superset.
26// The current implementation uses a 64bit bloom filter with 4 hash functions.
27
28/// ScopeHash is a hash value derived from a list of scopes. The hash value
29/// represents a fingerprint of the set of scopes *independent* of the ordering.
30#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
31struct ScopeHash(u64);
32
33/// ScopeFilter represents a filter for a set of scopes. It can definitively
34/// prove that a given list of scopes is not a subset of another.
35#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
36struct ScopeFilter(u64);
37
38#[derive(Debug, Copy, Clone, Eq, PartialEq)]
39enum FilterResponse {
40    Maybe,
41    No,
42}
43
44impl ScopeFilter {
45    /// Determine if this ScopeFilter could be a subset of the provided filter.
46    fn is_subset_of(self, filter: ScopeFilter) -> FilterResponse {
47        if self.0 & filter.0 == self.0 {
48            FilterResponse::Maybe
49        } else {
50            FilterResponse::No
51        }
52    }
53}
54
55/// A set of scopes
56#[derive(Debug)]
57pub(crate) struct ScopeSet<'a, T> {
58    hash: ScopeHash,
59    filter: ScopeFilter,
60    scopes: &'a [T],
61}
62
63// Implement Clone manually. Auto derive fails to work correctly because we want
64// Clone to be implemented regardless of whether T is Clone or not.
65impl<T> Clone for ScopeSet<'_, T> {
66    fn clone(&self) -> Self {
67        *self
68    }
69}
70impl<T> Copy for ScopeSet<'_, T> {}
71
72impl<'a, T> ScopeSet<'a, T>
73where
74    T: AsRef<str>,
75{
76    /// Convert from an array into a ScopeSet. Automatically invoked by the compiler when
77    /// an array reference is passed.
78    // implement an inherent from method even though From is implemented. This
79    // is because passing an array ref like &[&str; 1] (&["foo"]) will be auto
80    // deref'd to a slice on function boundaries, but it will not implement the
81    // From trait. This inherent method just serves to auto deref from array
82    // refs to slices and proxy to the From impl.
83    pub fn from(scopes: &'a [T]) -> Self {
84        let (hash, filter) = scopes.iter().fold(
85            (ScopeHash(0), ScopeFilter(0)),
86            |(mut scope_hash, mut scope_filter), scope| {
87                let h = seahash::hash(scope.as_ref().as_bytes());
88
89                // Use the first 4 6-bit chunks of the seahash as the 4 hash values
90                // in the bloom filter.
91                for i in 0..4 {
92                    // h is a hash derived value in the range 0..64
93                    let h = (h >> (6 * i)) & 0b11_1111;
94                    scope_filter.0 |= 1 << h;
95                }
96
97                // xor the hashes together to get an order independent fingerprint.
98                scope_hash.0 ^= h;
99                (scope_hash, scope_filter)
100            },
101        );
102        ScopeSet {
103            hash,
104            filter,
105            scopes,
106        }
107    }
108}
109
110#[derive(Debug, Error)]
111/// Errors that occur while caching tokens in storage
112pub enum TokenStorageError {
113    /// Error while performing an I/O action
114    #[error("I/O error: {0}")]
115    Io(#[from] std::io::Error),
116    /// Other errors
117    #[error("{0}")]
118    Other(std::borrow::Cow<'static, str>),
119}
120
121/// Implement your own token storage solution by implementing this trait. You need a way to
122/// store and retrieve tokens, each keyed by a set of scopes.
123#[async_trait]
124pub trait TokenStorage: Send + Sync {
125    /// Store a token for the given set of scopes so that it can be retrieved later by get()
126    /// TokenInfo can be serialized with serde.
127    async fn set(&self, scopes: &[&str], token: TokenInfo) -> Result<(), TokenStorageError>;
128
129    /// Retrieve a token stored by set for the given set of scopes
130    async fn get(&self, scopes: &[&str]) -> Option<TokenInfo>;
131}
132
133pub(crate) enum Storage {
134    Memory { tokens: Mutex<JSONTokens> },
135    Disk(DiskStorage),
136    Custom(Box<dyn TokenStorage>),
137}
138
139impl Storage {
140    pub(crate) async fn set<T>(
141        &self,
142        scopes: ScopeSet<'_, T>,
143        token: TokenInfo,
144    ) -> Result<(), TokenStorageError>
145    where
146        T: AsRef<str>,
147    {
148        match self {
149            Storage::Memory { tokens } => Ok(tokens.lock().await.set(scopes, token)?),
150            Storage::Disk(disk_storage) => Ok(disk_storage.set(scopes, token).await?),
151            Storage::Custom(custom_storage) => {
152                let mut str_scopes = scopes
153                    .scopes
154                    .iter()
155                    .map(|scope| scope.as_ref())
156                    .collect::<Vec<_>>();
157                str_scopes.sort_unstable();
158                str_scopes.dedup();
159
160                custom_storage.set(&str_scopes[..], token).await
161            }
162        }
163    }
164
165    pub(crate) async fn get<T>(&self, scopes: ScopeSet<'_, T>) -> Option<TokenInfo>
166    where
167        T: AsRef<str>,
168    {
169        match self {
170            Storage::Memory { tokens } => tokens.lock().await.get(scopes),
171            Storage::Disk(disk_storage) => disk_storage.get(scopes).await,
172            Storage::Custom(custom_storage) => {
173                let mut str_scopes = scopes
174                    .scopes
175                    .iter()
176                    .map(|scope| scope.as_ref())
177                    .collect::<Vec<_>>();
178                str_scopes.sort_unstable();
179                str_scopes.dedup();
180
181                custom_storage.get(&str_scopes[..]).await
182            }
183        }
184    }
185}
186
187/// A single stored token.
188
189#[derive(Debug, Clone)]
190struct JSONToken {
191    scopes: Vec<String>,
192    token: TokenInfo,
193    hash: ScopeHash,
194    filter: ScopeFilter,
195}
196
197impl<'de> Deserialize<'de> for JSONToken {
198    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
199    where
200        D: serde::Deserializer<'de>,
201    {
202        #[derive(Deserialize)]
203        struct RawJSONToken {
204            scopes: Vec<String>,
205            token: TokenInfo,
206        }
207        let RawJSONToken { scopes, token } = RawJSONToken::deserialize(deserializer)?;
208        let ScopeSet { hash, filter, .. } = ScopeSet::from(&scopes);
209        Ok(JSONToken {
210            scopes,
211            token,
212            hash,
213            filter,
214        })
215    }
216}
217
218impl Serialize for JSONToken {
219    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
220    where
221        S: serde::Serializer,
222    {
223        #[derive(Serialize)]
224        struct RawJSONToken<'a> {
225            scopes: &'a [String],
226            token: &'a TokenInfo,
227        }
228        RawJSONToken {
229            scopes: &self.scopes,
230            token: &self.token,
231        }
232        .serialize(serializer)
233    }
234}
235
236/// List of tokens in a JSON object
237#[derive(Debug, Clone)]
238pub(crate) struct JSONTokens {
239    token_map: HashMap<ScopeHash, JSONToken>,
240}
241
242impl Serialize for JSONTokens {
243    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
244    where
245        S: serde::Serializer,
246    {
247        serializer.collect_seq(self.token_map.values())
248    }
249}
250
251impl<'de> Deserialize<'de> for JSONTokens {
252    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
253    where
254        D: serde::Deserializer<'de>,
255    {
256        struct V;
257        impl<'de> serde::de::Visitor<'de> for V {
258            type Value = JSONTokens;
259
260            // Format a message stating what data this Visitor expects to receive.
261            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
262                formatter.write_str("a sequence of JSONToken's")
263            }
264
265            fn visit_seq<M>(self, mut access: M) -> Result<Self::Value, M::Error>
266            where
267                M: serde::de::SeqAccess<'de>,
268            {
269                let mut token_map = HashMap::with_capacity(access.size_hint().unwrap_or(0));
270                while let Some(json_token) = access.next_element::<JSONToken>()? {
271                    token_map.insert(json_token.hash, json_token);
272                }
273                Ok(JSONTokens { token_map })
274            }
275        }
276
277        // Instantiate our Visitor and ask the Deserializer to drive
278        // it over the input data.
279        deserializer.deserialize_seq(V)
280    }
281}
282
283impl JSONTokens {
284    pub(crate) fn new() -> Self {
285        JSONTokens {
286            token_map: HashMap::new(),
287        }
288    }
289
290    async fn load_from_file(filename: &Path) -> Result<Self, io::Error> {
291        let contents = tokio::fs::read(filename).await?;
292        serde_json::from_slice(&contents).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
293    }
294
295    fn get<T>(
296        &self,
297        ScopeSet {
298            hash,
299            filter,
300            scopes,
301        }: ScopeSet<T>,
302    ) -> Option<TokenInfo>
303    where
304        T: AsRef<str>,
305    {
306        if let Some(json_token) = self.token_map.get(&hash) {
307            return Some(json_token.token.clone());
308        }
309
310        let requested_scopes_are_subset_of = |other_scopes: &[String]| {
311            scopes
312                .iter()
313                .all(|s| other_scopes.iter().any(|t| t.as_str() == s.as_ref()))
314        };
315        // No exact match for the scopes provided. Search for any tokens that
316        // exist for a superset of the scopes requested.
317        self.token_map
318            .values()
319            .filter(|json_token| filter.is_subset_of(json_token.filter) == FilterResponse::Maybe)
320            .find(|v: &&JSONToken| requested_scopes_are_subset_of(&v.scopes))
321            .map(|t: &JSONToken| t.token.clone())
322    }
323
324    fn set<T>(
325        &mut self,
326        ScopeSet {
327            hash,
328            filter,
329            scopes,
330        }: ScopeSet<T>,
331        token: TokenInfo,
332    ) -> Result<(), io::Error>
333    where
334        T: AsRef<str>,
335    {
336        use std::collections::hash_map::Entry;
337        match self.token_map.entry(hash) {
338            Entry::Occupied(mut entry) => {
339                entry.get_mut().token = token;
340            }
341            Entry::Vacant(entry) => {
342                let json_token = JSONToken {
343                    scopes: scopes.iter().map(|x| x.as_ref().to_owned()).collect(),
344                    token,
345                    hash,
346                    filter,
347                };
348                entry.insert(json_token);
349            }
350        }
351        Ok(())
352    }
353}
354
355pub(crate) struct DiskStorage {
356    tokens: Mutex<JSONTokens>,
357    filename: PathBuf,
358}
359
360impl DiskStorage {
361    pub(crate) async fn new(filename: impl Into<PathBuf>) -> Result<Self, io::Error> {
362        let filename = filename.into();
363        let tokens = match JSONTokens::load_from_file(&filename).await {
364            Ok(tokens) => tokens,
365            Err(e) if e.kind() == io::ErrorKind::NotFound => JSONTokens::new(),
366            Err(e) => return Err(e),
367        };
368
369        Ok(DiskStorage {
370            tokens: Mutex::new(tokens),
371            filename,
372        })
373    }
374
375    pub(crate) async fn set<T>(
376        &self,
377        scopes: ScopeSet<'_, T>,
378        token: TokenInfo,
379    ) -> Result<(), io::Error>
380    where
381        T: AsRef<str>,
382    {
383        use tokio::io::AsyncWriteExt;
384        let json = {
385            use std::ops::Deref;
386            let mut lock = self.tokens.lock().await;
387            lock.set(scopes, token)?;
388            serde_json::to_string(lock.deref())
389                .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?
390        };
391        let mut f = open_writeable_file(&self.filename).await?;
392        f.write_all(json.as_bytes()).await?;
393        Ok(())
394    }
395
396    pub(crate) async fn get<T>(&self, scopes: ScopeSet<'_, T>) -> Option<TokenInfo>
397    where
398        T: AsRef<str>,
399    {
400        self.tokens.lock().await.get(scopes)
401    }
402}
403
404#[cfg(unix)]
405async fn open_writeable_file(
406    filename: impl AsRef<Path>,
407) -> Result<tokio::fs::File, tokio::io::Error> {
408    // Ensure if the file is created it's only readable and writable by the
409    // current user.
410    use std::os::unix::fs::OpenOptionsExt;
411    let opts: tokio::fs::OpenOptions = {
412        let mut opts = std::fs::OpenOptions::new();
413        opts.write(true).create(true).truncate(true).mode(0o600);
414        opts.into()
415    };
416    opts.open(filename).await
417}
418
419#[cfg(not(unix))]
420async fn open_writeable_file(
421    filename: impl AsRef<Path>,
422) -> Result<tokio::fs::File, tokio::io::Error> {
423    // I don't have knowledge of windows or other platforms to know how to
424    // create a file that's only readable by the current user.
425    tokio::fs::File::create(filename).await
426}
427
428#[cfg(test)]
429mod tests {
430    use std::time::Duration;
431
432    use super::*;
433
434    #[test]
435    fn test_scope_filter() {
436        let foo = ScopeSet::from(&["foo"]).filter;
437        let bar = ScopeSet::from(&["bar"]).filter;
438        let foobar = ScopeSet::from(&["foo", "bar"]).filter;
439
440        // foo and bar are both subsets of foobar. This condition should hold no
441        // matter what changes are made to the bloom filter implementation.
442        assert!(foo.is_subset_of(foobar) == FilterResponse::Maybe);
443        assert!(bar.is_subset_of(foobar) == FilterResponse::Maybe);
444
445        // These conditions hold under the current bloom filter implementation
446        // because "foo" and "bar" don't collide, but if the bloom filter
447        // implementations change it could be valid for them to return Maybe.
448        assert!(foo.is_subset_of(bar) == FilterResponse::No);
449        assert!(bar.is_subset_of(foo) == FilterResponse::No);
450        assert!(foobar.is_subset_of(foo) == FilterResponse::No);
451        assert!(foobar.is_subset_of(bar) == FilterResponse::No);
452    }
453
454    #[tokio::test]
455    async fn test_disk_storage() {
456        let new_token = |access_token: &str| TokenInfo {
457            access_token: Some(access_token.to_owned()),
458            refresh_token: None,
459            expires_at: None,
460            id_token: None,
461        };
462        let scope_set = ScopeSet::from(&["myscope"]);
463
464        let tempdir = tempfile::Builder::new()
465            .tempdir()
466            .expect("Tempdir to be created");
467
468        let filename = tempdir.path().join("tokenstorage.json");
469
470        {
471            let storage = DiskStorage::new(&filename).await.unwrap();
472            assert!(storage.get(scope_set).await.is_none());
473            storage
474                .set(scope_set, new_token("my_access_token"))
475                .await
476                .unwrap();
477            assert_eq!(
478                storage.get(scope_set).await,
479                Some(new_token("my_access_token"))
480            );
481        }
482        async fn find_file(path: &Path) {
483            loop {
484                if tokio::fs::metadata(path).await.is_ok() {
485                    break;
486                }
487            }
488        }
489
490        tokio::time::timeout(Duration::from_secs(1), find_file(&filename))
491            .await
492            .unwrap_or_else(|_| panic!("File not created at {}", filename.to_string_lossy()));
493
494        {
495            // Create a new DiskStorage instance and verify the tokens were read from disk correctly.
496            let storage = DiskStorage::new(&filename).await.unwrap();
497            assert_eq!(
498                storage.get(scope_set).await,
499                Some(new_token("my_access_token"))
500            );
501        }
502    }
503}